mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Start over new template handling (to avoid overcomplicating)
This commit is contained in:
parent
80f6abb07e
commit
20adc3c967
192
modules/chat.py
192
modules/chat.py
|
|
@ -86,134 +86,6 @@ yaml.add_representer(str, str_presenter)
|
||||||
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
|
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
|
||||||
|
|
||||||
|
|
||||||
# Template Handler Classes
|
|
||||||
class TemplateHandler:
|
|
||||||
"""Base class for handling different template types"""
|
|
||||||
|
|
||||||
def __init__(self, template_str):
|
|
||||||
self.template_str = template_str
|
|
||||||
|
|
||||||
def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True):
|
|
||||||
"""Get prefix/suffix for generation"""
|
|
||||||
return "", ""
|
|
||||||
|
|
||||||
def get_stopping_strings(self, renderer):
|
|
||||||
"""Get stopping strings for this template type"""
|
|
||||||
return []
|
|
||||||
|
|
||||||
def modify_for_continue(self, prompt, renderer, impersonate=False):
|
|
||||||
"""Modify prompt for continue mode"""
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
class LinearTemplateHandler(TemplateHandler):
|
|
||||||
"""Handles traditional linear templates"""
|
|
||||||
|
|
||||||
def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True):
|
|
||||||
# This is the original, complex logic for deriving prefix/suffix for old templates.
|
|
||||||
if impersonate:
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "<<|user-message-1|>>"},
|
|
||||||
{"role": "user", "content": "<<|user-message-2|>>"},
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
messages = [
|
|
||||||
{"role": "assistant", "content": "<<|user-message-1|>>"},
|
|
||||||
{"role": "assistant", "content": "<<|user-message-2|>>"},
|
|
||||||
]
|
|
||||||
|
|
||||||
prompt = renderer(messages=messages)
|
|
||||||
suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
|
|
||||||
suffix = prompt.split("<<|user-message-2|>>")[1]
|
|
||||||
|
|
||||||
if '<|start|>user' in suffix_plus_prefix or '<|start|>assistant' in suffix_plus_prefix:
|
|
||||||
start_index = suffix_plus_prefix.rindex('<|start|>')
|
|
||||||
prefix = suffix_plus_prefix[start_index:]
|
|
||||||
else:
|
|
||||||
prefix = suffix_plus_prefix[len(suffix):]
|
|
||||||
|
|
||||||
if strip_trailing_spaces:
|
|
||||||
prefix = prefix.rstrip(' ')
|
|
||||||
|
|
||||||
return prefix, suffix
|
|
||||||
|
|
||||||
def get_stopping_strings(self, renderer):
|
|
||||||
# This is the original, correct logic for dynamically creating stopping strings for linear templates.
|
|
||||||
prefix_bot, suffix_bot = self.get_generation_prefix_suffix(renderer, impersonate=False)
|
|
||||||
prefix_user, suffix_user = self.get_generation_prefix_suffix(renderer, impersonate=True)
|
|
||||||
|
|
||||||
stopping_strings = [
|
|
||||||
suffix_user + prefix_bot,
|
|
||||||
suffix_user + prefix_user,
|
|
||||||
suffix_bot + prefix_bot,
|
|
||||||
suffix_bot + prefix_user,
|
|
||||||
]
|
|
||||||
|
|
||||||
# Attempt to find a single EOT token to use as a stop string
|
|
||||||
for item in stopping_strings:
|
|
||||||
item = item.strip()
|
|
||||||
if item.startswith("<") and ">" in item:
|
|
||||||
stopping_strings.append(item.split(">")[0] + ">")
|
|
||||||
break
|
|
||||||
elif item.startswith("[") and "]" in item:
|
|
||||||
stopping_strings.append(item.split("]")[0] + "]")
|
|
||||||
break
|
|
||||||
|
|
||||||
return stopping_strings
|
|
||||||
|
|
||||||
def modify_for_continue(self, prompt, renderer, impersonate=False):
|
|
||||||
suffix = self.get_generation_prefix_suffix(renderer, impersonate)[1]
|
|
||||||
if len(suffix) > 0:
|
|
||||||
return prompt[:-len(suffix)]
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelTemplateHandler(TemplateHandler):
|
|
||||||
"""Handles channel-based templates"""
|
|
||||||
|
|
||||||
def get_generation_prefix_suffix(self, renderer, impersonate=False, strip_trailing_spaces=True):
|
|
||||||
"""
|
|
||||||
Gets the string to add to the prompt to start a new turn.
|
|
||||||
"""
|
|
||||||
if impersonate:
|
|
||||||
# For impersonate mode, we need the prefix for a user's turn.
|
|
||||||
prefix = "<|start|>user<|message|>"
|
|
||||||
else:
|
|
||||||
# For a normal reply, we need the prefix for the assistant's turn.
|
|
||||||
prefix = "<|start|>assistant"
|
|
||||||
|
|
||||||
if strip_trailing_spaces:
|
|
||||||
prefix = prefix.rstrip(' ')
|
|
||||||
|
|
||||||
# The suffix is not needed for this template type's generation logic.
|
|
||||||
return prefix, ""
|
|
||||||
|
|
||||||
def get_stopping_strings(self, renderer):
|
|
||||||
# Use specific tokens that unambiguously signal the end of a turn
|
|
||||||
# or the start of a different character's turn.
|
|
||||||
return [
|
|
||||||
'<|return|>',
|
|
||||||
'<|start|>user',
|
|
||||||
'<|start|>developer',
|
|
||||||
'<|call|>',
|
|
||||||
]
|
|
||||||
|
|
||||||
def modify_for_continue(self, prompt, renderer, impersonate=False):
|
|
||||||
suffix = '<|return|>'
|
|
||||||
if prompt.endswith(suffix):
|
|
||||||
return prompt[:-len(suffix)]
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def create_template_handler(template_str):
|
|
||||||
"""Factory function to create appropriate handler"""
|
|
||||||
if '<|channel|>' in template_str:
|
|
||||||
return ChannelTemplateHandler(template_str)
|
|
||||||
return LinearTemplateHandler(template_str)
|
|
||||||
|
|
||||||
|
|
||||||
def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True):
|
def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True):
|
||||||
'''
|
'''
|
||||||
Given a Jinja template, reverse-engineers the prefix and the suffix for
|
Given a Jinja template, reverse-engineers the prefix and the suffix for
|
||||||
|
|
@ -236,14 +108,7 @@ def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=Tru
|
||||||
|
|
||||||
suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
|
suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
|
||||||
suffix = prompt.split("<<|user-message-2|>>")[1]
|
suffix = prompt.split("<<|user-message-2|>>")[1]
|
||||||
|
prefix = suffix_plus_prefix[len(suffix):]
|
||||||
# Remove the message suffix. The first case handles the GPT-OSS model
|
|
||||||
# in a way that is likely to not interfere with previous models.
|
|
||||||
if '<|start|>user' in suffix_plus_prefix or '<|start|>assistant' in suffix_plus_prefix:
|
|
||||||
start_index = suffix_plus_prefix.rindex('<|start|>')
|
|
||||||
prefix = suffix_plus_prefix[start_index:]
|
|
||||||
else:
|
|
||||||
prefix = suffix_plus_prefix[len(suffix):]
|
|
||||||
|
|
||||||
if strip_trailing_spaces:
|
if strip_trailing_spaces:
|
||||||
prefix = prefix.rstrip(' ')
|
prefix = prefix.rstrip(' ')
|
||||||
|
|
@ -399,10 +264,6 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||||
|
|
||||||
messages.append({"role": "user", "content": user_input})
|
messages.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
# Create template handler based on current template
|
|
||||||
template_str = state['instruction_template_str'] if state['mode'] == 'instruct' else chat_template_str
|
|
||||||
handler = create_template_handler(template_str)
|
|
||||||
|
|
||||||
def make_prompt(messages):
|
def make_prompt(messages):
|
||||||
if state['mode'] == 'chat-instruct' and _continue:
|
if state['mode'] == 'chat-instruct' and _continue:
|
||||||
prompt = renderer(messages=messages[:-1])
|
prompt = renderer(messages=messages[:-1])
|
||||||
|
|
@ -420,10 +281,10 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||||
command = replace_character_names(command, state['name1'], state['name2'])
|
command = replace_character_names(command, state['name1'], state['name2'])
|
||||||
|
|
||||||
if _continue:
|
if _continue:
|
||||||
prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
|
prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
|
||||||
prefix += messages[-1]["content"]
|
prefix += messages[-1]["content"]
|
||||||
else:
|
else:
|
||||||
prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate)[0]
|
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
|
||||||
if not impersonate:
|
if not impersonate:
|
||||||
prefix = apply_extensions('bot_prefix', prefix, state)
|
prefix = apply_extensions('bot_prefix', prefix, state)
|
||||||
|
|
||||||
|
|
@ -431,14 +292,16 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||||
outer_messages.append({"role": "assistant", "content": prefix})
|
outer_messages.append({"role": "assistant", "content": prefix})
|
||||||
|
|
||||||
prompt = instruct_renderer(messages=outer_messages)
|
prompt = instruct_renderer(messages=outer_messages)
|
||||||
suffix = handler.get_generation_prefix_suffix(instruct_renderer, impersonate=False)[1]
|
suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
|
||||||
if len(suffix) > 0:
|
if len(suffix) > 0:
|
||||||
prompt = prompt[:-len(suffix)]
|
prompt = prompt[:-len(suffix)]
|
||||||
else:
|
else:
|
||||||
if _continue:
|
if _continue:
|
||||||
prompt = handler.modify_for_continue(prompt, renderer, impersonate)
|
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
|
||||||
|
if len(suffix) > 0:
|
||||||
|
prompt = prompt[:-len(suffix)]
|
||||||
else:
|
else:
|
||||||
prefix = handler.get_generation_prefix_suffix(renderer, impersonate=impersonate)[0]
|
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
|
||||||
if state['mode'] == 'chat' and not impersonate:
|
if state['mode'] == 'chat' and not impersonate:
|
||||||
prefix = apply_extensions('bot_prefix', prefix, state)
|
prefix = apply_extensions('bot_prefix', prefix, state)
|
||||||
|
|
||||||
|
|
@ -564,16 +427,31 @@ def get_stopping_strings(state):
|
||||||
if state['mode'] in ['instruct', 'chat-instruct']:
|
if state['mode'] in ['instruct', 'chat-instruct']:
|
||||||
template = jinja_env.from_string(state['instruction_template_str'])
|
template = jinja_env.from_string(state['instruction_template_str'])
|
||||||
renderer = partial(template.render, add_generation_prompt=False)
|
renderer = partial(template.render, add_generation_prompt=False)
|
||||||
renderers.append((renderer, state['instruction_template_str']))
|
renderers.append(renderer)
|
||||||
|
|
||||||
if state['mode'] in ['chat', 'chat-instruct']:
|
if state['mode'] in ['chat', 'chat-instruct']:
|
||||||
template = jinja_env.from_string(state['chat_template_str'])
|
template = jinja_env.from_string(state['chat_template_str'])
|
||||||
renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'])
|
renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'])
|
||||||
renderers.append((renderer, state['chat_template_str']))
|
renderers.append(renderer)
|
||||||
|
|
||||||
for renderer, template_str in renderers:
|
for renderer in renderers:
|
||||||
handler = create_template_handler(template_str)
|
prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False)
|
||||||
stopping_strings += handler.get_stopping_strings(renderer)
|
prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True)
|
||||||
|
|
||||||
|
stopping_strings += [
|
||||||
|
suffix_user + prefix_bot,
|
||||||
|
suffix_user + prefix_user,
|
||||||
|
suffix_bot + prefix_bot,
|
||||||
|
suffix_bot + prefix_user,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Try to find the EOT token
|
||||||
|
for item in stopping_strings.copy():
|
||||||
|
item = item.strip()
|
||||||
|
if item.startswith("<") and ">" in item:
|
||||||
|
stopping_strings.append(item.split(">")[0] + ">")
|
||||||
|
elif item.startswith("[") and "]" in item:
|
||||||
|
stopping_strings.append(item.split("]")[0] + "]")
|
||||||
|
|
||||||
if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
|
if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
|
||||||
stopping_strings += state.pop('stopping_strings')
|
stopping_strings += state.pop('stopping_strings')
|
||||||
|
|
@ -772,16 +650,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
||||||
output = apply_extensions('history', output)
|
output = apply_extensions('history', output)
|
||||||
state = apply_extensions('state', state)
|
state = apply_extensions('state', state)
|
||||||
|
|
||||||
# Automatically set skip_special_tokens to False for channel-based templates
|
|
||||||
if state['mode'] in ['instruct', 'chat-instruct']:
|
|
||||||
template_str = state['instruction_template_str']
|
|
||||||
else: # chat mode
|
|
||||||
template_str = state['chat_template_str']
|
|
||||||
|
|
||||||
handler = create_template_handler(template_str)
|
|
||||||
if isinstance(handler, ChannelTemplateHandler):
|
|
||||||
state['skip_special_tokens'] = False
|
|
||||||
|
|
||||||
# Let the jinja2 template handle the BOS token
|
# Let the jinja2 template handle the BOS token
|
||||||
if state['mode'] in ['instruct', 'chat-instruct']:
|
if state['mode'] in ['instruct', 'chat-instruct']:
|
||||||
state['add_bos_token'] = False
|
state['add_bos_token'] = False
|
||||||
|
|
@ -941,10 +809,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
||||||
|
|
||||||
|
|
||||||
def impersonate_wrapper(textbox, state):
|
def impersonate_wrapper(textbox, state):
|
||||||
# Check template support first
|
|
||||||
template_str = state['chat_template_str']
|
|
||||||
handler = create_template_handler(template_str)
|
|
||||||
|
|
||||||
text = textbox['text']
|
text = textbox['text']
|
||||||
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
|
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue