mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-05 00:01:01 +01:00
chat.py code simplifications
This commit is contained in:
parent
d08800c359
commit
f919cdf881
128
modules/chat.py
128
modules/chat.py
|
|
@ -86,36 +86,6 @@ yaml.add_representer(str, str_presenter)
|
|||
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
|
||||
|
||||
|
||||
def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True):
|
||||
'''
|
||||
Given a Jinja template, reverse-engineers the prefix and the suffix for
|
||||
an assistant message (if impersonate=False) or an user message
|
||||
(if impersonate=True)
|
||||
'''
|
||||
|
||||
if impersonate:
|
||||
messages = [
|
||||
{"role": "user", "content": "<<|user-message-1|>>"},
|
||||
{"role": "user", "content": "<<|user-message-2|>>"},
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{"role": "assistant", "content": "<<|user-message-1|>>"},
|
||||
{"role": "assistant", "content": "<<|user-message-2|>>"},
|
||||
]
|
||||
|
||||
prompt = renderer(messages=messages)
|
||||
|
||||
suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
|
||||
suffix = prompt.split("<<|user-message-2|>>")[1]
|
||||
prefix = suffix_plus_prefix[len(suffix):]
|
||||
|
||||
if strip_trailing_spaces:
|
||||
prefix = prefix.rstrip(' ')
|
||||
|
||||
return prefix, suffix
|
||||
|
||||
|
||||
def get_thinking_suppression_string(template):
|
||||
"""
|
||||
Determines what string needs to be added to suppress thinking mode
|
||||
|
|
@ -341,26 +311,16 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
command = command.replace('<|prompt|>', prompt)
|
||||
command = replace_character_names(command, state['name1'], state['name2'])
|
||||
|
||||
if _continue:
|
||||
prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
|
||||
prefix += messages[-1]["content"]
|
||||
else:
|
||||
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
|
||||
if not impersonate:
|
||||
prefix = apply_extensions('bot_prefix', prefix, state)
|
||||
|
||||
suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
|
||||
|
||||
outer_messages = []
|
||||
if state['custom_system_message'].strip() != '':
|
||||
outer_messages.append({"role": "system", "content": state['custom_system_message']})
|
||||
|
||||
outer_messages.append({"role": "user", "content": command})
|
||||
outer_messages.append({"role": "assistant", "content": prefix})
|
||||
|
||||
prompt = instruct_renderer(messages=outer_messages)
|
||||
if len(suffix) > 0:
|
||||
prompt = prompt[:-len(suffix)]
|
||||
prompt = instruct_renderer(
|
||||
messages=outer_messages,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
# Handle GPT-OSS as a special case when continuing
|
||||
# (otherwise the thinking block gets removed...)
|
||||
|
|
@ -375,29 +335,10 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
assistant_reply_so_far += f"<|channel|>final<|message|>{last_message_to_continue.get('content', '')}"
|
||||
prompt += assistant_reply_so_far
|
||||
else:
|
||||
prompt = renderer(messages=messages)
|
||||
if _continue:
|
||||
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
|
||||
if len(suffix) > 0:
|
||||
prompt = prompt[:-len(suffix)]
|
||||
else:
|
||||
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
|
||||
|
||||
# Handle GPT-OSS as a special case when not continuing
|
||||
if '<|channel|>final<|message|>' in state['instruction_template_str']:
|
||||
if prefix.endswith("<|channel|>final<|message|>"):
|
||||
prefix = prefix[:-len("<|channel|>final<|message|>")]
|
||||
|
||||
if impersonate:
|
||||
prefix += "<|message|>"
|
||||
|
||||
if state['mode'] == 'chat' and not impersonate:
|
||||
prefix = apply_extensions('bot_prefix', prefix, state)
|
||||
|
||||
prompt += prefix
|
||||
|
||||
if state['mode'] == 'instruct' and 'enable_thinking' in state['instruction_template_str'] and not any((_continue, impersonate, state['enable_thinking'])):
|
||||
prompt += get_thinking_suppression_string(instruction_template)
|
||||
prompt = renderer(
|
||||
messages=messages,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
|
@ -523,24 +464,41 @@ def get_stopping_strings(state):
|
|||
renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'])
|
||||
renderers.append(renderer)
|
||||
|
||||
for renderer in renderers:
|
||||
prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False)
|
||||
prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True)
|
||||
fake_messages = [
|
||||
{"role": "user", "content": "first user message"},
|
||||
{"role": "assistant", "content": "first assistant message"},
|
||||
{"role": "user", "content": "second user message"},
|
||||
{"role": "assistant", "content": "second assistant message"},
|
||||
]
|
||||
|
||||
stopping_strings += [
|
||||
suffix_user + prefix_bot,
|
||||
suffix_user + prefix_user,
|
||||
suffix_bot + prefix_bot,
|
||||
suffix_bot + prefix_user,
|
||||
stopping_strings = []
|
||||
for renderer in renderers:
|
||||
prompt = renderer(messages=fake_messages)
|
||||
|
||||
# Find positions of each message content
|
||||
first_user_end = prompt.find("first user message") + len("first user message")
|
||||
first_assistant_start = prompt.find("first assistant message")
|
||||
first_assistant_end = prompt.find("first assistant message") + len("first assistant message")
|
||||
second_user_start = prompt.find("second user message")
|
||||
second_assistant_end = prompt.find("second assistant message") + len("second assistant message")
|
||||
|
||||
# Extract pieces of text potentially containing unique stopping strings
|
||||
texts = [
|
||||
prompt[first_user_end:first_assistant_start],
|
||||
prompt[first_assistant_end:second_user_start],
|
||||
prompt[second_assistant_end:]
|
||||
]
|
||||
|
||||
# Try to find the EOT token
|
||||
for item in stopping_strings.copy():
|
||||
item = item.strip()
|
||||
if item.startswith("<") and ">" in item:
|
||||
stopping_strings.append(item.split(">")[0] + ">")
|
||||
elif item.startswith("[") and "]" in item:
|
||||
stopping_strings.append(item.split("]")[0] + "]")
|
||||
for text in texts:
|
||||
text = text.strip()
|
||||
if text.startswith("<") and ">" in text:
|
||||
stopping_strings.append(text.split(">")[0] + ">")
|
||||
elif text.startswith("[") and "]" in text:
|
||||
stopping_strings.append(text.split("]")[0] + "]")
|
||||
elif text.startswith("(") and ")" in text:
|
||||
stopping_strings.append(text.split(")")[0] + ")")
|
||||
elif text.startswith("{") and "}" in text:
|
||||
stopping_strings.append(text.split("}")[0] + "}")
|
||||
|
||||
if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
|
||||
stopping_strings += state.pop('stopping_strings')
|
||||
|
|
@ -549,12 +507,6 @@ def get_stopping_strings(state):
|
|||
result = [item for item in stopping_strings if not any(item.startswith(other) and item != other for other in stopping_strings)]
|
||||
result = list(set(result))
|
||||
|
||||
# Handle GPT-OSS as a special case
|
||||
if '<|channel|>final<|message|>' in state['instruction_template_str'] and "<|end|>" in result:
|
||||
result.remove("<|end|>")
|
||||
result.append("<|result|>")
|
||||
result = list(set(result))
|
||||
|
||||
if shared.args.verbose:
|
||||
logger.info("STOPPING_STRINGS=")
|
||||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(result)
|
||||
|
|
|
|||
Loading…
Reference in a new issue