Fix continue for GPT-OSS (hopefully the final fix)

This commit is contained in:
oobabooga 2025-08-06 10:18:42 -07:00
parent 0c1403f2c7
commit 3e24f455c8

View file

@ -219,21 +219,39 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Extract analysis content if present
if '<|channel|>analysis<|message|>' in assistant_msg:
analysis_start = assistant_msg.find('<|channel|>analysis<|message|>') + len('<|channel|>analysis<|message|>')
if '<|start|>assistant<|channel|>final<|message|>' in assistant_msg:
analysis_end = assistant_msg.find('<|start|>assistant<|channel|>final<|message|>')
else:
analysis_end = len(assistant_msg)
# Split the message by the analysis tag to isolate the content that follows
parts = assistant_msg.split('<|channel|>analysis<|message|>', 1)
if len(parts) > 1:
# The content is everything after the tag
potential_content = parts[1]
thinking_content = assistant_msg[analysis_start:analysis_end].strip()
# Now, find the end of this content block
analysis_end_tag = '<|end|>'
if analysis_end_tag in potential_content:
thinking_content = potential_content.split(analysis_end_tag, 1)[0].strip()
else:
# Fallback: if no <|end|> tag, stop at the start of the final channel if it exists
final_channel_tag = '<|channel|>final<|message|>'
if final_channel_tag in potential_content:
thinking_content = potential_content.split(final_channel_tag, 1)[0].strip()
else:
thinking_content = potential_content.strip()
# Extract final content if present
if '<|start|>assistant<|channel|>final<|message|>' in assistant_msg:
final_start = assistant_msg.find('<|start|>assistant<|channel|>final<|message|>') + len('<|start|>assistant<|channel|>final<|message|>')
final_content = assistant_msg[final_start:].strip()
elif '<|channel|>final<|message|>' in assistant_msg:
final_start = assistant_msg.find('<|channel|>final<|message|>') + len('<|channel|>final<|message|>')
final_content = assistant_msg[final_start:].strip()
final_tag_to_find = '<|channel|>final<|message|>'
if final_tag_to_find in assistant_msg:
# Split the message by the final tag to isolate the content that follows
parts = assistant_msg.split(final_tag_to_find, 1)
if len(parts) > 1:
# The content is everything after the tag
potential_content = parts[1]
# Now, find the end of this content block
final_end_tag = '<|end|>'
if final_end_tag in potential_content:
final_content = potential_content.split(final_end_tag, 1)[0].strip()
else:
final_content = potential_content.strip()
# Insert as structured message
msg_dict = {"role": "assistant", "content": final_content}
@ -330,16 +348,16 @@ def generate_chat_prompt(user_input, state, **kwargs):
else:
# Handle GPT-OSS as a special case when continuing
if _continue and '<|channel|>final<|message|>' in state['instruction_template_str']:
# This prevents the template from stripping the analysis block of the message being continued.
last_message_to_continue = messages[-1]
prompt = renderer(messages=messages[:-1])
assistant_reply_so_far = ""
if 'thinking' in last_message_to_continue:
assistant_reply_so_far += f"<|start|>assistant<|channel|>analysis<|message|>{last_message_to_continue['thinking']}<|end|>"
# Start the assistant turn wrapper
assistant_reply_so_far = "<|start|>assistant"
assistant_reply_so_far += f"<|start|>assistant<|channel|>final<|message|>{last_message_to_continue.get('content', '')}"
if 'thinking' in last_message_to_continue:
assistant_reply_so_far += f"<|channel|>analysis<|message|>{last_message_to_continue['thinking']}<|end|>"
assistant_reply_so_far += f"<|channel|>final<|message|>{last_message_to_continue.get('content', '')}"
prompt += assistant_reply_so_far