mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-09 08:33:41 +00:00
Initial tool-calling support in the UI
This commit is contained in:
parent
980a9d1657
commit
cf9ad8eafe
5 changed files with 253 additions and 14 deletions
189
modules/chat.py
189
modules/chat.py
|
|
@ -298,6 +298,23 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
if entry_meta.get('tool_calls') and messages[insert_pos].get('role') == 'assistant':
|
||||
messages[insert_pos]['tool_calls'] = _deserialize_tool_call_arguments(entry_meta['tool_calls'])
|
||||
|
||||
# Expand tool_sequence from metadata (inserted AFTER assistant so that
|
||||
# the final order is: user → tool_calls → tool_results → final_answer)
|
||||
meta_key = f"assistant_{row_idx}"
|
||||
tool_seq = metadata.get(meta_key, {}).get('tool_sequence', [])
|
||||
if tool_seq:
|
||||
for item in reversed(tool_seq):
|
||||
if 'tool_calls' in item:
|
||||
messages.insert(insert_pos, {
|
||||
"role": "assistant", "content": "",
|
||||
"tool_calls": _deserialize_tool_call_arguments(item['tool_calls'])
|
||||
})
|
||||
elif item.get('role') == 'tool':
|
||||
messages.insert(insert_pos, {
|
||||
"role": "tool", "content": item['content'],
|
||||
"tool_call_id": item.get('tool_call_id', '')
|
||||
})
|
||||
|
||||
if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||
# Check for user message attachments in metadata
|
||||
user_key = f"user_{row_idx}"
|
||||
|
|
@ -367,6 +384,22 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
# Expand tool_sequence for the current entry (excluded from the
|
||||
# history loop during regenerate — needed so the model sees prior
|
||||
# tool calls and results when re-generating the final answer).
|
||||
current_tool_seq = metadata.get(f"assistant_{len(history)}", {}).get('tool_sequence', [])
|
||||
for item in current_tool_seq:
|
||||
if 'tool_calls' in item:
|
||||
messages.append({
|
||||
"role": "assistant", "content": "",
|
||||
"tool_calls": _deserialize_tool_call_arguments(item['tool_calls'])
|
||||
})
|
||||
elif item.get('role') == 'tool':
|
||||
messages.append({
|
||||
"role": "tool", "content": item['content'],
|
||||
"tool_call_id": item.get('tool_call_id', '')
|
||||
})
|
||||
|
||||
if impersonate and state['mode'] != 'chat-instruct':
|
||||
messages.append({"role": "user", "content": "fake user message replace me"})
|
||||
|
||||
|
|
@ -886,7 +919,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
}
|
||||
else:
|
||||
text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
|
||||
if regenerate:
|
||||
if regenerate and not state.get('_tool_turn'):
|
||||
row_idx = len(output['internal']) - 1
|
||||
|
||||
# Store the old response as a version before regenerating
|
||||
|
|
@ -984,7 +1017,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
|
||||
|
||||
# Keep version metadata in sync during streaming (for regeneration)
|
||||
if regenerate:
|
||||
if regenerate and not state.get('_tool_turn'):
|
||||
row_idx = len(output['internal']) - 1
|
||||
key = f"assistant_{row_idx}"
|
||||
current_idx = output['metadata'][key]['current_version_index']
|
||||
|
|
@ -1012,7 +1045,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
|
||||
|
||||
# Final sync for version metadata (in case streaming was disabled)
|
||||
if regenerate:
|
||||
if regenerate and not state.get('_tool_turn'):
|
||||
row_idx = len(output['internal']) - 1
|
||||
key = f"assistant_{row_idx}"
|
||||
current_idx = output['metadata'][key]['current_version_index']
|
||||
|
|
@ -1066,12 +1099,24 @@ def character_is_loaded(state, raise_exception=False):
|
|||
|
||||
def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
||||
'''
|
||||
Same as above but returns HTML for the UI
|
||||
Same as above but returns HTML for the UI.
|
||||
When tools are selected, wraps generation in a loop that detects
|
||||
tool calls, executes them, and re-generates until the model stops.
|
||||
All tool output is consolidated into a single visible chat bubble
|
||||
using metadata['assistant_N']['tool_sequence'].
|
||||
'''
|
||||
|
||||
if not character_is_loaded(state):
|
||||
return
|
||||
|
||||
# On regenerate, clear old tool_sequence metadata so it gets rebuilt
|
||||
if regenerate:
|
||||
history = state['history']
|
||||
meta = history.get('metadata', {})
|
||||
row_idx = len(history['internal']) - 1
|
||||
if row_idx >= 0:
|
||||
meta.get(f'assistant_{row_idx}', {}).pop('tool_sequence', None)
|
||||
|
||||
if state['start_with'] != '' and not _continue:
|
||||
if regenerate:
|
||||
text, state['history'] = remove_last_message(state['history'])
|
||||
|
|
@ -1081,23 +1126,139 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
send_dummy_message(text, state)
|
||||
send_dummy_reply(state['start_with'], state)
|
||||
|
||||
history = state['history']
|
||||
# Load tools if any are selected
|
||||
selected = state.get('selected_tools', [])
|
||||
if selected:
|
||||
from modules.tool_use import load_tools, execute_tool, generate_tool_call_id
|
||||
try:
|
||||
from extensions.openai.utils import parseToolCall
|
||||
except ImportError:
|
||||
logger.warning('Tool calling requires the openai extension for parseToolCall. Disabling tools.')
|
||||
selected = []
|
||||
|
||||
if selected:
|
||||
tool_defs, tool_executors = load_tools(selected)
|
||||
state['tools'] = tool_defs
|
||||
tool_func_names = [t['function']['name'] for t in tool_defs]
|
||||
else:
|
||||
tool_func_names = None
|
||||
|
||||
visible_prefix = [] # Accumulated tool call summaries + results
|
||||
last_save_time = time.monotonic()
|
||||
save_interval = 8
|
||||
for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)):
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history
|
||||
if i == 0:
|
||||
time.sleep(0.125) # We need this to make sure the first update goes through
|
||||
max_tool_turns = 10
|
||||
|
||||
current_time = time.monotonic()
|
||||
# Save on first iteration or if save_interval seconds have passed
|
||||
if i == 0 or (current_time - last_save_time) >= save_interval:
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
last_save_time = current_time
|
||||
for _tool_turn in range(max_tool_turns):
|
||||
history = state['history']
|
||||
|
||||
# Turn 0: use original flags; turns 2+: regenerate into the same entry
|
||||
if _tool_turn > 0:
|
||||
state['_tool_turn'] = True
|
||||
|
||||
regen = regenerate if _tool_turn == 0 else True
|
||||
cont = _continue if _tool_turn == 0 else False
|
||||
cur_text = text if _tool_turn == 0 else ''
|
||||
|
||||
for i, history in enumerate(generate_chat_reply(cur_text, state, regen, cont, loading_message=True, for_ui=True)):
|
||||
# Prepend accumulated tool output to visible reply
|
||||
if visible_prefix:
|
||||
history['visible'][-1][1] = '\n\n'.join(visible_prefix + [history['visible'][-1][1]])
|
||||
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history
|
||||
|
||||
if i == 0:
|
||||
time.sleep(0.125)
|
||||
|
||||
current_time = time.monotonic()
|
||||
if i == 0 or (current_time - last_save_time) >= save_interval:
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
last_save_time = current_time
|
||||
|
||||
# Early stop on tool call detection
|
||||
if tool_func_names and parseToolCall(history['internal'][-1][1], tool_func_names):
|
||||
break
|
||||
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
|
||||
# Check for tool calls
|
||||
if not tool_func_names or shared.stop_everything:
|
||||
break
|
||||
|
||||
answer = history['internal'][-1][1]
|
||||
parsed_calls = parseToolCall(answer, tool_func_names) if answer else None
|
||||
|
||||
if not parsed_calls:
|
||||
break # No tool calls — done
|
||||
|
||||
# --- Process tool calls ---
|
||||
row_idx = len(history['internal']) - 1
|
||||
meta = history.get('metadata', {})
|
||||
seq = meta.setdefault(f'assistant_{row_idx}', {}).setdefault('tool_sequence', [])
|
||||
|
||||
# Serialize tool calls
|
||||
serialized = []
|
||||
for tc in parsed_calls:
|
||||
tc['id'] = generate_tool_call_id()
|
||||
args = tc['function'].get('arguments', {})
|
||||
serialized.append({
|
||||
'id': tc['id'],
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': tc['function']['name'],
|
||||
'arguments': json.dumps(args) if isinstance(args, dict) else args
|
||||
}
|
||||
})
|
||||
|
||||
seq.append({'tool_calls': serialized})
|
||||
|
||||
# Clear internal (raw tool markup)
|
||||
history['internal'][-1][1] = ''
|
||||
|
||||
# Add call summary to visible prefix
|
||||
call_summary = ', '.join(f'{tc["function"]["name"]}(...)' for tc in parsed_calls)
|
||||
visible_prefix.append('Calling: ' + call_summary)
|
||||
|
||||
# Execute tools, store results
|
||||
for tc in parsed_calls:
|
||||
fn_name = tc['function']['name']
|
||||
fn_args = tc['function'].get('arguments', {})
|
||||
result = execute_tool(fn_name, fn_args, tool_executors)
|
||||
|
||||
seq.append({'role': 'tool', 'content': result, 'tool_call_id': tc['id']})
|
||||
try:
|
||||
pretty_result = json.dumps(json.loads(result), indent=2, ensure_ascii=False)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pretty_result = result
|
||||
|
||||
visible_prefix.append(f'**{fn_name}**\n```json\n{pretty_result}\n```')
|
||||
|
||||
# Show tool results
|
||||
history['visible'][-1][1] = '\n\n'.join(visible_prefix)
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
|
||||
state['history'] = history
|
||||
|
||||
state.pop('_tool_turn', None)
|
||||
state['history'] = history
|
||||
|
||||
# Sync version metadata so swipes show the full visible (with tool prefix)
|
||||
if visible_prefix and history.get('metadata'):
|
||||
row_idx = len(history['internal']) - 1
|
||||
key = f"assistant_{row_idx}"
|
||||
meta_entry = history['metadata'].get(key, {})
|
||||
if 'versions' in meta_entry and 'current_version_index' in meta_entry:
|
||||
current_idx = meta_entry['current_version_index']
|
||||
if current_idx < len(meta_entry['versions']):
|
||||
meta_entry['versions'][current_idx].update({
|
||||
'content': history['internal'][row_idx][1],
|
||||
'visible_content': history['visible'][row_idx][1]
|
||||
})
|
||||
|
||||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
|
||||
|
||||
|
||||
def remove_last_message(history):
|
||||
if 'metadata' not in history:
|
||||
history['metadata'] = {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue