diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 4d6018f9..b1979cbc 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -42,6 +42,7 @@ class GenerationOptions(BaseModel): auto_max_new_tokens: bool = False ban_eos_token: bool = False add_bos_token: bool = True + enable_thinking: bool = True skip_special_tokens: bool = True static_cache: bool = False truncation_length: int = 0 diff --git a/modules/chat.py b/modules/chat.py index e117e6ee..98913d5c 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -90,6 +90,44 @@ def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=Tru return prefix, suffix +def get_thinking_suppression_string(template): + """ + Determines what string needs to be added to suppress thinking mode + by comparing template renderings with thinking enabled vs disabled. + """ + + # Render with thinking enabled + with_thinking = template.render( + messages=[{'role': 'user', 'content': ''}], + builtin_tools=None, + tools=None, + tools_in_user_message=False, + add_generation_prompt=True, + enable_thinking=True + ) + + # Render with thinking disabled + without_thinking = template.render( + messages=[{'role': 'user', 'content': ''}], + builtin_tools=None, + tools=None, + tools_in_user_message=False, + add_generation_prompt=True, + enable_thinking=False + ) + + # Find the difference (what gets added to suppress thinking) + i = 0 + while i < min(len(with_thinking), len(without_thinking)) and with_thinking[i] == without_thinking[i]: + i += 1 + + j = 0 + while j < min(len(with_thinking), len(without_thinking)) - i and with_thinking[-1 - j] == without_thinking[-1 - j]: + j += 1 + + return without_thinking[i:len(without_thinking) - j if j else None] + + def generate_chat_prompt(user_input, state, **kwargs): impersonate = kwargs.get('impersonate', False) _continue = kwargs.get('_continue', False) @@ -147,13 +185,6 @@ def generate_chat_prompt(user_input, state, **kwargs): if user_input and not impersonate and not _continue: messages.append({"role": "user", "content": user_input}) - def remove_extra_bos(prompt): - for bos_token in ['', '<|startoftext|>', '', '<|endoftext|>']: - while prompt.startswith(bos_token): - prompt = prompt[len(bos_token):] - - return prompt - def make_prompt(messages): if state['mode'] == 'chat-instruct' and _continue: prompt = renderer(messages=messages[:-1]) @@ -165,7 +196,6 @@ def generate_chat_prompt(user_input, state, **kwargs): if state['custom_system_message'].strip() != '': outer_messages.append({"role": "system", "content": state['custom_system_message']}) - prompt = remove_extra_bos(prompt) command = state['chat-instruct_command'] command = command.replace('<|character|>', state['name2'] if not impersonate else state['name1']) command = command.replace('<|prompt|>', prompt) @@ -182,11 +212,10 @@ def generate_chat_prompt(user_input, state, **kwargs): outer_messages.append({"role": "user", "content": command}) outer_messages.append({"role": "assistant", "content": prefix}) - prompt = instruction_template.render(messages=outer_messages) + prompt = instruct_renderer(messages=outer_messages) suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1] if len(suffix) > 0: prompt = prompt[:-len(suffix)] - else: if _continue: suffix = get_generation_prompt(renderer, impersonate=impersonate)[1] @@ -199,7 +228,9 @@ def generate_chat_prompt(user_input, state, **kwargs): prompt += prefix - prompt = remove_extra_bos(prompt) + if state['mode'] == 'instruct' and not any((_continue, impersonate, state['enable_thinking'])): + prompt += get_thinking_suppression_string(instruction_template) + return prompt prompt = make_prompt(messages) diff --git a/modules/loaders.py b/modules/loaders.py index b8ae82d7..738198b1 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -143,6 +143,7 @@ def transformers_samplers(): 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', + 'enable_thinking', 'skip_special_tokens', 'static_cache', 'seed', @@ -195,6 +196,7 @@ loaders_samplers = { 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', + 'enable_thinking', 'skip_special_tokens', 'seed', 'sampler_priority', @@ -241,6 +243,7 @@ loaders_samplers = { 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', + 'enable_thinking', 'skip_special_tokens', 'seed', 'sampler_priority', @@ -279,6 +282,7 @@ loaders_samplers = { 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', + 'enable_thinking', 'skip_special_tokens', 'seed', 'custom_token_bans', @@ -311,6 +315,7 @@ loaders_samplers = { 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', + 'enable_thinking', 'seed', 'sampler_priority', 'dry_sequence_breakers', diff --git a/modules/shared.py b/modules/shared.py index 5d9dd362..4c1179e3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -51,6 +51,7 @@ settings = { 'auto_max_new_tokens': True, 'ban_eos_token': False, 'add_bos_token': True, + 'enable_thinking': True, 'skip_special_tokens': True, 'stream': True, 'static_cache': False, diff --git a/modules/ui.py b/modules/ui.py index f137e62d..4105f53b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -198,6 +198,7 @@ def list_interface_input_elements(): 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', + 'enable_thinking', 'skip_special_tokens', 'stream', 'static_cache', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 6c2715af..3f609d71 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -82,6 +82,7 @@ def create_ui(default_preset): shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.') shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.') shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') + shared.gradio['enable_thinking'] = gr.Checkbox(value=shared.settings['enable_thinking'], label='enable_thinking', info='Used by Qwen3 to toggle mode.') shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming') shared.gradio['static_cache'] = gr.Checkbox(value=shared.settings['static_cache'], label='Static KV cache', info='Use a static cache for improved performance.') diff --git a/user_data/settings-template.yaml b/user_data/settings-template.yaml index 83764f97..80923276 100644 --- a/user_data/settings-template.yaml +++ b/user_data/settings-template.yaml @@ -22,6 +22,7 @@ max_updates_second: 0 auto_max_new_tokens: true ban_eos_token: false add_bos_token: true +enable_thinking: true skip_special_tokens: true stream: true static_cache: false