Initial tool-calling support in the UI

This commit is contained in:
oobabooga 2026-03-12 01:15:49 -03:00
parent 980a9d1657
commit cf9ad8eafe
5 changed files with 253 additions and 14 deletions

View file

@ -298,6 +298,23 @@ def generate_chat_prompt(user_input, state, **kwargs):
if entry_meta.get('tool_calls') and messages[insert_pos].get('role') == 'assistant':
messages[insert_pos]['tool_calls'] = _deserialize_tool_call_arguments(entry_meta['tool_calls'])
# Expand tool_sequence from metadata (inserted AFTER assistant so that
# the final order is: user → tool_calls → tool_results → final_answer)
meta_key = f"assistant_{row_idx}"
tool_seq = metadata.get(meta_key, {}).get('tool_sequence', [])
if tool_seq:
for item in reversed(tool_seq):
if 'tool_calls' in item:
messages.insert(insert_pos, {
"role": "assistant", "content": "",
"tool_calls": _deserialize_tool_call_arguments(item['tool_calls'])
})
elif item.get('role') == 'tool':
messages.insert(insert_pos, {
"role": "tool", "content": item['content'],
"tool_call_id": item.get('tool_call_id', '')
})
if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
# Check for user message attachments in metadata
user_key = f"user_{row_idx}"
@ -367,6 +384,22 @@ def generate_chat_prompt(user_input, state, **kwargs):
messages.append({"role": "user", "content": user_input})
# Expand tool_sequence for the current entry (excluded from the
# history loop during regenerate — needed so the model sees prior
# tool calls and results when re-generating the final answer).
current_tool_seq = metadata.get(f"assistant_{len(history)}", {}).get('tool_sequence', [])
for item in current_tool_seq:
if 'tool_calls' in item:
messages.append({
"role": "assistant", "content": "",
"tool_calls": _deserialize_tool_call_arguments(item['tool_calls'])
})
elif item.get('role') == 'tool':
messages.append({
"role": "tool", "content": item['content'],
"tool_call_id": item.get('tool_call_id', '')
})
if impersonate and state['mode'] != 'chat-instruct':
messages.append({"role": "user", "content": "fake user message replace me"})
@ -886,7 +919,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
}
else:
text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
if regenerate:
if regenerate and not state.get('_tool_turn'):
row_idx = len(output['internal']) - 1
# Store the old response as a version before regenerating
@ -984,7 +1017,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
# Keep version metadata in sync during streaming (for regeneration)
if regenerate:
if regenerate and not state.get('_tool_turn'):
row_idx = len(output['internal']) - 1
key = f"assistant_{row_idx}"
current_idx = output['metadata'][key]['current_version_index']
@ -1012,7 +1045,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
# Final sync for version metadata (in case streaming was disabled)
if regenerate:
if regenerate and not state.get('_tool_turn'):
row_idx = len(output['internal']) - 1
key = f"assistant_{row_idx}"
current_idx = output['metadata'][key]['current_version_index']
@ -1066,12 +1099,24 @@ def character_is_loaded(state, raise_exception=False):
def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
'''
Same as above but returns HTML for the UI
Same as above but returns HTML for the UI.
When tools are selected, wraps generation in a loop that detects
tool calls, executes them, and re-generates until the model stops.
All tool output is consolidated into a single visible chat bubble
using metadata['assistant_N']['tool_sequence'].
'''
if not character_is_loaded(state):
return
# On regenerate, clear old tool_sequence metadata so it gets rebuilt
if regenerate:
history = state['history']
meta = history.get('metadata', {})
row_idx = len(history['internal']) - 1
if row_idx >= 0:
meta.get(f'assistant_{row_idx}', {}).pop('tool_sequence', None)
if state['start_with'] != '' and not _continue:
if regenerate:
text, state['history'] = remove_last_message(state['history'])
@ -1081,23 +1126,139 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
send_dummy_message(text, state)
send_dummy_reply(state['start_with'], state)
history = state['history']
# Load tools if any are selected
selected = state.get('selected_tools', [])
if selected:
from modules.tool_use import load_tools, execute_tool, generate_tool_call_id
try:
from extensions.openai.utils import parseToolCall
except ImportError:
logger.warning('Tool calling requires the openai extension for parseToolCall. Disabling tools.')
selected = []
if selected:
tool_defs, tool_executors = load_tools(selected)
state['tools'] = tool_defs
tool_func_names = [t['function']['name'] for t in tool_defs]
else:
tool_func_names = None
visible_prefix = [] # Accumulated tool call summaries + results
last_save_time = time.monotonic()
save_interval = 8
for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)):
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history
if i == 0:
time.sleep(0.125) # We need this to make sure the first update goes through
max_tool_turns = 10
current_time = time.monotonic()
# Save on first iteration or if save_interval seconds have passed
if i == 0 or (current_time - last_save_time) >= save_interval:
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
last_save_time = current_time
for _tool_turn in range(max_tool_turns):
history = state['history']
# Turn 0: use original flags; turns 2+: regenerate into the same entry
if _tool_turn > 0:
state['_tool_turn'] = True
regen = regenerate if _tool_turn == 0 else True
cont = _continue if _tool_turn == 0 else False
cur_text = text if _tool_turn == 0 else ''
for i, history in enumerate(generate_chat_reply(cur_text, state, regen, cont, loading_message=True, for_ui=True)):
# Prepend accumulated tool output to visible reply
if visible_prefix:
history['visible'][-1][1] = '\n\n'.join(visible_prefix + [history['visible'][-1][1]])
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history
if i == 0:
time.sleep(0.125)
current_time = time.monotonic()
if i == 0 or (current_time - last_save_time) >= save_interval:
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
last_save_time = current_time
# Early stop on tool call detection
if tool_func_names and parseToolCall(history['internal'][-1][1], tool_func_names):
break
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
# Check for tool calls
if not tool_func_names or shared.stop_everything:
break
answer = history['internal'][-1][1]
parsed_calls = parseToolCall(answer, tool_func_names) if answer else None
if not parsed_calls:
break # No tool calls — done
# --- Process tool calls ---
row_idx = len(history['internal']) - 1
meta = history.get('metadata', {})
seq = meta.setdefault(f'assistant_{row_idx}', {}).setdefault('tool_sequence', [])
# Serialize tool calls
serialized = []
for tc in parsed_calls:
tc['id'] = generate_tool_call_id()
args = tc['function'].get('arguments', {})
serialized.append({
'id': tc['id'],
'type': 'function',
'function': {
'name': tc['function']['name'],
'arguments': json.dumps(args) if isinstance(args, dict) else args
}
})
seq.append({'tool_calls': serialized})
# Clear internal (raw tool markup)
history['internal'][-1][1] = ''
# Add call summary to visible prefix
call_summary = ', '.join(f'{tc["function"]["name"]}(...)' for tc in parsed_calls)
visible_prefix.append('Calling: ' + call_summary)
# Execute tools, store results
for tc in parsed_calls:
fn_name = tc['function']['name']
fn_args = tc['function'].get('arguments', {})
result = execute_tool(fn_name, fn_args, tool_executors)
seq.append({'role': 'tool', 'content': result, 'tool_call_id': tc['id']})
try:
pretty_result = json.dumps(json.loads(result), indent=2, ensure_ascii=False)
except (json.JSONDecodeError, TypeError):
pretty_result = result
visible_prefix.append(f'**{fn_name}**\n```json\n{pretty_result}\n```')
# Show tool results
history['visible'][-1][1] = '\n\n'.join(visible_prefix)
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
state['history'] = history
state.pop('_tool_turn', None)
state['history'] = history
# Sync version metadata so swipes show the full visible (with tool prefix)
if visible_prefix and history.get('metadata'):
row_idx = len(history['internal']) - 1
key = f"assistant_{row_idx}"
meta_entry = history['metadata'].get(key, {})
if 'versions' in meta_entry and 'current_version_index' in meta_entry:
current_idx = meta_entry['current_version_index']
if current_idx < len(meta_entry['versions']):
meta_entry['versions'][current_idx].update({
'content': history['internal'][row_idx][1],
'visible_content': history['visible'][row_idx][1]
})
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
def remove_last_message(history):
if 'metadata' not in history:
history['metadata'] = {}

