diff --git a/modules/chat.py b/modules/chat.py index 46d24a6f..043908c9 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -105,10 +105,6 @@ class TemplateHandler: """Modify prompt for continue mode""" return prompt - def supports_impersonate(self): - """Whether impersonate mode is supported""" - return False - class LinearTemplateHandler(TemplateHandler): """Handles traditional linear templates""" @@ -171,41 +167,41 @@ class LinearTemplateHandler(TemplateHandler): return prompt[:-len(suffix)] return prompt - def supports_impersonate(self): - return True - class ChannelTemplateHandler(TemplateHandler): """Handles channel-based templates""" def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True): """ - Gets the string to add to the prompt to start a new generation. + Gets the string to add to the prompt to start a new turn. """ - dummy_message = [{'role': 'user', 'content': '...'}] - prompt_without_gen = renderer(messages=dummy_message, add_generation_prompt=False) - prompt_with_gen = renderer(messages=dummy_message, add_generation_prompt=True) - generation_prompt = prompt_with_gen[len(prompt_without_gen):] + if impersonate: + # For impersonate mode, we need the prefix for a user's turn. + prefix = "<|start|>user<|message|>" + else: + # For a normal reply, we need the prefix for the assistant's turn. + prefix = "<|start|>assistant" if strip_trailing_spaces: - generation_prompt = generation_prompt.rstrip(' ') + prefix = prefix.rstrip(' ') - return generation_prompt, "" + # The suffix is not needed for this template type's generation logic. + return prefix, "" def get_stopping_strings(self, renderer): + # Use specific tokens that unambiguously signal the end of a turn + # or the start of a different character's turn. return [ '<|return|>', '<|start|>user', '<|start|>developer', - '<|call|>' + '<|call|>', ] def modify_for_continue(self, prompt, renderer, impersonate=False): + # Channels don't need suffix stripping for the continue logic to work. return prompt - def supports_impersonate(self): - return False - def create_template_handler(template_str): """Factory function to create appropriate handler""" @@ -402,11 +398,6 @@ def generate_chat_prompt(user_input, state, **kwargs): template_str = state['instruction_template_str'] if state['mode'] == 'instruct' else chat_template_str handler = create_template_handler(template_str) - # Check impersonate support early - if impersonate and not handler.supports_impersonate(): - logger.warning("Impersonate not supported for channel-based templates") - return "" - def make_prompt(messages): if state['mode'] == 'chat-instruct' and _continue: prompt = renderer(messages=messages[:-1]) @@ -943,12 +934,6 @@ def impersonate_wrapper(textbox, state): template_str = state['chat_template_str'] handler = create_template_handler(template_str) - if not handler.supports_impersonate(): - logger.warning("Impersonate not supported for channel-based templates") - static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) - yield textbox, static_output - return - text = textbox['text'] static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])