Handle GPT-OSS as a special template case

This commit is contained in:
oobabooga 2025-08-05 18:05:09 -07:00
parent fbea21a1f1
commit 7c82d65a9d

View file

@ -302,6 +302,13 @@ def generate_chat_prompt(user_input, state, **kwargs):
prompt = prompt[:-len(suffix)]
else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
# Handle GPT-OSS as a special case
if '<|channel|>final<|message|>' in state['instruction_template_str']:
prefix = prefix.rstrip("<|channel|>final<|message|>")
if impersonate:
prefix += "<|message|>"
if state['mode'] == 'chat' and not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)
@ -460,6 +467,12 @@ def get_stopping_strings(state):
result = [item for item in stopping_strings if not any(item.startswith(other) and item != other for other in stopping_strings)]
result = list(set(result))
# Handle GPT-OSS as a special case
if '<|channel|>final<|message|>' in state['instruction_template_str'] and "<|end|>" in result:
result.remove("<|end|>")
result.append("<|result|>")
result = list(set(result))
if shared.args.verbose:
logger.info("STOPPING_STRINGS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(result)
@ -650,6 +663,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
output = apply_extensions('history', output)
state = apply_extensions('state', state)
# Handle GPT-OSS as a special case
if '<|channel|>final<|message|>' in state['instruction_template_str']:
state['skip_special_tokens'] = False
# Let the jinja2 template handle the BOS token
if state['mode'] in ['instruct', 'chat-instruct']:
state['add_bos_token'] = False