diff --git a/modules/chat.py b/modules/chat.py index e7fd86f9..66f89c70 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -302,6 +302,13 @@ def generate_chat_prompt(user_input, state, **kwargs): prompt = prompt[:-len(suffix)] else: prefix = get_generation_prompt(renderer, impersonate=impersonate)[0] + + # Handle GPT-OSS as a special case + if '<|channel|>final<|message|>' in state['instruction_template_str']: + prefix = prefix.rstrip("<|channel|>final<|message|>") + if impersonate: + prefix += "<|message|>" + if state['mode'] == 'chat' and not impersonate: prefix = apply_extensions('bot_prefix', prefix, state) @@ -460,6 +467,12 @@ def get_stopping_strings(state): result = [item for item in stopping_strings if not any(item.startswith(other) and item != other for other in stopping_strings)] result = list(set(result)) + # Handle GPT-OSS as a special case + if '<|channel|>final<|message|>' in state['instruction_template_str'] and "<|end|>" in result: + result.remove("<|end|>") + result.append("<|result|>") + result = list(set(result)) + if shared.args.verbose: logger.info("STOPPING_STRINGS=") pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(result) @@ -650,6 +663,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess output = apply_extensions('history', output) state = apply_extensions('state', state) + # Handle GPT-OSS as a special case + if '<|channel|>final<|message|>' in state['instruction_template_str']: + state['skip_special_tokens'] = False + # Let the jinja2 template handle the BOS token if state['mode'] in ['instruct', 'chat-instruct']: state['add_bos_token'] = False