UI: Add an enable_thinking option to enable/disable Qwen3 thinking

This commit is contained in:
oobabooga 2025-04-28 22:37:01 -07:00
parent 1ee0acc852
commit d10bded7f8
7 changed files with 52 additions and 11 deletions

View file

@ -42,6 +42,7 @@ class GenerationOptions(BaseModel):
auto_max_new_tokens: bool = False
ban_eos_token: bool = False
add_bos_token: bool = True
enable_thinking: bool = True
skip_special_tokens: bool = True
static_cache: bool = False
truncation_length: int = 0

View file

@ -90,6 +90,44 @@ def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=Tru
return prefix, suffix
def get_thinking_suppression_string(template):
"""
Determines what string needs to be added to suppress thinking mode
by comparing template renderings with thinking enabled vs disabled.
"""
# Render with thinking enabled
with_thinking = template.render(
messages=[{'role': 'user', 'content': ''}],
builtin_tools=None,
tools=None,
tools_in_user_message=False,
add_generation_prompt=True,
enable_thinking=True
)
# Render with thinking disabled
without_thinking = template.render(
messages=[{'role': 'user', 'content': ''}],
builtin_tools=None,
tools=None,
tools_in_user_message=False,
add_generation_prompt=True,
enable_thinking=False
)
# Find the difference (what gets added to suppress thinking)
i = 0
while i < min(len(with_thinking), len(without_thinking)) and with_thinking[i] == without_thinking[i]:
i += 1
j = 0
while j < min(len(with_thinking), len(without_thinking)) - i and with_thinking[-1 - j] == without_thinking[-1 - j]:
j += 1
return without_thinking[i:len(without_thinking) - j if j else None]
def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs.get('impersonate', False)
_continue = kwargs.get('_continue', False)
@ -147,13 +185,6 @@ def generate_chat_prompt(user_input, state, **kwargs):
if user_input and not impersonate and not _continue:
messages.append({"role": "user", "content": user_input})
def remove_extra_bos(prompt):
for bos_token in ['<s>', '<|startoftext|>', '<BOS_TOKEN>', '<|endoftext|>']:
while prompt.startswith(bos_token):
prompt = prompt[len(bos_token):]
return prompt
def make_prompt(messages):
if state['mode'] == 'chat-instruct' and _continue:
prompt = renderer(messages=messages[:-1])
@ -165,7 +196,6 @@ def generate_chat_prompt(user_input, state, **kwargs):
if state['custom_system_message'].strip() != '':
outer_messages.append({"role": "system", "content": state['custom_system_message']})
prompt = remove_extra_bos(prompt)
command = state['chat-instruct_command']
command = command.replace('<|character|>', state['name2'] if not impersonate else state['name1'])
command = command.replace('<|prompt|>', prompt)
@ -182,11 +212,10 @@ def generate_chat_prompt(user_input, state, **kwargs):
outer_messages.append({"role": "user", "content": command})
outer_messages.append({"role": "assistant", "content": prefix})
prompt = instruction_template.render(messages=outer_messages)
prompt = instruct_renderer(messages=outer_messages)
suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
if len(suffix) > 0:
prompt = prompt[:-len(suffix)]
else:
if _continue:
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
@ -199,7 +228,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
prompt += prefix
prompt = remove_extra_bos(prompt)
if state['mode'] == 'instruct' and not any((_continue, impersonate, state['enable_thinking'])):
prompt += get_thinking_suppression_string(instruction_template)
return prompt
prompt = make_prompt(messages)

View file

@ -143,6 +143,7 @@ def transformers_samplers():
'auto_max_new_tokens',
'ban_eos_token',
'add_bos_token',
'enable_thinking',
'skip_special_tokens',
'static_cache',
'seed',
@ -195,6 +196,7 @@ loaders_samplers = {
'auto_max_new_tokens',
'ban_eos_token',
'add_bos_token',
'enable_thinking',
'skip_special_tokens',
'seed',
'sampler_priority',
@ -241,6 +243,7 @@ loaders_samplers = {
'auto_max_new_tokens',
'ban_eos_token',
'add_bos_token',
'enable_thinking',
'skip_special_tokens',
'seed',
'sampler_priority',
@ -279,6 +282,7 @@ loaders_samplers = {
'auto_max_new_tokens',
'ban_eos_token',
'add_bos_token',
'enable_thinking',
'skip_special_tokens',
'seed',
'custom_token_bans',
@ -311,6 +315,7 @@ loaders_samplers = {
'auto_max_new_tokens',
'ban_eos_token',
'add_bos_token',
'enable_thinking',
'seed',
'sampler_priority',
'dry_sequence_breakers',

View file

@ -51,6 +51,7 @@ settings = {
'auto_max_new_tokens': True,
'ban_eos_token': False,
'add_bos_token': True,
'enable_thinking': True,
'skip_special_tokens': True,
'stream': True,
'static_cache': False,

View file

@ -198,6 +198,7 @@ def list_interface_input_elements():
'auto_max_new_tokens',
'ban_eos_token',
'add_bos_token',
'enable_thinking',
'skip_special_tokens',
'stream',
'static_cache',

View file

@ -82,6 +82,7 @@ def create_ui(default_preset):
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
shared.gradio['enable_thinking'] = gr.Checkbox(value=shared.settings['enable_thinking'], label='enable_thinking', info='Used by Qwen3 to toggle <think> mode.')
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming')
shared.gradio['static_cache'] = gr.Checkbox(value=shared.settings['static_cache'], label='Static KV cache', info='Use a static cache for improved performance.')

View file

@ -22,6 +22,7 @@ max_updates_second: 0
auto_max_new_tokens: true
ban_eos_token: false
add_bos_token: true
enable_thinking: true
skip_special_tokens: true
stream: true
static_cache: false