UI: Add a collapsible thinking block to messages with <think> steps (#6902)

This commit is contained in:
oobabooga 2025-04-25 18:02:02 -03:00 committed by GitHub
parent 0dd71e78c9
commit d35818f4e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 238 additions and 27 deletions

View file

@ -417,16 +417,8 @@ def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_
yield history
return
show_after = html.escape(state.get("show_after")) if state.get("show_after") else None
for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message, for_ui=for_ui):
if show_after:
after = history["visible"][-1][1].partition(show_after)[2] or "*Is thinking...*"
yield {
'internal': history['internal'],
'visible': history['visible'][:-1] + [[history['visible'][-1][0], after]]
}
else:
yield history
yield history
def character_is_loaded(state, raise_exception=False):

View file

@ -107,8 +107,87 @@ def replace_blockquote(m):
return m.group().replace('\n', '\n> ').replace('\\begin{blockquote}', '').replace('\\end{blockquote}', '')
def extract_thinking_block(string):
"""Extract thinking blocks from the beginning of a string."""
if not string:
return None, string
THINK_START_TAG = "&lt;think&gt;"
THINK_END_TAG = "&lt;/think&gt;"
# Look for opening tag
start_pos = string.lstrip().find(THINK_START_TAG)
if start_pos == -1:
return None, string
# Adjust start position to account for any leading whitespace
start_pos = string.find(THINK_START_TAG)
# Find the content after the opening tag
content_start = start_pos + len(THINK_START_TAG)
# Look for closing tag
end_pos = string.find(THINK_END_TAG, content_start)
if end_pos != -1:
# Both tags found - extract content between them
thinking_content = string[content_start:end_pos]
remaining_content = string[end_pos + len(THINK_END_TAG):]
return thinking_content, remaining_content
else:
# Only opening tag found - everything else is thinking content
thinking_content = string[content_start:]
return thinking_content, ""
@functools.lru_cache(maxsize=None)
def convert_to_markdown(string):
def convert_to_markdown(string, message_id=None):
if not string:
return ""
# Use a default message ID if none provided
if message_id is None:
message_id = "unknown"
# Extract thinking block if present
thinking_content, remaining_content = extract_thinking_block(string)
# Process the main content
html_output = process_markdown_content(remaining_content)
# If thinking content was found, process it using the same function
if thinking_content is not None:
thinking_html = process_markdown_content(thinking_content)
# Generate unique ID for the thinking block
block_id = f"thinking-{message_id}-0"
# Check if thinking is complete or still in progress
is_streaming = not remaining_content
title_text = "Thinking..." if is_streaming else "Thought"
thinking_block = f'''
<details class="thinking-block" data-block-id="{block_id}" data-streaming="{str(is_streaming).lower()}" open>
<summary class="thinking-header">
<svg class="thinking-icon" width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8 1.33334C4.31868 1.33334 1.33334 4.31868 1.33334 8.00001C1.33334 11.6813 4.31868 14.6667 8 14.6667C11.6813 14.6667 14.6667 11.6813 14.6667 8.00001C14.6667 4.31868 11.6813 1.33334 8 1.33334Z" stroke="currentColor" stroke-width="1.33" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M8 10.6667V8.00001" stroke="currentColor" stroke-width="1.33" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M8 5.33334H8.00667" stroke="currentColor" stroke-width="1.33" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
<span class="thinking-title">{title_text}</span>
</summary>
<div class="thinking-content pretty_scrollbar">{thinking_html}</div>
</details>
'''
# Prepend the thinking block to the message HTML
html_output = thinking_block + html_output
return html_output
def process_markdown_content(string):
"""Process a string through the markdown conversion pipeline."""
if not string:
return ""
@ -209,15 +288,15 @@ def convert_to_markdown(string):
return html_output
def convert_to_markdown_wrapped(string, use_cache=True):
def convert_to_markdown_wrapped(string, message_id=None, use_cache=True):
'''
Used to avoid caching convert_to_markdown calls during streaming.
'''
if use_cache:
return convert_to_markdown(string)
return convert_to_markdown(string, message_id=message_id)
return convert_to_markdown.__wrapped__(string)
return convert_to_markdown.__wrapped__(string, message_id=message_id)
def generate_basic_html(string):
@ -273,7 +352,7 @@ def generate_instruct_html(history):
for i in range(len(history['visible'])):
row_visible = history['visible'][i]
row_internal = history['internal'][i]
converted_visible = [convert_to_markdown_wrapped(entry, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
if converted_visible[0]: # Don't display empty user messages
output += (
@ -320,7 +399,7 @@ def generate_cai_chat_html(history, name1, name2, style, character, reset_cache=
for i in range(len(history['visible'])):
row_visible = history['visible'][i]
row_internal = history['internal'][i]
converted_visible = [convert_to_markdown_wrapped(entry, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
if converted_visible[0]: # Don't display empty user messages
output += (
@ -360,7 +439,7 @@ def generate_chat_html(history, name1, name2, reset_cache=False):
for i in range(len(history['visible'])):
row_visible = history['visible'][i]
row_internal = history['internal'][i]
converted_visible = [convert_to_markdown_wrapped(entry, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
if converted_visible[0]: # Don't display empty user messages
output += (

View file

@ -59,7 +59,6 @@ settings = {
'seed': -1,
'custom_stopping_strings': '',
'custom_token_bans': '',
'show_after': '',
'negative_prompt': '',
'autoload_model': False,
'dark_theme': True,

View file

@ -207,7 +207,6 @@ def list_interface_input_elements():
'sampler_priority',
'custom_stopping_strings',
'custom_token_bans',
'show_after',
'negative_prompt',
'dry_sequence_breakers',
'grammar_string',

View file

@ -93,7 +93,6 @@ def create_ui(default_preset):
shared.gradio['sampler_priority'] = gr.Textbox(value=generate_params['sampler_priority'], lines=12, label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar'])
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=2, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"')
shared.gradio['custom_token_bans'] = gr.Textbox(value=shared.settings['custom_token_bans'] or None, label='Token bans', info='Token IDs to ban, separated by commas. The IDs can be found in the Default or Notebook tab.')
shared.gradio['show_after'] = gr.Textbox(value=shared.settings['show_after'] or None, label='Show after', info='Hide the reply before this text.', placeholder="</think>")
shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt', info='For CFG. Only used when guidance_scale is different than 1.', lines=3, elem_classes=['add_scrollbar'])
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.')
with gr.Row() as shared.gradio['grammar_file_row']: