Start over new template handling (to avoid overcomplicating)

This commit is contained in:
oobabooga 2025-08-05 16:58:45 -07:00
parent 80f6abb07e
commit 20adc3c967

View file

@ -86,134 +86,6 @@ yaml.add_representer(str, str_presenter)
yaml.representer.SafeRepresenter.add_representer(str, str_presenter) yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
# Template Handler Classes
class TemplateHandler:
"""Base class for handling different template types"""
def __init__(self, template_str):
self.template_str = template_str
def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True):
"""Get prefix/suffix for generation"""
return "", ""
def get_stopping_strings(self, renderer):
"""Get stopping strings for this template type"""
return []
def modify_for_continue(self, prompt, renderer, impersonate=False):
"""Modify prompt for continue mode"""
return prompt
class LinearTemplateHandler(TemplateHandler):
"""Handles traditional linear templates"""
def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True):
# This is the original, complex logic for deriving prefix/suffix for old templates.
if impersonate:
messages = [
{"role": "user", "content": "<<|user-message-1|>>"},
{"role": "user", "content": "<<|user-message-2|>>"},
]
else:
messages = [
{"role": "assistant", "content": "<<|user-message-1|>>"},
{"role": "assistant", "content": "<<|user-message-2|>>"},
]
prompt = renderer(messages=messages)
suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
suffix = prompt.split("<<|user-message-2|>>")[1]
if '<|start|>user' in suffix_plus_prefix or '<|start|>assistant' in suffix_plus_prefix:
start_index = suffix_plus_prefix.rindex('<|start|>')
prefix = suffix_plus_prefix[start_index:]
else:
prefix = suffix_plus_prefix[len(suffix):]
if strip_trailing_spaces:
prefix = prefix.rstrip(' ')
return prefix, suffix
def get_stopping_strings(self, renderer):
# This is the original, correct logic for dynamically creating stopping strings for linear templates.
prefix_bot, suffix_bot = self.get_generation_prefix_suffix(renderer, impersonate=False)
prefix_user, suffix_user = self.get_generation_prefix_suffix(renderer, impersonate=True)
stopping_strings = [
suffix_user + prefix_bot,
suffix_user + prefix_user,
suffix_bot + prefix_bot,
suffix_bot + prefix_user,
]
# Attempt to find a single EOT token to use as a stop string
for item in stopping_strings:
item = item.strip()
if item.startswith("<") and ">" in item:
stopping_strings.append(item.split(">")[0] + ">")
break
elif item.startswith("[") and "]" in item:
stopping_strings.append(item.split("]")[0] + "]")
break
return stopping_strings
def modify_for_continue(self, prompt, renderer, impersonate=False):
suffix = self.get_generation_prefix_suffix(renderer, impersonate)[1]
if len(suffix) > 0:
return prompt[:-len(suffix)]
return prompt
class ChannelTemplateHandler(TemplateHandler):
"""Handles channel-based templates"""
def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True):
"""
Gets the string to add to the prompt to start a new turn.
"""
if impersonate:
# For impersonate mode, we need the prefix for a user's turn.
prefix = "<|start|>user<|message|>"
else:
# For a normal reply, we need the prefix for the assistant's turn.
prefix = "<|start|>assistant"
if strip_trailing_spaces:
prefix = prefix.rstrip(' ')
# The suffix is not needed for this template type's generation logic.
return prefix, ""
def get_stopping_strings(self, renderer):
# Use specific tokens that unambiguously signal the end of a turn
# or the start of a different character's turn.
return [
'<|return|>',
'<|start|>user',
'<|start|>developer',
'<|call|>',
]
def modify_for_continue(self, prompt, renderer, impersonate=False):
suffix = '<|return|>'
if prompt.endswith(suffix):
return prompt[:-len(suffix)]
return prompt
def create_template_handler(template_str):
"""Factory function to create appropriate handler"""
if '<|channel|>' in template_str:
return ChannelTemplateHandler(template_str)
return LinearTemplateHandler(template_str)
def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True): def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True):
''' '''
Given a Jinja template, reverse-engineers the prefix and the suffix for Given a Jinja template, reverse-engineers the prefix and the suffix for
@ -236,14 +108,7 @@ def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=Tru
suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0] suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
suffix = prompt.split("<<|user-message-2|>>")[1] suffix = prompt.split("<<|user-message-2|>>")[1]
prefix = suffix_plus_prefix[len(suffix):]
# Remove the message suffix. The first case handles the GPT-OSS model
# in a way that is likely to not interfere with previous models.
if '<|start|>user' in suffix_plus_prefix or '<|start|>assistant' in suffix_plus_prefix:
start_index = suffix_plus_prefix.rindex('<|start|>')
prefix = suffix_plus_prefix[start_index:]
else:
prefix = suffix_plus_prefix[len(suffix):]
if strip_trailing_spaces: if strip_trailing_spaces:
prefix = prefix.rstrip(' ') prefix = prefix.rstrip(' ')
@ -399,10 +264,6 @@ def generate_chat_prompt(user_input, state, **kwargs):
messages.append({"role": "user", "content": user_input}) messages.append({"role": "user", "content": user_input})
# Create template handler based on current template
template_str = state['instruction_template_str'] if state['mode'] == 'instruct' else chat_template_str
handler = create_template_handler(template_str)
def make_prompt(messages): def make_prompt(messages):
if state['mode'] == 'chat-instruct' and _continue: if state['mode'] == 'chat-instruct' and _continue:
prompt = renderer(messages=messages[:-1]) prompt = renderer(messages=messages[:-1])
@ -420,10 +281,10 @@ def generate_chat_prompt(user_input, state, **kwargs):
command = replace_character_names(command, state['name1'], state['name2']) command = replace_character_names(command, state['name1'], state['name2'])
if _continue: if _continue:
prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0] prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
prefix += messages[-1]["content"] prefix += messages[-1]["content"]
else: else:
prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate)[0] prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
if not impersonate: if not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state) prefix = apply_extensions('bot_prefix', prefix, state)
@ -431,14 +292,16 @@ def generate_chat_prompt(user_input, state, **kwargs):
outer_messages.append({"role": "assistant", "content": prefix}) outer_messages.append({"role": "assistant", "content": prefix})
prompt = instruct_renderer(messages=outer_messages) prompt = instruct_renderer(messages=outer_messages)
suffix = handler.get_generation_prefix_suffix(instruct_renderer, impersonate=False)[1] suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
if len(suffix) > 0: if len(suffix) > 0:
prompt = prompt[:-len(suffix)] prompt = prompt[:-len(suffix)]
else: else:
if _continue: if _continue:
prompt = handler.modify_for_continue(prompt, renderer, impersonate) suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
if len(suffix) > 0:
prompt = prompt[:-len(suffix)]
else: else:
prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate)[0] prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
if state['mode'] == 'chat' and not impersonate: if state['mode'] == 'chat' and not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state) prefix = apply_extensions('bot_prefix', prefix, state)
@ -564,16 +427,31 @@ def get_stopping_strings(state):
if state['mode'] in ['instruct', 'chat-instruct']: if state['mode'] in ['instruct', 'chat-instruct']:
template = jinja_env.from_string(state['instruction_template_str']) template = jinja_env.from_string(state['instruction_template_str'])
renderer = partial(template.render, add_generation_prompt=False) renderer = partial(template.render, add_generation_prompt=False)
renderers.append((renderer, state['instruction_template_str'])) renderers.append(renderer)
if state['mode'] in ['chat', 'chat-instruct']: if state['mode'] in ['chat', 'chat-instruct']:
template = jinja_env.from_string(state['chat_template_str']) template = jinja_env.from_string(state['chat_template_str'])
renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2']) renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'])
renderers.append((renderer, state['chat_template_str'])) renderers.append(renderer)
for renderer, template_str in renderers: for renderer in renderers:
handler = create_template_handler(template_str) prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False)
stopping_strings += handler.get_stopping_strings(renderer) prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True)
stopping_strings += [
suffix_user + prefix_bot,
suffix_user + prefix_user,
suffix_bot + prefix_bot,
suffix_bot + prefix_user,
]
# Try to find the EOT token
for item in stopping_strings.copy():
item = item.strip()
if item.startswith("<") and ">" in item:
stopping_strings.append(item.split(">")[0] + ">")
elif item.startswith("[") and "]" in item:
stopping_strings.append(item.split("]")[0] + "]")
if 'stopping_strings' in state and isinstance(state['stopping_strings'], list): if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
stopping_strings += state.pop('stopping_strings') stopping_strings += state.pop('stopping_strings')
@ -772,16 +650,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
output = apply_extensions('history', output) output = apply_extensions('history', output)
state = apply_extensions('state', state) state = apply_extensions('state', state)
# Automatically set skip_special_tokens to False for channel-based templates
if state['mode'] in ['instruct', 'chat-instruct']:
template_str = state['instruction_template_str']
else: # chat mode
template_str = state['chat_template_str']
handler = create_template_handler(template_str)
if isinstance(handler, ChannelTemplateHandler):
state['skip_special_tokens'] = False
# Let the jinja2 template handle the BOS token # Let the jinja2 template handle the BOS token
if state['mode'] in ['instruct', 'chat-instruct']: if state['mode'] in ['instruct', 'chat-instruct']:
state['add_bos_token'] = False state['add_bos_token'] = False
@ -941,10 +809,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
def impersonate_wrapper(textbox, state): def impersonate_wrapper(textbox, state):
# Check template support first
template_str = state['chat_template_str']
handler = create_template_handler(template_str)
text = textbox['text'] text = textbox['text']
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])