From 178c3e75cca827657a018a64ae3d7945d9e25231 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 5 Aug 2025 12:38:06 -0700 Subject: [PATCH] Handle templates with channels separately --- modules/chat.py | 184 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 157 insertions(+), 27 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index c10d91a7..f929f653 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -86,6 +86,134 @@ yaml.add_representer(str, str_presenter) yaml.representer.SafeRepresenter.add_representer(str, str_presenter) +# Template Handler Classes +class TemplateHandler: + """Base class for handling different template types""" + + def __init__(self, template_str): + self.template_str = template_str + + def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True): + """Get prefix/suffix for generation""" + return "", "" + + def get_stopping_strings(self, renderer): + """Get stopping strings for this template type""" + return [] + + def modify_for_continue(self, prompt, renderer, impersonate=False): + """Modify prompt for continue mode""" + return prompt + + def supports_impersonate(self): + """Whether impersonate mode is supported""" + return False + + +class LinearTemplateHandler(TemplateHandler): + """Handles traditional linear templates""" + + def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True): + # This is the original, complex logic for deriving prefix/suffix for old templates. + if impersonate: + messages = [ + {"role": "user", "content": "<<|user-message-1|>>"}, + {"role": "user", "content": "<<|user-message-2|>>"}, + ] + else: + messages = [ + {"role": "assistant", "content": "<<|user-message-1|>>"}, + {"role": "assistant", "content": "<<|user-message-2|>>"}, + ] + + prompt = renderer(messages=messages) + suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0] + suffix = prompt.split("<<|user-message-2|>>")[1] + + if '<|start|>user' in suffix_plus_prefix or '<|start|>assistant' in suffix_plus_prefix: + start_index = suffix_plus_prefix.rindex('<|start|>') + prefix = suffix_plus_prefix[start_index:] + else: + prefix = suffix_plus_prefix[len(suffix):] + + if strip_trailing_spaces: + prefix = prefix.rstrip(' ') + + return prefix, suffix + + def get_stopping_strings(self, renderer): + # This is the original, correct logic for dynamically creating stopping strings for linear templates. + prefix_bot, suffix_bot = self.get_generation_prefix_suffix(renderer, impersonate=False) + prefix_user, suffix_user = self.get_generation_prefix_suffix(renderer, impersonate=True) + + stopping_strings = [ + suffix_user + prefix_bot, + suffix_user + prefix_user, + suffix_bot + prefix_bot, + suffix_bot + prefix_user, + ] + + # Attempt to find a single EOT token to use as a stop string + for item in stopping_strings: + item = item.strip() + if item.startswith("<") and ">" in item: + stopping_strings.append(item.split(">")[0] + ">") + break + elif item.startswith("[") and "]" in item: + stopping_strings.append(item.split("]")[0] + "]") + break + + return stopping_strings + + def modify_for_continue(self, prompt, renderer, impersonate=False): + suffix = self.get_generation_prefix_suffix(renderer, impersonate)[1] + if len(suffix) > 0: + return prompt[:-len(suffix)] + return prompt + + def supports_impersonate(self): + return True + + +class ChannelTemplateHandler(TemplateHandler): + """Handles channel-based templates""" + + def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True): + """ + Gets the string to add to the prompt to start a new generation. + """ + dummy_message = [{'role': 'user', 'content': '...'}] + prompt_without_gen = renderer(messages=dummy_message, add_generation_prompt=False) + prompt_with_gen = renderer(messages=dummy_message, add_generation_prompt=True) + generation_prompt = prompt_with_gen[len(prompt_without_gen):] + + if strip_trailing_spaces: + generation_prompt = generation_prompt.rstrip(' ') + + return generation_prompt, "" + + def get_stopping_strings(self, renderer): + return [ + '<|return|>', + '<|start|>user', + '<|start|>developer', + '<|call|>' + ] + + def modify_for_continue(self, prompt, renderer, impersonate=False): + return prompt + + def supports_impersonate(self): + return False + + +def create_template_handler(template_str): + """Factory function to create appropriate handler""" + if '<|channel|>' in template_str: + return ChannelTemplateHandler(template_str) + return LinearTemplateHandler(template_str) + + def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True): ''' Given a Jinja template, reverse-engineers the prefix and the suffix for @@ -270,6 +398,15 @@ def generate_chat_prompt(user_input, state, **kwargs): messages.append({"role": "user", "content": user_input}) + # Create template handler based on current template + template_str = state['instruction_template_str'] if state['mode'] == 'instruct' else chat_template_str + handler = create_template_handler(template_str) + + # Check impersonate support early + if impersonate and not handler.supports_impersonate(): + logger.warning("Impersonate not supported for channel-based templates") + return "" + def make_prompt(messages): if state['mode'] == 'chat-instruct' and _continue: prompt = renderer(messages=messages[:-1]) @@ -287,10 +424,10 @@ def generate_chat_prompt(user_input, state, **kwargs): command = replace_character_names(command, state['name1'], state['name2']) if _continue: - prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0] + prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0] prefix += messages[-1]["content"] else: - prefix = get_generation_prompt(renderer, impersonate=impersonate)[0] + prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate)[0] if not impersonate: prefix = apply_extensions('bot_prefix', prefix, state) @@ -298,16 +435,14 @@ def generate_chat_prompt(user_input, state, **kwargs): outer_messages.append({"role": "assistant", "content": prefix}) prompt = instruct_renderer(messages=outer_messages) - suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1] + suffix = handler.get_generation_prefix_suffix(instruct_renderer, impersonate=False)[1] if len(suffix) > 0: prompt = prompt[:-len(suffix)] else: if _continue: - suffix = get_generation_prompt(renderer, impersonate=impersonate)[1] - if len(suffix) > 0: - prompt = prompt[:-len(suffix)] + prompt = handler.modify_for_continue(prompt, renderer, impersonate) else: - prefix = get_generation_prompt(renderer, impersonate=impersonate)[0] + prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate)[0] if state['mode'] == 'chat' and not impersonate: prefix = apply_extensions('bot_prefix', prefix, state) @@ -433,31 +568,16 @@ def get_stopping_strings(state): if state['mode'] in ['instruct', 'chat-instruct']: template = jinja_env.from_string(state['instruction_template_str']) renderer = partial(template.render, add_generation_prompt=False) - renderers.append(renderer) + renderers.append((renderer, state['instruction_template_str'])) if state['mode'] in ['chat', 'chat-instruct']: template = jinja_env.from_string(state['chat_template_str']) renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2']) - renderers.append(renderer) + renderers.append((renderer, state['chat_template_str'])) - for renderer in renderers: - prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False) - prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True) - - stopping_strings += [ - suffix_user + prefix_bot, - suffix_user + prefix_user, - suffix_bot + prefix_bot, - suffix_bot + prefix_user, - ] - - # Try to find the EOT token - for item in stopping_strings.copy(): - item = item.strip() - if item.startswith("<") and ">" in item: - stopping_strings.append(item.split(">")[0] + ">") - elif item.startswith("[") and "]" in item: - stopping_strings.append(item.split("]")[0] + "]") + for renderer, template_str in renderers: + handler = create_template_handler(template_str) + stopping_strings += handler.get_stopping_strings(renderer) if 'stopping_strings' in state and isinstance(state['stopping_strings'], list): stopping_strings += state.pop('stopping_strings') @@ -809,6 +929,16 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess def impersonate_wrapper(textbox, state): + # Check template support first + template_str = state['chat_template_str'] + handler = create_template_handler(template_str) + + if not handler.supports_impersonate(): + logger.warning("Impersonate not supported for channel-based templates") + static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) + yield textbox, static_output + return + text = textbox['text'] static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])