Fix the GPT-OSS template

This commit is contained in:
oobabooga 2025-08-06 06:42:45 -07:00
parent 7c82d65a9d
commit 6ce4b353c4

View file

@ -211,7 +211,39 @@ def generate_chat_prompt(user_input, state, **kwargs):
messages.insert(insert_pos, {"role": "tool", "content": tool_msg})
if assistant_msg:
messages.insert(insert_pos, {"role": "assistant", "content": assistant_msg})
# Handle GPT-OSS as a special case
if '<|channel|>analysis<|message|>' in assistant_msg or '<|channel|>final<|message|>' in assistant_msg:
thinking_content = ""
final_content = ""
# Extract analysis content if present
if '<|channel|>analysis<|message|>' in assistant_msg:
analysis_start = assistant_msg.find('<|channel|>analysis<|message|>') + len('<|channel|>analysis<|message|>')
if '<|start|>assistant<|channel|>final<|message|>' in assistant_msg:
analysis_end = assistant_msg.find('<|start|>assistant<|channel|>final<|message|>')
else:
analysis_end = len(assistant_msg)
thinking_content = assistant_msg[analysis_start:analysis_end].strip()
# Extract final content if present
if '<|start|>assistant<|channel|>final<|message|>' in assistant_msg:
final_start = assistant_msg.find('<|start|>assistant<|channel|>final<|message|>') + len('<|start|>assistant<|channel|>final<|message|>')
final_content = assistant_msg[final_start:].strip()
elif '<|channel|>final<|message|>' in assistant_msg:
final_start = assistant_msg.find('<|channel|>final<|message|>') + len('<|channel|>final<|message|>')
final_content = assistant_msg[final_start:].strip()
# Insert as structured message
msg_dict = {"role": "assistant", "content": final_content}
if thinking_content:
msg_dict["thinking"] = thinking_content
messages.insert(insert_pos, msg_dict)
else:
messages.insert(insert_pos, {"role": "assistant", "content": assistant_msg})
if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
# Check for user message attachments in metadata
@ -305,7 +337,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Handle GPT-OSS as a special case
if '<|channel|>final<|message|>' in state['instruction_template_str']:
prefix = prefix.rstrip("<|channel|>final<|message|>")
if prefix.endswith("<|channel|>final<|message|>"):
prefix = prefix[:-len("<|channel|>final<|message|>")]
if impersonate:
prefix += "<|message|>"