Handle templates with channels separately

This commit is contained in:
oobabooga 2025-08-05 12:38:06 -07:00
parent 9f28f53cfc
commit 178c3e75cc

View file

@ -86,6 +86,134 @@ yaml.add_representer(str, str_presenter)
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
# Template Handler Classes
class TemplateHandler:
"""Base class for handling different template types"""
def __init__(self, template_str):
self.template_str = template_str
def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True):
"""Get prefix/suffix for generation"""
return "", ""
def get_stopping_strings(self, renderer):
"""Get stopping strings for this template type"""
return []
def modify_for_continue(self, prompt, renderer, impersonate=False):
"""Modify prompt for continue mode"""
return prompt
def supports_impersonate(self):
"""Whether impersonate mode is supported"""
return False
class LinearTemplateHandler(TemplateHandler):
"""Handles traditional linear templates"""
def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True):
# This is the original, complex logic for deriving prefix/suffix for old templates.
if impersonate:
messages = [
{"role": "user", "content": "<<|user-message-1|>>"},
{"role": "user", "content": "<<|user-message-2|>>"},
]
else:
messages = [
{"role": "assistant", "content": "<<|user-message-1|>>"},
{"role": "assistant", "content": "<<|user-message-2|>>"},
]
prompt = renderer(messages=messages)
suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
suffix = prompt.split("<<|user-message-2|>>")[1]
if '<|start|>user' in suffix_plus_prefix or '<|start|>assistant' in suffix_plus_prefix:
start_index = suffix_plus_prefix.rindex('<|start|>')
prefix = suffix_plus_prefix[start_index:]
else:
prefix = suffix_plus_prefix[len(suffix):]
if strip_trailing_spaces:
prefix = prefix.rstrip(' ')
return prefix, suffix
def get_stopping_strings(self, renderer):
# This is the original, correct logic for dynamically creating stopping strings for linear templates.
prefix_bot, suffix_bot = self.get_generation_prefix_suffix(renderer, impersonate=False)
prefix_user, suffix_user = self.get_generation_prefix_suffix(renderer, impersonate=True)
stopping_strings = [
suffix_user + prefix_bot,
suffix_user + prefix_user,
suffix_bot + prefix_bot,
suffix_bot + prefix_user,
]
# Attempt to find a single EOT token to use as a stop string
for item in stopping_strings:
item = item.strip()
if item.startswith("<") and ">" in item:
stopping_strings.append(item.split(">")[0] + ">")
break
elif item.startswith("[") and "]" in item:
stopping_strings.append(item.split("]")[0] + "]")
break
return stopping_strings
def modify_for_continue(self, prompt, renderer, impersonate=False):
suffix = self.get_generation_prefix_suffix(renderer, impersonate)[1]
if len(suffix) > 0:
return prompt[:-len(suffix)]
return prompt
def supports_impersonate(self):
return True
class ChannelTemplateHandler(TemplateHandler):
"""Handles channel-based templates"""
def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True):
"""
Gets the string to add to the prompt to start a new generation.
"""
dummy_message = [{'role': 'user', 'content': '...'}]
prompt_without_gen = renderer(messages=dummy_message, add_generation_prompt=False)
prompt_with_gen = renderer(messages=dummy_message, add_generation_prompt=True)
generation_prompt = prompt_with_gen[len(prompt_without_gen):]
if strip_trailing_spaces:
generation_prompt = generation_prompt.rstrip(' ')
return generation_prompt, ""
def get_stopping_strings(self, renderer):
return [
'<|return|>',
'<|start|>user',
'<|start|>developer',
'<|call|>'
]
def modify_for_continue(self, prompt, renderer, impersonate=False):
return prompt
def supports_impersonate(self):
return False
def create_template_handler(template_str):
"""Factory function to create appropriate handler"""
if '<|channel|>' in template_str:
return ChannelTemplateHandler(template_str)
return LinearTemplateHandler(template_str)
def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True):
'''
Given a Jinja template, reverse-engineers the prefix and the suffix for
@ -270,6 +398,15 @@ def generate_chat_prompt(user_input, state, **kwargs):
messages.append({"role": "user", "content": user_input})
# Create template handler based on current template
template_str = state['instruction_template_str'] if state['mode'] == 'instruct' else chat_template_str
handler = create_template_handler(template_str)
# Check impersonate support early
if impersonate and not handler.supports_impersonate():
logger.warning("Impersonate not supported for channel-based templates")
return ""
def make_prompt(messages):
if state['mode'] == 'chat-instruct' and _continue:
prompt = renderer(messages=messages[:-1])
@ -287,10 +424,10 @@ def generate_chat_prompt(user_input, state, **kwargs):
command = replace_character_names(command, state['name1'], state['name2'])
if _continue:
prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
prefix += messages[-1]["content"]
else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate)[0]
if not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)
@ -298,16 +435,14 @@ def generate_chat_prompt(user_input, state, **kwargs):
outer_messages.append({"role": "assistant", "content": prefix})
prompt = instruct_renderer(messages=outer_messages)
suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
suffix = handler.get_generation_prefix_suffix(instruct_renderer, impersonate=False)[1]
if len(suffix) > 0:
prompt = prompt[:-len(suffix)]
else:
if _continue:
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
if len(suffix) > 0:
prompt = prompt[:-len(suffix)]
prompt = handler.modify_for_continue(prompt, renderer, impersonate)
else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate)[0]
if state['mode'] == 'chat' and not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)
@ -433,31 +568,16 @@ def get_stopping_strings(state):
if state['mode'] in ['instruct', 'chat-instruct']:
template = jinja_env.from_string(state['instruction_template_str'])
renderer = partial(template.render, add_generation_prompt=False)
renderers.append(renderer)
renderers.append((renderer, state['instruction_template_str']))
if state['mode'] in ['chat', 'chat-instruct']:
template = jinja_env.from_string(state['chat_template_str'])
renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'])
renderers.append(renderer)
renderers.append((renderer, state['chat_template_str']))
for renderer in renderers:
prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False)
prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True)
stopping_strings += [
suffix_user + prefix_bot,
suffix_user + prefix_user,
suffix_bot + prefix_bot,
suffix_bot + prefix_user,
]
# Try to find the EOT token
for item in stopping_strings.copy():
item = item.strip()
if item.startswith("<") and ">" in item:
stopping_strings.append(item.split(">")[0] + ">")
elif item.startswith("[") and "]" in item:
stopping_strings.append(item.split("]")[0] + "]")
for renderer, template_str in renderers:
handler = create_template_handler(template_str)
stopping_strings += handler.get_stopping_strings(renderer)
if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
stopping_strings += state.pop('stopping_strings')
@ -809,6 +929,16 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
def impersonate_wrapper(textbox, state):
# Check template support first
template_str = state['chat_template_str']
handler = create_template_handler(template_str)
if not handler.supports_impersonate():
logger.warning("Impersonate not supported for channel-based templates")
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
yield textbox, static_output
return
text = textbox['text']
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])