mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-25 18:10:43 +01:00
Fix impersonate
This commit is contained in:
parent
ecd16d6bf9
commit
5c5a4dfc14
|
|
@ -105,10 +105,6 @@ class TemplateHandler:
|
|||
"""Modify prompt for continue mode"""
|
||||
return prompt
|
||||
|
||||
def supports_impersonate(self):
|
||||
"""Whether impersonate mode is supported"""
|
||||
return False
|
||||
|
||||
|
||||
class LinearTemplateHandler(TemplateHandler):
|
||||
"""Handles traditional linear templates"""
|
||||
|
|
@ -171,41 +167,41 @@ class LinearTemplateHandler(TemplateHandler):
|
|||
return prompt[:-len(suffix)]
|
||||
return prompt
|
||||
|
||||
def supports_impersonate(self):
|
||||
return True
|
||||
|
||||
|
||||
class ChannelTemplateHandler(TemplateHandler):
|
||||
"""Handles channel-based templates"""
|
||||
|
||||
def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True):
|
||||
"""
|
||||
Gets the string to add to the prompt to start a new generation.
|
||||
Gets the string to add to the prompt to start a new turn.
|
||||
"""
|
||||
dummy_message = [{'role': 'user', 'content': '...'}]
|
||||
prompt_without_gen = renderer(messages=dummy_message, add_generation_prompt=False)
|
||||
prompt_with_gen = renderer(messages=dummy_message, add_generation_prompt=True)
|
||||
generation_prompt = prompt_with_gen[len(prompt_without_gen):]
|
||||
if impersonate:
|
||||
# For impersonate mode, we need the prefix for a user's turn.
|
||||
prefix = "<|start|>user<|message|>"
|
||||
else:
|
||||
# For a normal reply, we need the prefix for the assistant's turn.
|
||||
prefix = "<|start|>assistant"
|
||||
|
||||
if strip_trailing_spaces:
|
||||
generation_prompt = generation_prompt.rstrip(' ')
|
||||
prefix = prefix.rstrip(' ')
|
||||
|
||||
return generation_prompt, ""
|
||||
# The suffix is not needed for this template type's generation logic.
|
||||
return prefix, ""
|
||||
|
||||
def get_stopping_strings(self, renderer):
|
||||
# Use specific tokens that unambiguously signal the end of a turn
|
||||
# or the start of a different character's turn.
|
||||
return [
|
||||
'<|return|>',
|
||||
'<|start|>user',
|
||||
'<|start|>developer',
|
||||
'<|call|>'
|
||||
'<|call|>',
|
||||
]
|
||||
|
||||
def modify_for_continue(self, prompt, renderer, impersonate=False):
|
||||
# Channels don't need suffix stripping for the continue logic to work.
|
||||
return prompt
|
||||
|
||||
def supports_impersonate(self):
|
||||
return False
|
||||
|
||||
|
||||
def create_template_handler(template_str):
|
||||
"""Factory function to create appropriate handler"""
|
||||
|
|
@ -402,11 +398,6 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
template_str = state['instruction_template_str'] if state['mode'] == 'instruct' else chat_template_str
|
||||
handler = create_template_handler(template_str)
|
||||
|
||||
# Check impersonate support early
|
||||
if impersonate and not handler.supports_impersonate():
|
||||
logger.warning("Impersonate not supported for channel-based templates")
|
||||
return ""
|
||||
|
||||
def make_prompt(messages):
|
||||
if state['mode'] == 'chat-instruct' and _continue:
|
||||
prompt = renderer(messages=messages[:-1])
|
||||
|
|
@ -943,12 +934,6 @@ def impersonate_wrapper(textbox, state):
|
|||
template_str = state['chat_template_str']
|
||||
handler = create_template_handler(template_str)
|
||||
|
||||
if not handler.supports_impersonate():
|
||||
logger.warning("Impersonate not supported for channel-based templates")
|
||||
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
|
||||
yield textbox, static_output
|
||||
return
|
||||
|
||||
text = textbox['text']
|
||||
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue