UI: Add a collapsible thinking block to messages with <think> steps (#6902)

This commit is contained in:
oobabooga 2025-04-25 18:02:02 -03:00 committed by GitHub
parent 0dd71e78c9
commit d35818f4e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 238 additions and 27 deletions

View file

@ -625,19 +625,19 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
width: 100%;
overflow-y: visible;
}
.message {
break-inside: avoid;
}
.gradio-container {
overflow: visible;
}
.tab-nav {
display: none !important;
}
#chat-tab > :first-child {
max-width: unset;
}
@ -1308,3 +1308,77 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
padding-left: 1rem;
padding-right: 1rem;
}
/* Thinking blocks styling */
.thinking-block {
margin-bottom: 12px;
border-radius: 8px;
border: 1px solid rgba(0, 0, 0, 0.1);
background-color: var(--light-theme-gray);
overflow: hidden;
}
.dark .thinking-block {
background-color: var(--darker-gray);
}
.thinking-header {
display: flex;
align-items: center;
padding: 10px 16px;
cursor: pointer;
user-select: none;
font-size: 14px;
color: rgba(0, 0, 0, 0.7);
transition: background-color 0.2s;
}
.thinking-header:hover {
background-color: rgba(0, 0, 0, 0.03);
}
.thinking-header::-webkit-details-marker {
display: none;
}
.thinking-icon {
margin-right: 8px;
color: rgba(0, 0, 0, 0.5);
}
.thinking-title {
font-weight: 500;
}
.thinking-content {
padding: 12px 16px;
border-top: 1px solid rgba(0, 0, 0, 0.07);
color: rgba(0, 0, 0, 0.7);
font-size: 14px;
line-height: 1.5;
overflow-wrap: break-word;
max-height: 300px;
overflow-y: scroll;
contain: layout;
}
/* Animation for opening thinking blocks */
@keyframes fadeIn {
from { opacity: 0; }
to { opacity: 1; }
}
.thinking-block[open] .thinking-content {
animation: fadeIn 0.3s ease-out;
}
/* Additional style for in-progress thinking */
.thinking-block[data-streaming="true"] .thinking-title {
animation: pulse 1.5s infinite;
}
@keyframes pulse {
0% { opacity: 0.6; }
50% { opacity: 1; }
100% { opacity: 0.6; }
}

View file

@ -31,24 +31,94 @@ function removeLastClick() {
}
function handleMorphdomUpdate(text) {
// Track closed blocks
const closedBlocks = new Set();
document.querySelectorAll(".thinking-block").forEach(block => {
const blockId = block.getAttribute("data-block-id");
// If block exists and is not open, add to closed set
if (blockId && !block.hasAttribute("open")) {
closedBlocks.add(blockId);
}
});
// Store scroll positions for any open blocks
const scrollPositions = {};
document.querySelectorAll(".thinking-block[open]").forEach(block => {
const content = block.querySelector(".thinking-content");
const blockId = block.getAttribute("data-block-id");
if (content && blockId) {
const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5;
scrollPositions[blockId] = {
position: content.scrollTop,
isAtBottom: isAtBottom
};
}
});
morphdom(
document.getElementById("chat").parentNode,
"<div class=\"prose svelte-1ybaih5\">" + text + "</div>",
{
onBeforeElUpdated: function(fromEl, toEl) {
// Preserve code highlighting
if (fromEl.tagName === "PRE" && fromEl.querySelector("code[data-highlighted]")) {
const fromCode = fromEl.querySelector("code");
const toCode = toEl.querySelector("code");
if (fromCode && toCode && fromCode.textContent === toCode.textContent) {
// If the <code> content is the same, preserve the entire <pre> element
toEl.className = fromEl.className;
toEl.innerHTML = fromEl.innerHTML;
return false; // Skip updating the <pre> element
return false;
}
}
// For thinking blocks, respect closed state
if (fromEl.classList && fromEl.classList.contains("thinking-block") &&
toEl.classList && toEl.classList.contains("thinking-block")) {
const blockId = toEl.getAttribute("data-block-id");
// If this block was closed by user, keep it closed
if (blockId && closedBlocks.has(blockId)) {
toEl.removeAttribute("open");
}
}
return !fromEl.isEqualNode(toEl);
},
onElUpdated: function(el) {
// Restore scroll positions for open thinking blocks
if (el.classList && el.classList.contains("thinking-block") && el.hasAttribute("open")) {
const blockId = el.getAttribute("data-block-id");
const content = el.querySelector(".thinking-content");
if (content && blockId && scrollPositions[blockId]) {
setTimeout(() => {
if (scrollPositions[blockId].isAtBottom) {
content.scrollTop = content.scrollHeight;
} else {
content.scrollTop = scrollPositions[blockId].position;
}
}, 0);
}
}
return !fromEl.isEqualNode(toEl); // Update only if nodes differ
}
}
);
// Add toggle listeners for new blocks
document.querySelectorAll(".thinking-block").forEach(block => {
if (!block._hasToggleListener) {
block.addEventListener("toggle", function(e) {
if (this.open) {
const content = this.querySelector(".thinking-content");
if (content) {
setTimeout(() => {
content.scrollTop = content.scrollHeight;
}, 0);
}
}
});
block._hasToggleListener = true;
}
});
}

