mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
UI: Add an enable_thinking option to enable/disable Qwen3 thinking
This commit is contained in:
parent
1ee0acc852
commit
d10bded7f8
|
|
@ -42,6 +42,7 @@ class GenerationOptions(BaseModel):
|
|||
auto_max_new_tokens: bool = False
|
||||
ban_eos_token: bool = False
|
||||
add_bos_token: bool = True
|
||||
enable_thinking: bool = True
|
||||
skip_special_tokens: bool = True
|
||||
static_cache: bool = False
|
||||
truncation_length: int = 0
|
||||
|
|
|
|||
|
|
@ -90,6 +90,44 @@ def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=Tru
|
|||
return prefix, suffix
|
||||
|
||||
|
||||
def get_thinking_suppression_string(template):
|
||||
"""
|
||||
Determines what string needs to be added to suppress thinking mode
|
||||
by comparing template renderings with thinking enabled vs disabled.
|
||||
"""
|
||||
|
||||
# Render with thinking enabled
|
||||
with_thinking = template.render(
|
||||
messages=[{'role': 'user', 'content': ''}],
|
||||
builtin_tools=None,
|
||||
tools=None,
|
||||
tools_in_user_message=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True
|
||||
)
|
||||
|
||||
# Render with thinking disabled
|
||||
without_thinking = template.render(
|
||||
messages=[{'role': 'user', 'content': ''}],
|
||||
builtin_tools=None,
|
||||
tools=None,
|
||||
tools_in_user_message=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
|
||||
# Find the difference (what gets added to suppress thinking)
|
||||
i = 0
|
||||
while i < min(len(with_thinking), len(without_thinking)) and with_thinking[i] == without_thinking[i]:
|
||||
i += 1
|
||||
|
||||
j = 0
|
||||
while j < min(len(with_thinking), len(without_thinking)) - i and with_thinking[-1 - j] == without_thinking[-1 - j]:
|
||||
j += 1
|
||||
|
||||
return without_thinking[i:len(without_thinking) - j if j else None]
|
||||
|
||||
|
||||
def generate_chat_prompt(user_input, state, **kwargs):
|
||||
impersonate = kwargs.get('impersonate', False)
|
||||
_continue = kwargs.get('_continue', False)
|
||||
|
|
@ -147,13 +185,6 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
if user_input and not impersonate and not _continue:
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
def remove_extra_bos(prompt):
|
||||
for bos_token in ['<s>', '<|startoftext|>', '<BOS_TOKEN>', '<|endoftext|>']:
|
||||
while prompt.startswith(bos_token):
|
||||
prompt = prompt[len(bos_token):]
|
||||
|
||||
return prompt
|
||||
|
||||
def make_prompt(messages):
|
||||
if state['mode'] == 'chat-instruct' and _continue:
|
||||
prompt = renderer(messages=messages[:-1])
|
||||
|
|
@ -165,7 +196,6 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
if state['custom_system_message'].strip() != '':
|
||||
outer_messages.append({"role": "system", "content": state['custom_system_message']})
|
||||
|
||||
prompt = remove_extra_bos(prompt)
|
||||
command = state['chat-instruct_command']
|
||||
command = command.replace('<|character|>', state['name2'] if not impersonate else state['name1'])
|
||||
command = command.replace('<|prompt|>', prompt)
|
||||
|
|
@ -182,11 +212,10 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
outer_messages.append({"role": "user", "content": command})
|
||||
outer_messages.append({"role": "assistant", "content": prefix})
|
||||
|
||||
prompt = instruction_template.render(messages=outer_messages)
|
||||
prompt = instruct_renderer(messages=outer_messages)
|
||||
suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
|
||||
if len(suffix) > 0:
|
||||
prompt = prompt[:-len(suffix)]
|
||||
|
||||
else:
|
||||
if _continue:
|
||||
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
|
||||
|
|
@ -199,7 +228,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
|
||||
prompt += prefix
|
||||
|
||||
prompt = remove_extra_bos(prompt)
|
||||
if state['mode'] == 'instruct' and not any((_continue, impersonate, state['enable_thinking'])):
|
||||
prompt += get_thinking_suppression_string(instruction_template)
|
||||
|
||||
return prompt
|
||||
|
||||
prompt = make_prompt(messages)
|
||||
|
|
|
|||
|
|
@ -143,6 +143,7 @@ def transformers_samplers():
|
|||
'auto_max_new_tokens',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'enable_thinking',
|
||||
'skip_special_tokens',
|
||||
'static_cache',
|
||||
'seed',
|
||||
|
|
@ -195,6 +196,7 @@ loaders_samplers = {
|
|||
'auto_max_new_tokens',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'enable_thinking',
|
||||
'skip_special_tokens',
|
||||
'seed',
|
||||
'sampler_priority',
|
||||
|
|
@ -241,6 +243,7 @@ loaders_samplers = {
|
|||
'auto_max_new_tokens',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'enable_thinking',
|
||||
'skip_special_tokens',
|
||||
'seed',
|
||||
'sampler_priority',
|
||||
|
|
@ -279,6 +282,7 @@ loaders_samplers = {
|
|||
'auto_max_new_tokens',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'enable_thinking',
|
||||
'skip_special_tokens',
|
||||
'seed',
|
||||
'custom_token_bans',
|
||||
|
|
@ -311,6 +315,7 @@ loaders_samplers = {
|
|||
'auto_max_new_tokens',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'enable_thinking',
|
||||
'seed',
|
||||
'sampler_priority',
|
||||
'dry_sequence_breakers',
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ settings = {
|
|||
'auto_max_new_tokens': True,
|
||||
'ban_eos_token': False,
|
||||
'add_bos_token': True,
|
||||
'enable_thinking': True,
|
||||
'skip_special_tokens': True,
|
||||
'stream': True,
|
||||
'static_cache': False,
|
||||
|
|
|
|||
|
|
@ -198,6 +198,7 @@ def list_interface_input_elements():
|
|||
'auto_max_new_tokens',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'enable_thinking',
|
||||
'skip_special_tokens',
|
||||
'stream',
|
||||
'static_cache',
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ def create_ui(default_preset):
|
|||
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
|
||||
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
|
||||
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
||||
shared.gradio['enable_thinking'] = gr.Checkbox(value=shared.settings['enable_thinking'], label='enable_thinking', info='Used by Qwen3 to toggle <think> mode.')
|
||||
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
||||
shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming')
|
||||
shared.gradio['static_cache'] = gr.Checkbox(value=shared.settings['static_cache'], label='Static KV cache', info='Use a static cache for improved performance.')
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ max_updates_second: 0
|
|||
auto_max_new_tokens: true
|
||||
ban_eos_token: false
|
||||
add_bos_token: true
|
||||
enable_thinking: true
|
||||
skip_special_tokens: true
|
||||
stream: true
|
||||
static_cache: false
|
||||
|
|
|
|||
Loading…
Reference in a new issue