Handle GPT-OSS as a special case when continuing

This commit is contained in:
oobabooga 2025-08-06 08:05:37 -07:00
parent 6ce4b353c4
commit 0c1403f2c7

View file

@ -237,7 +237,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Insert as structured message
msg_dict = {"role": "assistant", "content": final_content}
if thinking_content:
if '<|channel|>analysis<|message|>' in assistant_msg:
msg_dict["thinking"] = thinking_content
messages.insert(insert_pos, msg_dict)
@ -328,25 +328,42 @@ def generate_chat_prompt(user_input, state, **kwargs):
if len(suffix) > 0:
prompt = prompt[:-len(suffix)]
else:
if _continue:
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
if len(suffix) > 0:
prompt = prompt[:-len(suffix)]
# Handle GPT-OSS as a special case when continuing
if _continue and '<|channel|>final<|message|>' in state['instruction_template_str']:
# This prevents the template from stripping the analysis block of the message being continued.
last_message_to_continue = messages[-1]
prompt = renderer(messages=messages[:-1])
assistant_reply_so_far = ""
if 'thinking' in last_message_to_continue:
assistant_reply_so_far += f"<|start|>assistant<|channel|>analysis<|message|>{last_message_to_continue['thinking']}<|end|>"
assistant_reply_so_far += f"<|start|>assistant<|channel|>final<|message|>{last_message_to_continue.get('content', '')}"
prompt += assistant_reply_so_far
else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
prompt = renderer(messages=messages)
if _continue:
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
if len(suffix) > 0:
prompt = prompt[:-len(suffix)]
else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
# Handle GPT-OSS as a special case
if '<|channel|>final<|message|>' in state['instruction_template_str']:
if prefix.endswith("<|channel|>final<|message|>"):
prefix = prefix[:-len("<|channel|>final<|message|>")]
# Handle GPT-OSS as a special case when not continuing
if '<|channel|>final<|message|>' in state['instruction_template_str']:
if prefix.endswith("<|channel|>final<|message|>"):
prefix = prefix[:-len("<|channel|>final<|message|>")]
if impersonate:
prefix += "<|message|>"
if impersonate:
prefix += "<|message|>"
if state['mode'] == 'chat' and not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)
if state['mode'] == 'chat' and not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)
prompt += prefix
prompt += prefix
if state['mode'] == 'instruct' and 'enable_thinking' in state['instruction_template_str'] and not any((_continue, impersonate, state['enable_thinking'])):
prompt += get_thinking_suppression_string(instruction_template)