chat.py code simplifications

This commit is contained in:
oobabooga 2025-08-25 17:20:51 -07:00
parent d08800c359
commit f919cdf881

View file

@ -86,36 +86,6 @@ yaml.add_representer(str, str_presenter)
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True):
'''
Given a Jinja template, reverse-engineers the prefix and the suffix for
an assistant message (if impersonate=False) or an user message
(if impersonate=True)
'''
if impersonate:
messages = [
{"role": "user", "content": "<<|user-message-1|>>"},
{"role": "user", "content": "<<|user-message-2|>>"},
]
else:
messages = [
{"role": "assistant", "content": "<<|user-message-1|>>"},
{"role": "assistant", "content": "<<|user-message-2|>>"},
]
prompt = renderer(messages=messages)
suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
suffix = prompt.split("<<|user-message-2|>>")[1]
prefix = suffix_plus_prefix[len(suffix):]
if strip_trailing_spaces:
prefix = prefix.rstrip(' ')
return prefix, suffix
def get_thinking_suppression_string(template):
"""
Determines what string needs to be added to suppress thinking mode
@ -341,26 +311,16 @@ def generate_chat_prompt(user_input, state, **kwargs):
command = command.replace('<|prompt|>', prompt)
command = replace_character_names(command, state['name1'], state['name2'])
if _continue:
prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
prefix += messages[-1]["content"]
else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
if not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)
suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
outer_messages = []
if state['custom_system_message'].strip() != '':
outer_messages.append({"role": "system", "content": state['custom_system_message']})
outer_messages.append({"role": "user", "content": command})
outer_messages.append({"role": "assistant", "content": prefix})
prompt = instruct_renderer(messages=outer_messages)
if len(suffix) > 0:
prompt = prompt[:-len(suffix)]
prompt = instruct_renderer(
messages=outer_messages,
add_generation_prompt=True
)
else:
# Handle GPT-OSS as a special case when continuing
# (otherwise the thinking block gets removed...)
@ -375,29 +335,10 @@ def generate_chat_prompt(user_input, state, **kwargs):
assistant_reply_so_far += f"<|channel|>final<|message|>{last_message_to_continue.get('content', '')}"
prompt += assistant_reply_so_far
else:
prompt = renderer(messages=messages)
if _continue:
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
if len(suffix) > 0:
prompt = prompt[:-len(suffix)]
else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
# Handle GPT-OSS as a special case when not continuing
if '<|channel|>final<|message|>' in state['instruction_template_str']:
if prefix.endswith("<|channel|>final<|message|>"):
prefix = prefix[:-len("<|channel|>final<|message|>")]
if impersonate:
prefix += "<|message|>"
if state['mode'] == 'chat' and not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)
prompt += prefix
if state['mode'] == 'instruct' and 'enable_thinking' in state['instruction_template_str'] and not any((_continue, impersonate, state['enable_thinking'])):
prompt += get_thinking_suppression_string(instruction_template)
prompt = renderer(
messages=messages,
add_generation_prompt=True
)
return prompt
@ -523,24 +464,41 @@ def get_stopping_strings(state):
renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'])
renderers.append(renderer)
for renderer in renderers:
prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False)
prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True)
fake_messages = [
{"role": "user", "content": "first user message"},
{"role": "assistant", "content": "first assistant message"},
{"role": "user", "content": "second user message"},
{"role": "assistant", "content": "second assistant message"},
]
stopping_strings += [
suffix_user + prefix_bot,
suffix_user + prefix_user,
suffix_bot + prefix_bot,
suffix_bot + prefix_user,
stopping_strings = []
for renderer in renderers:
prompt = renderer(messages=fake_messages)
# Find positions of each message content
first_user_end = prompt.find("first user message") + len("first user message")
first_assistant_start = prompt.find("first assistant message")
first_assistant_end = prompt.find("first assistant message") + len("first assistant message")
second_user_start = prompt.find("second user message")
second_assistant_end = prompt.find("second assistant message") + len("second assistant message")
# Extract pieces of text potentially containing unique stopping strings
texts = [
prompt[first_user_end:first_assistant_start],
prompt[first_assistant_end:second_user_start],
prompt[second_assistant_end:]
]
# Try to find the EOT token
for item in stopping_strings.copy():
item = item.strip()
if item.startswith("<") and ">" in item:
stopping_strings.append(item.split(">")[0] + ">")
elif item.startswith("[") and "]" in item:
stopping_strings.append(item.split("]")[0] + "]")
for text in texts:
text = text.strip()
if text.startswith("<") and ">" in text:
stopping_strings.append(text.split(">")[0] + ">")
elif text.startswith("[") and "]" in text:
stopping_strings.append(text.split("]")[0] + "]")
elif text.startswith("(") and ")" in text:
stopping_strings.append(text.split(")")[0] + ")")
elif text.startswith("{") and "}" in text:
stopping_strings.append(text.split("}")[0] + "}")
if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
stopping_strings += state.pop('stopping_strings')
@ -549,12 +507,6 @@ def get_stopping_strings(state):
result = [item for item in stopping_strings if not any(item.startswith(other) and item != other for other in stopping_strings)]
result = list(set(result))
# Handle GPT-OSS as a special case
if '<|channel|>final<|message|>' in state['instruction_template_str'] and "<|end|>" in result:
result.remove("<|end|>")
result.append("<|result|>")
result = list(set(result))
if shared.args.verbose:
logger.info("STOPPING_STRINGS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(result)