View file

@ -417,16 +417,8 @@ def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_
yield history
return
show_after = html.escape(state.get("show_after")) if state.get("show_after") else None
for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message, for_ui=for_ui):
if show_after:
after = history["visible"][-1][1].partition(show_after)[2] or "*Is thinking...*"
yield {
'internal': history['internal'],
'visible': history['visible'][:-1] + [[history['visible'][-1][0], after]]
}
else:
yield history
yield history
def character_is_loaded(state, raise_exception=False):

View file

@ -107,8 +107,87 @@ def replace_blockquote(m):
return m.group().replace('\n', '\n> ').replace('\\begin{blockquote}', '').replace('\\end{blockquote}', '')
def extract_thinking_block(string):
"""Extract thinking blocks from the beginning of a string."""
if not string:
return None, string
THINK_START_TAG = "&lt;think&gt;"
THINK_END_TAG = "&lt;/think&gt;"
# Look for opening tag
start_pos = string.lstrip().find(THINK_START_TAG)
if start_pos == -1:
return None, string
# Adjust start position to account for any leading whitespace
start_pos = string.find(THINK_START_TAG)
# Find the content after the opening tag
content_start = start_pos + len(THINK_START_TAG)
# Look for closing tag
end_pos = string.find(THINK_END_TAG, content_start)
if end_pos != -1:
# Both tags found - extract content between them
thinking_content = string[content_start:end_pos]
remaining_content = string[end_pos + len(THINK_END_TAG):]
return thinking_content, remaining_content
else:
# Only opening tag found - everything else is thinking content
thinking_content = string[content_start:]
return thinking_content, ""
@functools.lru_cache(maxsize=None)
def convert_to_markdown(string):
def convert_to_markdown(string, message_id=None):
if not string:
return ""
# Use a default message ID if none provided
if message_id is None:
message_id = "unknown"
# Extract thinking block if present
thinking_content, remaining_content = extract_thinking_block(string)
# Process the main content
html_output = process_markdown_content(remaining_content)
# If thinking content was found, process it using the same function
if thinking_content is not None:
thinking_html = process_markdown_content(thinking_content)
# Generate unique ID for the thinking block
block_id = f"thinking-{message_id}-0"
# Check if thinking is complete or still in progress
is_streaming = not remaining_content
title_text = "Thinking..." if is_streaming else "Thought"
thinking_block = f'''
<details class="thinking-block" data-block-id="{block_id}" data-streaming="{str(is_streaming).lower()}" open>
<summary class="thinking-header">
<svg class="thinking-icon" width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8 1.33334C4.31868 1.33334 1.33334 4.31868 1.33334 8.00001C1.33334 11.6813 4.31868 14.6667 8 14.6667C11.6813 14.6667 14.6667 11.6813 14.6667 8.00001C14.6667 4.31868 11.6813 1.33334 8 1.33334Z" stroke="currentColor" stroke-width="1.33" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M8 10.6667V8.00001" stroke="currentColor" stroke-width="1.33" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M8 5.33334H8.00667" stroke="currentColor" stroke-width="1.33" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
<span class="thinking-title">{title_text}</span>
</summary>
<div class="thinking-content pretty_scrollbar">{thinking_html}</div>
</details>
'''
# Prepend the thinking block to the message HTML
html_output = thinking_block + html_output
return html_output
def process_markdown_content(string):
"""Process a string through the markdown conversion pipeline."""
if not string:
return ""
@ -209,15 +288,15 @@ def convert_to_markdown(string):
return html_output
def convert_to_markdown_wrapped(string, use_cache=True):
def convert_to_markdown_wrapped(string, message_id=None, use_cache=True):
'''
Used to avoid caching convert_to_markdown calls during streaming.
'''
if use_cache:
return convert_to_markdown(string)
return convert_to_markdown(string, message_id=message_id)
return convert_to_markdown.__wrapped__(string)
return convert_to_markdown.__wrapped__(string, message_id=message_id)
def generate_basic_html(string):
@ -273,7 +352,7 @@ def generate_instruct_html(history):
for i in range(len(history['visible'])):
row_visible = history['visible'][i]
row_internal = history['internal'][i]
converted_visible = [convert_to_markdown_wrapped(entry, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
if converted_visible[0]: # Don't display empty user messages
output += (
@ -320,7 +399,7 @@ def generate_cai_chat_html(history, name1, name2, style, character, reset_cache=
for i in range(len(history['visible'])):
row_visible = history['visible'][i]
row_internal = history['internal'][i]
converted_visible = [convert_to_markdown_wrapped(entry, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
if converted_visible[0]: # Don't display empty user messages
output += (
@ -360,7 +439,7 @@ def generate_chat_html(history, name1, name2, reset_cache=False):
for i in range(len(history['visible'])):
row_visible = history['visible'][i]
row_internal = history['internal'][i]
converted_visible = [convert_to_markdown_wrapped(entry, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
if converted_visible[0]: # Don't display empty user messages
output += (

View file

@ -59,7 +59,6 @@ settings = {
'seed': -1,
'custom_stopping_strings': '',
'custom_token_bans': '',
'show_after': '',
'negative_prompt': '',
'autoload_model': False,
'dark_theme': True,

View file

@ -207,7 +207,6 @@ def list_interface_input_elements():
'sampler_priority',
'custom_stopping_strings',
'custom_token_bans',
'show_after',
'negative_prompt',
'dry_sequence_breakers',
'grammar_string',

View file

@ -93,7 +93,6 @@ def create_ui(default_preset):
shared.gradio['sampler_priority'] = gr.Textbox(value=generate_params['sampler_priority'], lines=12, label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar'])
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=2, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"')
shared.gradio['custom_token_bans'] = gr.Textbox(value=shared.settings['custom_token_bans'] or None, label='Token bans', info='Token IDs to ban, separated by commas. The IDs can be found in the Default or Notebook tab.')
shared.gradio['show_after'] = gr.Textbox(value=shared.settings['show_after'] or None, label='Show after', info='Hide the reply before this text.', placeholder="</think>")
shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt', info='For CFG. Only used when guidance_scale is different than 1.', lines=3, elem_classes=['add_scrollbar'])
shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=generate_params['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.')
with gr.Row() as shared.gradio['grammar_file_row']:

View file

@ -29,7 +29,6 @@ truncation_length: 8192
seed: -1
custom_stopping_strings: ''
custom_token_bans: ''
show_after: ''
negative_prompt: ''
autoload_model: false
dark_theme: true