From 20adc3c96737e35b96f6b1d557a63b1d2c75a825 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 5 Aug 2025 16:58:45 -0700 Subject: [PATCH] Start over new template handling (to avoid overcomplicating) --- modules/chat.py | 192 +++++++----------------------------------------- 1 file changed, 28 insertions(+), 164 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index b23340aa..82760cc8 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -86,134 +86,6 @@ yaml.add_representer(str, str_presenter) yaml.representer.SafeRepresenter.add_representer(str, str_presenter) -# Template Handler Classes -class TemplateHandler: - """Base class for handling different template types""" - - def __init__(self, template_str): - self.template_str = template_str - - def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True): - """Get prefix/suffix for generation""" - return "", "" - - def get_stopping_strings(self, renderer): - """Get stopping strings for this template type""" - return [] - - def modify_for_continue(self, prompt, renderer, impersonate=False): - """Modify prompt for continue mode""" - return prompt - - -class LinearTemplateHandler(TemplateHandler): - """Handles traditional linear templates""" - - def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True): - # This is the original, complex logic for deriving prefix/suffix for old templates. - if impersonate: - messages = [ - {"role": "user", "content": "<<|user-message-1|>>"}, - {"role": "user", "content": "<<|user-message-2|>>"}, - ] - else: - messages = [ - {"role": "assistant", "content": "<<|user-message-1|>>"}, - {"role": "assistant", "content": "<<|user-message-2|>>"}, - ] - - prompt = renderer(messages=messages) - suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0] - suffix = prompt.split("<<|user-message-2|>>")[1] - - if '<|start|>user' in suffix_plus_prefix or '<|start|>assistant' in suffix_plus_prefix: - start_index = suffix_plus_prefix.rindex('<|start|>') - prefix = suffix_plus_prefix[start_index:] - else: - prefix = suffix_plus_prefix[len(suffix):] - - if strip_trailing_spaces: - prefix = prefix.rstrip(' ') - - return prefix, suffix - - def get_stopping_strings(self, renderer): - # This is the original, correct logic for dynamically creating stopping strings for linear templates. - prefix_bot, suffix_bot = self.get_generation_prefix_suffix(renderer, impersonate=False) - prefix_user, suffix_user = self.get_generation_prefix_suffix(renderer, impersonate=True) - - stopping_strings = [ - suffix_user + prefix_bot, - suffix_user + prefix_user, - suffix_bot + prefix_bot, - suffix_bot + prefix_user, - ] - - # Attempt to find a single EOT token to use as a stop string - for item in stopping_strings: - item = item.strip() - if item.startswith("<") and ">" in item: - stopping_strings.append(item.split(">")[0] + ">") - break - elif item.startswith("[") and "]" in item: - stopping_strings.append(item.split("]")[0] + "]") - break - - return stopping_strings - - def modify_for_continue(self, prompt, renderer, impersonate=False): - suffix = self.get_generation_prefix_suffix(renderer, impersonate)[1] - if len(suffix) > 0: - return prompt[:-len(suffix)] - - return prompt - - -class ChannelTemplateHandler(TemplateHandler): - """Handles channel-based templates""" - - def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True): - """ - Gets the string to add to the prompt to start a new turn. - """ - if impersonate: - # For impersonate mode, we need the prefix for a user's turn. - prefix = "<|start|>user<|message|>" - else: - # For a normal reply, we need the prefix for the assistant's turn. - prefix = "<|start|>assistant" - - if strip_trailing_spaces: - prefix = prefix.rstrip(' ') - - # The suffix is not needed for this template type's generation logic. - return prefix, "" - - def get_stopping_strings(self, renderer): - # Use specific tokens that unambiguously signal the end of a turn - # or the start of a different character's turn. - return [ - '<|return|>', - '<|start|>user', - '<|start|>developer', - '<|call|>', - ] - - def modify_for_continue(self, prompt, renderer, impersonate=False): - suffix = '<|return|>' - if prompt.endswith(suffix): - return prompt[:-len(suffix)] - - return prompt - - -def create_template_handler(template_str): - """Factory function to create appropriate handler""" - if '<|channel|>' in template_str: - return ChannelTemplateHandler(template_str) - return LinearTemplateHandler(template_str) - - def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True): ''' Given a Jinja template, reverse-engineers the prefix and the suffix for @@ -236,14 +108,7 @@ def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=Tru suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0] suffix = prompt.split("<<|user-message-2|>>")[1] - - # Remove the message suffix. The first case handles the GPT-OSS model - # in a way that is likely to not interfere with previous models. - if '<|start|>user' in suffix_plus_prefix or '<|start|>assistant' in suffix_plus_prefix: - start_index = suffix_plus_prefix.rindex('<|start|>') - prefix = suffix_plus_prefix[start_index:] - else: - prefix = suffix_plus_prefix[len(suffix):] + prefix = suffix_plus_prefix[len(suffix):] if strip_trailing_spaces: prefix = prefix.rstrip(' ') @@ -399,10 +264,6 @@ def generate_chat_prompt(user_input, state, **kwargs): messages.append({"role": "user", "content": user_input}) - # Create template handler based on current template - template_str = state['instruction_template_str'] if state['mode'] == 'instruct' else chat_template_str - handler = create_template_handler(template_str) - def make_prompt(messages): if state['mode'] == 'chat-instruct' and _continue: prompt = renderer(messages=messages[:-1]) @@ -420,10 +281,10 @@ def generate_chat_prompt(user_input, state, **kwargs): command = replace_character_names(command, state['name1'], state['name2']) if _continue: - prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0] + prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0] prefix += messages[-1]["content"] else: - prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate)[0] + prefix = get_generation_prompt(renderer, impersonate=impersonate)[0] if not impersonate: prefix = apply_extensions('bot_prefix', prefix, state) @@ -431,14 +292,16 @@ def generate_chat_prompt(user_input, state, **kwargs): outer_messages.append({"role": "assistant", "content": prefix}) prompt = instruct_renderer(messages=outer_messages) - suffix = handler.get_generation_prefix_suffix(instruct_renderer, impersonate=False)[1] + suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1] if len(suffix) > 0: prompt = prompt[:-len(suffix)] else: if _continue: - prompt = handler.modify_for_continue(prompt, renderer, impersonate) + suffix = get_generation_prompt(renderer, impersonate=impersonate)[1] + if len(suffix) > 0: + prompt = prompt[:-len(suffix)] else: - prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate)[0] + prefix = get_generation_prompt(renderer, impersonate=impersonate)[0] if state['mode'] == 'chat' and not impersonate: prefix = apply_extensions('bot_prefix', prefix, state) @@ -564,16 +427,31 @@ def get_stopping_strings(state): if state['mode'] in ['instruct', 'chat-instruct']: template = jinja_env.from_string(state['instruction_template_str']) renderer = partial(template.render, add_generation_prompt=False) - renderers.append((renderer, state['instruction_template_str'])) + renderers.append(renderer) if state['mode'] in ['chat', 'chat-instruct']: template = jinja_env.from_string(state['chat_template_str']) renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2']) - renderers.append((renderer, state['chat_template_str'])) + renderers.append(renderer) - for renderer, template_str in renderers: - handler = create_template_handler(template_str) - stopping_strings += handler.get_stopping_strings(renderer) + for renderer in renderers: + prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False) + prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True) + + stopping_strings += [ + suffix_user + prefix_bot, + suffix_user + prefix_user, + suffix_bot + prefix_bot, + suffix_bot + prefix_user, + ] + + # Try to find the EOT token + for item in stopping_strings.copy(): + item = item.strip() + if item.startswith("<") and ">" in item: + stopping_strings.append(item.split(">")[0] + ">") + elif item.startswith("[") and "]" in item: + stopping_strings.append(item.split("]")[0] + "]") if 'stopping_strings' in state and isinstance(state['stopping_strings'], list): stopping_strings += state.pop('stopping_strings') @@ -772,16 +650,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess output = apply_extensions('history', output) state = apply_extensions('state', state) - # Automatically set skip_special_tokens to False for channel-based templates - if state['mode'] in ['instruct', 'chat-instruct']: - template_str = state['instruction_template_str'] - else: # chat mode - template_str = state['chat_template_str'] - - handler = create_template_handler(template_str) - if isinstance(handler, ChannelTemplateHandler): - state['skip_special_tokens'] = False - # Let the jinja2 template handle the BOS token if state['mode'] in ['instruct', 'chat-instruct']: state['add_bos_token'] = False @@ -941,10 +809,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess def impersonate_wrapper(textbox, state): - # Check template support first - template_str = state['chat_template_str'] - handler = create_template_handler(template_str) - text = textbox['text'] static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])