View file

@ -263,6 +263,7 @@ settings = {
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>". Reply directly, without starting the reply with the character name.\n\n<|prompt|>',
'enable_web_search': False,
'web_search_pages': 3,
'selected_tools': [],
'prompt-notebook': '',
'preset': 'Qwen3 - Thinking' if (user_data_dir / 'presets/Qwen3 - Thinking.yaml').exists() else None,
'max_new_tokens': 512,

70
modules/tool_use.py Normal file
View file

@ -0,0 +1,70 @@
import importlib.util
import json
import random
from modules import shared
from modules.logging_colors import logger
def get_available_tools():
"""Return sorted list of tool script names from user_data/tools/*.py."""
tools_dir = shared.user_data_dir / 'tools'
tools_dir.mkdir(parents=True, exist_ok=True)
return sorted(p.stem for p in tools_dir.glob('*.py'))
def load_tools(selected_names):
"""
Import selected tool scripts and return their definitions and executors.
Returns (tool_defs, executors) where:
- tool_defs: list of OpenAI-format tool dicts
- executors: dict mapping function_name -> execute callable
"""
tool_defs = []
executors = {}
for name in selected_names:
path = shared.user_data_dir / 'tools' / f'{name}.py'
if not path.exists():
continue
try:
spec = importlib.util.spec_from_file_location(f"tool_{name}", str(path))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
except Exception:
logger.exception(f'Failed to load tool script "{name}"')
continue
tool_def = getattr(module, 'tool', None)
execute_fn = getattr(module, 'execute', None)
if tool_def is None or execute_fn is None:
logger.warning(f'Tool "{name}" is missing a "tool" dict or "execute" function.')
continue
func_name = tool_def.get('function', {}).get('name', name)
tool_defs.append(tool_def)
executors[func_name] = execute_fn
return tool_defs, executors
def generate_tool_call_id():
"""Generate a unique tool call ID (e.g. 'call_a1b2c3d4')."""
chars = "abcdefghijklmnopqrstuvwxyz0123456789"
return "call_" + "".join(random.choice(chars) for _ in range(8))
def execute_tool(func_name, arguments, executors):
"""Execute a tool by function name. Returns result as a JSON string."""
fn = executors.get(func_name)
if fn is None:
return json.dumps({"error": f"Unknown tool: {func_name}"})
try:
if isinstance(arguments, str):
arguments = json.loads(arguments)
result = fn(arguments)
return json.dumps(result) if not isinstance(result, str) else result
except Exception as e:
logger.exception(f'Tool "{func_name}" execution failed')
return json.dumps({"error": str(e)})

View file

@ -199,6 +199,7 @@ def list_interface_input_elements():
'unique_id',
'textbox',
'start_with',
'selected_tools',
'mode',
'chat_style',
'chat-instruct_command',
@ -424,6 +425,7 @@ def setup_auto_save():
'user_bio',
'custom_system_message',
'chat_template_str',
'selected_tools',
# Parameters tab (ui_parameters.py) - Generation parameters
'preset_menu',

View file

@ -91,6 +91,11 @@ def create_ui():
gr.HTML("<div class='sidebar-vertical-separator'></div>")
from modules.tool_use import get_available_tools
shared.gradio['selected_tools'] = gr.CheckboxGroup(choices=get_available_tools(), value=[], label='Tools', info='Functions the model can call during generation.')
gr.HTML("<div class='sidebar-vertical-separator'></div>")
with gr.Row():
shared.gradio['mode'] = gr.Radio(choices=['instruct', 'chat-instruct', 'chat'], value=None, label='Mode', info='In instruct and chat-instruct modes, the template under Parameters > Instruction template is used.', elem_id='chat-mode')