mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-04 14:17:28 +00:00
UI: Add a collapsible thinking block to messages with <think> steps (#6902)
This commit is contained in:
parent
0dd71e78c9
commit
d35818f4e1
8 changed files with 238 additions and 27 deletions
|
|
@ -417,16 +417,8 @@ def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_
|
|||
yield history
|
||||
return
|
||||
|
||||
show_after = html.escape(state.get("show_after")) if state.get("show_after") else None
|
||||
for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message, for_ui=for_ui):
|
||||
if show_after:
|
||||
after = history["visible"][-1][1].partition(show_after)[2] or "*Is thinking...*"
|
||||
yield {
|
||||
'internal': history['internal'],
|
||||
'visible': history['visible'][:-1] + [[history['visible'][-1][0], after]]
|
||||
}
|
||||
else:
|
||||
yield history
|
||||
yield history
|
||||
|
||||
|
||||
def character_is_loaded(state, raise_exception=False):
|
||||
|
|
|
|||
|
|
@ -107,8 +107,87 @@ def replace_blockquote(m):
|
|||
return m.group().replace('\n', '\n> ').replace('\\begin{blockquote}', '').replace('\\end{blockquote}', '')
|
||||
|
||||
|
||||
def extract_thinking_block(string):
|
||||
"""Extract thinking blocks from the beginning of a string."""
|
||||
if not string:
|
||||
return None, string
|
||||
|
||||
THINK_START_TAG = "<think>"
|
||||
THINK_END_TAG = "</think>"
|
||||
|
||||
# Look for opening tag
|
||||
start_pos = string.lstrip().find(THINK_START_TAG)
|
||||
if start_pos == -1:
|
||||
return None, string
|
||||
|
||||
# Adjust start position to account for any leading whitespace
|
||||
start_pos = string.find(THINK_START_TAG)
|
||||
|
||||
# Find the content after the opening tag
|
||||
content_start = start_pos + len(THINK_START_TAG)
|
||||
|
||||
# Look for closing tag
|
||||
end_pos = string.find(THINK_END_TAG, content_start)
|
||||
|
||||
if end_pos != -1:
|
||||
# Both tags found - extract content between them
|
||||
thinking_content = string[content_start:end_pos]
|
||||
remaining_content = string[end_pos + len(THINK_END_TAG):]
|
||||
return thinking_content, remaining_content
|
||||
else:
|
||||
# Only opening tag found - everything else is thinking content
|
||||
thinking_content = string[content_start:]
|
||||
return thinking_content, ""
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def convert_to_markdown(string):
|
||||
def convert_to_markdown(string, message_id=None):
|
||||
if not string:
|
||||
return ""
|
||||
|
||||
# Use a default message ID if none provided
|
||||
if message_id is None:
|
||||
message_id = "unknown"
|
||||
|
||||
# Extract thinking block if present
|
||||
thinking_content, remaining_content = extract_thinking_block(string)
|
||||
|
||||
# Process the main content
|
||||
html_output = process_markdown_content(remaining_content)
|
||||
|
||||
# If thinking content was found, process it using the same function
|
||||
if thinking_content is not None:
|
||||
thinking_html = process_markdown_content(thinking_content)
|
||||
|
||||
# Generate unique ID for the thinking block
|
||||
block_id = f"thinking-{message_id}-0"
|
||||
|
||||
# Check if thinking is complete or still in progress
|
||||
is_streaming = not remaining_content
|
||||
title_text = "Thinking..." if is_streaming else "Thought"
|
||||
|
||||
thinking_block = f'''
|
||||
<details class="thinking-block" data-block-id="{block_id}" data-streaming="{str(is_streaming).lower()}" open>
|
||||
<summary class="thinking-header">
|
||||
<svg class="thinking-icon" width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M8 1.33334C4.31868 1.33334 1.33334 4.31868 1.33334 8.00001C1.33334 11.6813 4.31868 14.6667 8 14.6667C11.6813 14.6667 14.6667 11.6813 14.6667 8.00001C14.6667 4.31868 11.6813 1.33334 8 1.33334Z" stroke="currentColor" stroke-width="1.33" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M8 10.6667V8.00001" stroke="currentColor" stroke-width="1.33" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M8 5.33334H8.00667" stroke="currentColor" stroke-width="1.33" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
<span class="thinking-title">{title_text}</span>
|
||||
</summary>
|
||||
<div class="thinking-content pretty_scrollbar">{thinking_html}</div>
|
||||
</details>
|
||||
'''
|
||||
|
||||
# Prepend the thinking block to the message HTML
|
||||
html_output = thinking_block + html_output
|
||||
|
||||
return html_output
|
||||
|
||||
|
||||
def process_markdown_content(string):
|
||||
"""Process a string through the markdown conversion pipeline."""
|
||||
if not string:
|
||||
return ""
|
||||
|
||||
|
|
@ -209,15 +288,15 @@ def convert_to_markdown(string):
|
|||
return html_output
|
||||
|
||||
|
||||
def convert_to_markdown_wrapped(string, use_cache=True):
|
||||
def convert_to_markdown_wrapped(string, message_id=None, use_cache=True):
|
||||
'''
|
||||
Used to avoid caching convert_to_markdown calls during streaming.
|
||||
'''
|
||||
|
||||
if use_cache:
|
||||
return convert_to_markdown(string)
|
||||
return convert_to_markdown(string, message_id=message_id)
|
||||
|
||||
return convert_to_markdown.__wrapped__(string)
|
||||
return convert_to_markdown.__wrapped__(string, message_id=message_id)
|
||||
|
||||
|
||||
def generate_basic_html(string):
|
||||
|
|
@ -273,7 +352,7 @@ def generate_instruct_html(history):
|
|||
for i in range(len(history['visible'])):
|
||||
row_visible = history['visible'][i]
|
||||
row_internal = history['internal'][i]
|
||||
converted_visible = [convert_to_markdown_wrapped(entry, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
|
||||
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
|
||||
|
||||
if converted_visible[0]: # Don't display empty user messages
|
||||
output += (
|
||||
|
|
@ -320,7 +399,7 @@ def generate_cai_chat_html(history, name1, name2, style, character, reset_cache=
|
|||
for i in range(len(history['visible'])):
|
||||
row_visible = history['visible'][i]
|
||||
row_internal = history['internal'][i]
|
||||
converted_visible = [convert_to_markdown_wrapped(entry, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
|
||||
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
|
||||
|
||||
if converted_visible[0]: # Don't display empty user messages
|
||||
output += (
|
||||
|
|
@ -360,7 +439,7 @@ def generate_chat_html(history, name1, name2, reset_cache=False):
|
|||
for i in range(len(history['visible'])):
|
||||
row_visible = history['visible'][i]
|
||||
row_internal = history['internal'][i]
|
||||
converted_visible = [convert_to_markdown_wrapped(entry, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
|
||||
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
|
||||
|
||||
if converted_visible[0]: # Don't display empty user messages
|
||||
output += (
|
||||
|
|
|
|||
|
|
@ -59,7 +59,6 @@ settings = {
|
|||
'seed': -1,
|
||||
'custom_stopping_strings': '',
|
||||
'custom_token_bans': '',
|
||||
'show_after': '',
|
||||
'negative_prompt': '',
|
||||
'autoload_model': False,
|
||||
'dark_theme': True,
|
||||
|
|
|
|||
|
|
@ -207,7 +207,6 @@ def list_interface_input_elements():
|
|||
'sampler_priority',
|
||||
'custom_stopping_strings',
|
||||
'custom_token_bans',
|
||||
'show_after',
|
||||
'negative_prompt',
|
||||
'dry_sequence_breakers',
|
||||
'grammar_string',
|
||||
|
|
|
|||
|
|
@ -93,7 +93,6 @@ def create_ui(default_preset):
|
|||
shared.gradio['sampler_priority'] = gr.Textbox(value=generate_params['sampler_priority'], lines=12, label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar'])
|
||||
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=2, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"')
|
||||
shared.gradio['custom_token_bans'] = gr.Textbox(value=shared.settings['custom_token_bans'] or None, label='Token bans', info='Token IDs to ban, separated by commas. The IDs can be found in the Default or Notebook tab.')
|
||||
shared.gradio['show_after'] = gr.Textbox(value=shared.settings['show_after'] or None, label='Show after', info='Hide the reply before this text.', placeholder="</think>")
|
||||
shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt', info='For CFG. Only used when guidance_scale is different than 1.', lines=3, elem_classes=['add_scrollbar'])
|
||||
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.')
|
||||
with gr.Row() as shared.gradio['grammar_file_row']:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue