mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Handle GPT-OSS as a special template case
This commit is contained in:
parent
fbea21a1f1
commit
7c82d65a9d
|
|
@ -302,6 +302,13 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
prompt = prompt[:-len(suffix)]
|
||||
else:
|
||||
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
|
||||
|
||||
# Handle GPT-OSS as a special case
|
||||
if '<|channel|>final<|message|>' in state['instruction_template_str']:
|
||||
prefix = prefix.rstrip("<|channel|>final<|message|>")
|
||||
if impersonate:
|
||||
prefix += "<|message|>"
|
||||
|
||||
if state['mode'] == 'chat' and not impersonate:
|
||||
prefix = apply_extensions('bot_prefix', prefix, state)
|
||||
|
||||
|
|
@ -460,6 +467,12 @@ def get_stopping_strings(state):
|
|||
result = [item for item in stopping_strings if not any(item.startswith(other) and item != other for other in stopping_strings)]
|
||||
result = list(set(result))
|
||||
|
||||
# Handle GPT-OSS as a special case
|
||||
if '<|channel|>final<|message|>' in state['instruction_template_str'] and "<|end|>" in result:
|
||||
result.remove("<|end|>")
|
||||
result.append("<|result|>")
|
||||
result = list(set(result))
|
||||
|
||||
if shared.args.verbose:
|
||||
logger.info("STOPPING_STRINGS=")
|
||||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(result)
|
||||
|
|
@ -650,6 +663,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
output = apply_extensions('history', output)
|
||||
state = apply_extensions('state', state)
|
||||
|
||||
# Handle GPT-OSS as a special case
|
||||
if '<|channel|>final<|message|>' in state['instruction_template_str']:
|
||||
state['skip_special_tokens'] = False
|
||||
|
||||
# Let the jinja2 template handle the BOS token
|
||||
if state['mode'] in ['instruct', 'chat-instruct']:
|
||||
state['add_bos_token'] = False
|
||||
|
|
|
|||
Loading…
Reference in a new issue