mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-18 03:14:39 +01:00
Initial tool-calling support in the UI
This commit is contained in:
parent
980a9d1657
commit
cf9ad8eafe
189
modules/chat.py
189
modules/chat.py
|
|
@ -298,6 +298,23 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
if entry_meta.get('tool_calls') and messages[insert_pos].get('role') == 'assistant':
|
||||
messages[insert_pos]['tool_calls'] = _deserialize_tool_call_arguments(entry_meta['tool_calls'])
|
||||
|
||||
# Expand tool_sequence from metadata (inserted AFTER assistant so that
|
||||
# the final order is: user → tool_calls → tool_results → final_answer)
|
||||
meta_key = f"assistant_{row_idx}"
|
||||
tool_seq = metadata.get(meta_key, {}).get('tool_sequence', [])
|
||||
if tool_seq:
|
||||
for item in reversed(tool_seq):
|
||||
if 'tool_calls' in item:
|
||||
messages.insert(insert_pos, {
|
||||
"role": "assistant", "content": "",
|
||||
"tool_calls": _deserialize_tool_call_arguments(item['tool_calls'])
|
||||
})
|
||||
elif item.get('role') == 'tool':
|
||||
messages.insert(insert_pos, {
|
||||
"role": "tool", "content": item['content'],
|
||||
"tool_call_id": item.get('tool_call_id', '')
|
||||
})
|
||||
|
||||
if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||
# Check for user message attachments in metadata
|
||||
user_key = f"user_{row_idx}"
|
||||
|
|
@ -367,6 +384,22 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
# Expand tool_sequence for the current entry (excluded from the
|
||||
# history loop during regenerate — needed so the model sees prior
|
||||
# tool calls and results when re-generating the final answer).
|
||||
current_tool_seq = metadata.get(f"assistant_{len(history)}", {}).get('tool_sequence', [])
|
||||
for item in current_tool_seq:
|
||||
if 'tool_calls' in item:
|
||||
messages.append({
|
||||
"role": "assistant", "content": "",
|
||||
"tool_calls": _deserialize_tool_call_arguments(item['tool_calls'])
|
||||
})
|
||||
elif item.get('role') == 'tool':
|
||||
messages.append({
|
||||
"role": "tool", "content": item['content'],
|
||||
"tool_call_id": item.get('tool_call_id', '')
|
||||
})
|
||||
|
||||
if impersonate and state['mode'] != 'chat-instruct':
|
||||
messages.append({"role": "user", "content": "fake user message replace me"})
|
||||
|
||||
|
|
@ -886,7 +919,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
}
|
||||
else:
|
||||
text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
|
||||
if regenerate:
|
||||
if regenerate and not state.get('_tool_turn'):
|
||||
row_idx = len(output['internal']) - 1
|
||||
|
||||
# Store the old response as a version before regenerating
|
||||
|
|
@ -984,7 +1017,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
|
||||
|
||||
# Keep version metadata in sync during streaming (for regeneration)
|
||||
if regenerate:
|
||||
if regenerate and not state.get('_tool_turn'):
|
||||
row_idx = len(output['internal']) - 1
|
||||
key = f"assistant_{row_idx}"
|
||||
current_idx = output['metadata'][key]['current_version_index']
|
||||
|
|
@ -1012,7 +1045,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
|
||||
|
||||
# Final sync for version metadata (in case streaming was disabled)
|
||||
if regenerate:
|
||||
if regenerate and not state.get('_tool_turn'):
|
||||
row_idx = len(output['internal']) - 1
|
||||
key = f"assistant_{row_idx}"
|
||||
current_idx = output['metadata'][key]['current_version_index']
|
||||
|
|
@ -1066,12 +1099,24 @@ def character_is_loaded(state, raise_exception=False):
|
|||
|
||||
def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
||||
'''
|
||||
Same as above but returns HTML for the UI
|
||||
Same as above but returns HTML for the UI.
|
||||
When tools are selected, wraps generation in a loop that detects
|
||||
tool calls, executes them, and re-generates until the model stops.
|
||||
All tool output is consolidated into a single visible chat bubble
|
||||
using metadata['assistant_N']['tool_sequence'].
|
||||
'''
|
||||
|
||||
if not character_is_loaded(state):
|
||||
return
|
||||
|
||||
# On regenerate, clear old tool_sequence metadata so it gets rebuilt
|
||||
if regenerate:
|
||||
history = state['history']
|
||||
meta = history.get('metadata', {})
|
||||
row_idx = len(history['internal']) - 1
|
||||
if row_idx >= 0:
|
||||
meta.get(f'assistant_{row_idx}', {}).pop('tool_sequence', None)
|
||||
|
||||
if state['start_with'] != '' and not _continue:
|
||||
if regenerate:
|
||||
text, state['history'] = remove_last_message(state['history'])
|
||||
|
|
@ -1081,23 +1126,139 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
send_dummy_message(text, state)
|
||||
send_dummy_reply(state['start_with'], state)
|
||||
|
||||
history = state['history']
|
||||
# Load tools if any are selected
|
||||
selected = state.get('selected_tools', [])
|
||||
if selected:
|
||||
from modules.tool_use import load_tools, execute_tool, generate_tool_call_id
|
||||
try:
|
||||
from extensions.openai.utils import parseToolCall
|
||||
except ImportError:
|
||||
logger.warning('Tool calling requires the openai extension for parseToolCall. Disabling tools.')
|
||||
selected = []
|
||||
|
||||
if selected:
|
||||
tool_defs, tool_executors = load_tools(selected)
|
||||
state['tools'] = tool_defs
|
||||
tool_func_names = [t['function']['name'] for t in tool_defs]
|
||||
else:
|
||||
tool_func_names = None
|
||||
|
||||
visible_prefix = [] # Accumulated tool call summaries + results
|
||||
last_save_time = time.monotonic()
|
||||
save_interval = 8
|
||||
for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)):
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history
|
||||
if i == 0:
|
||||
time.sleep(0.125) # We need this to make sure the first update goes through
|
||||
max_tool_turns = 10
|
||||
|
||||
current_time = time.monotonic()
|
||||
# Save on first iteration or if save_interval seconds have passed
|
||||
if i == 0 or (current_time - last_save_time) >= save_interval:
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
last_save_time = current_time
|
||||
for _tool_turn in range(max_tool_turns):
|
||||
history = state['history']
|
||||
|
||||
# Turn 0: use original flags; turns 2+: regenerate into the same entry
|
||||
if _tool_turn > 0:
|
||||
state['_tool_turn'] = True
|
||||
|
||||
regen = regenerate if _tool_turn == 0 else True
|
||||
cont = _continue if _tool_turn == 0 else False
|
||||
cur_text = text if _tool_turn == 0 else ''
|
||||
|
||||
for i, history in enumerate(generate_chat_reply(cur_text, state, regen, cont, loading_message=True, for_ui=True)):
|
||||
# Prepend accumulated tool output to visible reply
|
||||
if visible_prefix:
|
||||
history['visible'][-1][1] = '\n\n'.join(visible_prefix + [history['visible'][-1][1]])
|
||||
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history
|
||||
|
||||
if i == 0:
|
||||
time.sleep(0.125)
|
||||
|
||||
current_time = time.monotonic()
|
||||
if i == 0 or (current_time - last_save_time) >= save_interval:
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
last_save_time = current_time
|
||||
|
||||
# Early stop on tool call detection
|
||||
if tool_func_names and parseToolCall(history['internal'][-1][1], tool_func_names):
|
||||
break
|
||||
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
|
||||
# Check for tool calls
|
||||
if not tool_func_names or shared.stop_everything:
|
||||
break
|
||||
|
||||
answer = history['internal'][-1][1]
|
||||
parsed_calls = parseToolCall(answer, tool_func_names) if answer else None
|
||||
|
||||
if not parsed_calls:
|
||||
break # No tool calls — done
|
||||
|
||||
# --- Process tool calls ---
|
||||
row_idx = len(history['internal']) - 1
|
||||
meta = history.get('metadata', {})
|
||||
seq = meta.setdefault(f'assistant_{row_idx}', {}).setdefault('tool_sequence', [])
|
||||
|
||||
# Serialize tool calls
|
||||
serialized = []
|
||||
for tc in parsed_calls:
|
||||
tc['id'] = generate_tool_call_id()
|
||||
args = tc['function'].get('arguments', {})
|
||||
serialized.append({
|
||||
'id': tc['id'],
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': tc['function']['name'],
|
||||
'arguments': json.dumps(args) if isinstance(args, dict) else args
|
||||
}
|
||||
})
|
||||
|
||||
seq.append({'tool_calls': serialized})
|
||||
|
||||
# Clear internal (raw tool markup)
|
||||
history['internal'][-1][1] = ''
|
||||
|
||||
# Add call summary to visible prefix
|
||||
call_summary = ', '.join(f'{tc["function"]["name"]}(...)' for tc in parsed_calls)
|
||||
visible_prefix.append('Calling: ' + call_summary)
|
||||
|
||||
# Execute tools, store results
|
||||
for tc in parsed_calls:
|
||||
fn_name = tc['function']['name']
|
||||
fn_args = tc['function'].get('arguments', {})
|
||||
result = execute_tool(fn_name, fn_args, tool_executors)
|
||||
|
||||
seq.append({'role': 'tool', 'content': result, 'tool_call_id': tc['id']})
|
||||
try:
|
||||
pretty_result = json.dumps(json.loads(result), indent=2, ensure_ascii=False)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pretty_result = result
|
||||
|
||||
visible_prefix.append(f'**{fn_name}**\n```json\n{pretty_result}\n```')
|
||||
|
||||
# Show tool results
|
||||
history['visible'][-1][1] = '\n\n'.join(visible_prefix)
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
|
||||
state['history'] = history
|
||||
|
||||
state.pop('_tool_turn', None)
|
||||
state['history'] = history
|
||||
|
||||
# Sync version metadata so swipes show the full visible (with tool prefix)
|
||||
if visible_prefix and history.get('metadata'):
|
||||
row_idx = len(history['internal']) - 1
|
||||
key = f"assistant_{row_idx}"
|
||||
meta_entry = history['metadata'].get(key, {})
|
||||
if 'versions' in meta_entry and 'current_version_index' in meta_entry:
|
||||
current_idx = meta_entry['current_version_index']
|
||||
if current_idx < len(meta_entry['versions']):
|
||||
meta_entry['versions'][current_idx].update({
|
||||
'content': history['internal'][row_idx][1],
|
||||
'visible_content': history['visible'][row_idx][1]
|
||||
})
|
||||
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
|
||||
|
||||
|
||||
def remove_last_message(history):
|
||||
if 'metadata' not in history:
|
||||
history['metadata'] = {}
|
||||
|
|
|
|||
|
|
@ -263,6 +263,7 @@ settings = {
|
|||
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>". Reply directly, without starting the reply with the character name.\n\n<|prompt|>',
|
||||
'enable_web_search': False,
|
||||
'web_search_pages': 3,
|
||||
'selected_tools': [],
|
||||
'prompt-notebook': '',
|
||||
'preset': 'Qwen3 - Thinking' if (user_data_dir / 'presets/Qwen3 - Thinking.yaml').exists() else None,
|
||||
'max_new_tokens': 512,
|
||||
|
|
|
|||
70
modules/tool_use.py
Normal file
70
modules/tool_use.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
import importlib.util
|
||||
import json
|
||||
import random
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
def get_available_tools():
|
||||
"""Return sorted list of tool script names from user_data/tools/*.py."""
|
||||
tools_dir = shared.user_data_dir / 'tools'
|
||||
tools_dir.mkdir(parents=True, exist_ok=True)
|
||||
return sorted(p.stem for p in tools_dir.glob('*.py'))
|
||||
|
||||
|
||||
def load_tools(selected_names):
|
||||
"""
|
||||
Import selected tool scripts and return their definitions and executors.
|
||||
Returns (tool_defs, executors) where:
|
||||
- tool_defs: list of OpenAI-format tool dicts
|
||||
- executors: dict mapping function_name -> execute callable
|
||||
"""
|
||||
tool_defs = []
|
||||
executors = {}
|
||||
for name in selected_names:
|
||||
path = shared.user_data_dir / 'tools' / f'{name}.py'
|
||||
if not path.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(f"tool_{name}", str(path))
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
except Exception:
|
||||
logger.exception(f'Failed to load tool script "{name}"')
|
||||
continue
|
||||
|
||||
tool_def = getattr(module, 'tool', None)
|
||||
execute_fn = getattr(module, 'execute', None)
|
||||
if tool_def is None or execute_fn is None:
|
||||
logger.warning(f'Tool "{name}" is missing a "tool" dict or "execute" function.')
|
||||
continue
|
||||
|
||||
func_name = tool_def.get('function', {}).get('name', name)
|
||||
tool_defs.append(tool_def)
|
||||
executors[func_name] = execute_fn
|
||||
|
||||
return tool_defs, executors
|
||||
|
||||
|
||||
def generate_tool_call_id():
|
||||
"""Generate a unique tool call ID (e.g. 'call_a1b2c3d4')."""
|
||||
chars = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
return "call_" + "".join(random.choice(chars) for _ in range(8))
|
||||
|
||||
|
||||
def execute_tool(func_name, arguments, executors):
|
||||
"""Execute a tool by function name. Returns result as a JSON string."""
|
||||
fn = executors.get(func_name)
|
||||
if fn is None:
|
||||
return json.dumps({"error": f"Unknown tool: {func_name}"})
|
||||
|
||||
try:
|
||||
if isinstance(arguments, str):
|
||||
arguments = json.loads(arguments)
|
||||
result = fn(arguments)
|
||||
return json.dumps(result) if not isinstance(result, str) else result
|
||||
except Exception as e:
|
||||
logger.exception(f'Tool "{func_name}" execution failed')
|
||||
return json.dumps({"error": str(e)})
|
||||
|
|
@ -199,6 +199,7 @@ def list_interface_input_elements():
|
|||
'unique_id',
|
||||
'textbox',
|
||||
'start_with',
|
||||
'selected_tools',
|
||||
'mode',
|
||||
'chat_style',
|
||||
'chat-instruct_command',
|
||||
|
|
@ -424,6 +425,7 @@ def setup_auto_save():
|
|||
'user_bio',
|
||||
'custom_system_message',
|
||||
'chat_template_str',
|
||||
'selected_tools',
|
||||
|
||||
# Parameters tab (ui_parameters.py) - Generation parameters
|
||||
'preset_menu',
|
||||
|
|
|
|||
|
|
@ -91,6 +91,11 @@ def create_ui():
|
|||
|
||||
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
||||
|
||||
from modules.tool_use import get_available_tools
|
||||
shared.gradio['selected_tools'] = gr.CheckboxGroup(choices=get_available_tools(), value=[], label='Tools', info='Functions the model can call during generation.')
|
||||
|
||||
gr.HTML("<div class='sidebar-vertical-separator'></div>")
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['mode'] = gr.Radio(choices=['instruct', 'chat-instruct', 'chat'], value=None, label='Mode', info='In instruct and chat-instruct modes, the template under Parameters > Instruction template is used.', elem_id='chat-mode')
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue