Fix impersonate

This commit is contained in:
oobabooga 2025-08-05 13:03:18 -07:00
parent ecd16d6bf9
commit 5c5a4dfc14

View file

@ -105,10 +105,6 @@ class TemplateHandler:
"""Modify prompt for continue mode"""
return prompt
def supports_impersonate(self):
"""Whether impersonate mode is supported"""
return False
class LinearTemplateHandler(TemplateHandler):
"""Handles traditional linear templates"""
@ -171,41 +167,41 @@ class LinearTemplateHandler(TemplateHandler):
return prompt[:-len(suffix)]
return prompt
def supports_impersonate(self):
return True
class ChannelTemplateHandler(TemplateHandler):
"""Handles channel-based templates"""
def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True):
"""
Gets the string to add to the prompt to start a new generation.
Gets the string to add to the prompt to start a new turn.
"""
dummy_message = [{'role': 'user', 'content': '...'}]
prompt_without_gen = renderer(messages=dummy_message, add_generation_prompt=False)
prompt_with_gen = renderer(messages=dummy_message, add_generation_prompt=True)
generation_prompt = prompt_with_gen[len(prompt_without_gen):]
if impersonate:
# For impersonate mode, we need the prefix for a user's turn.
prefix = "<|start|>user<|message|>"
else:
# For a normal reply, we need the prefix for the assistant's turn.
prefix = "<|start|>assistant"
if strip_trailing_spaces:
generation_prompt = generation_prompt.rstrip(' ')
prefix = prefix.rstrip(' ')
return generation_prompt, ""
# The suffix is not needed for this template type's generation logic.
return prefix, ""
def get_stopping_strings(self, renderer):
# Use specific tokens that unambiguously signal the end of a turn
# or the start of a different character's turn.
return [
'<|return|>',
'<|start|>user',
'<|start|>developer',
'<|call|>'
'<|call|>',
]
def modify_for_continue(self, prompt, renderer, impersonate=False):
# Channels don't need suffix stripping for the continue logic to work.
return prompt
def supports_impersonate(self):
return False
def create_template_handler(template_str):
"""Factory function to create appropriate handler"""
@ -402,11 +398,6 @@ def generate_chat_prompt(user_input, state, **kwargs):
template_str = state['instruction_template_str'] if state['mode'] == 'instruct' else chat_template_str
handler = create_template_handler(template_str)
# Check impersonate support early
if impersonate and not handler.supports_impersonate():
logger.warning("Impersonate not supported for channel-based templates")
return ""
def make_prompt(messages):
if state['mode'] == 'chat-instruct' and _continue:
prompt = renderer(messages=messages[:-1])
@ -943,12 +934,6 @@ def impersonate_wrapper(textbox, state):
template_str = state['chat_template_str']
handler = create_template_handler(template_str)
if not handler.supports_impersonate():
logger.warning("Impersonate not supported for channel-based templates")
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
yield textbox, static_output
return
text = textbox['text']
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])