mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-18 03:14:39 +01:00
UI: Clean up tool calling code
This commit is contained in:
parent
4c7a56c18d
commit
286ae475f6
|
|
@ -1056,9 +1056,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
|
||||
if _continue:
|
||||
# Reprocess the entire internal text for extensions (like translation).
|
||||
# Skip the rebuild when the visible text contains <tool_call> markers,
|
||||
# Skip entirely when the visible text contains <tool_call> markers,
|
||||
# since those only exist in visible (internal is cleared after each tool
|
||||
# execution) and rebuilding from internal would destroy them.
|
||||
# execution) and rebuilding from internal would destroy them. Output
|
||||
# extensions also can't handle the raw <tool_call> markup safely.
|
||||
if '<tool_call>' not in output['visible'][-1][1]:
|
||||
full_internal = output['internal'][-1][1]
|
||||
if state['mode'] in ['chat', 'chat-instruct']:
|
||||
|
|
@ -1069,9 +1070,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
full_visible = html.escape(full_visible)
|
||||
if not state.get('_skip_output_extensions'):
|
||||
output['visible'][-1][1] = apply_extensions('output', full_visible, state, is_chat=True)
|
||||
else:
|
||||
if not state.get('_skip_output_extensions'):
|
||||
output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
|
||||
else:
|
||||
if not state.get('_skip_output_extensions'):
|
||||
output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
|
||||
|
|
@ -1141,16 +1139,6 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
if not character_is_loaded(state):
|
||||
return
|
||||
|
||||
# On regenerate, clear old tool_sequence metadata so it gets rebuilt.
|
||||
# Save it first so it can be stored per-version below.
|
||||
_old_tool_sequence = None
|
||||
if regenerate:
|
||||
history = state['history']
|
||||
meta = history.get('metadata', {})
|
||||
row_idx = len(history['internal']) - 1
|
||||
if row_idx >= 0:
|
||||
_old_tool_sequence = meta.get(f'assistant_{row_idx}', {}).pop('tool_sequence', None)
|
||||
|
||||
if state['start_with'] != '' and not _continue:
|
||||
if regenerate:
|
||||
text, state['history'] = remove_last_message(state['history'])
|
||||
|
|
@ -1160,13 +1148,25 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
send_dummy_message(text, state)
|
||||
send_dummy_reply(state['start_with'], state)
|
||||
|
||||
# On regenerate, clear old tool_sequence metadata so it gets rebuilt.
|
||||
# Save it first so it can be stored per-version below.
|
||||
# This must happen after the start_with logic above, which may remove
|
||||
# and re-add messages, changing which row we operate on.
|
||||
_old_tool_sequence = None
|
||||
if regenerate:
|
||||
history = state['history']
|
||||
meta = history.get('metadata', {})
|
||||
row_idx = len(history['internal']) - 1
|
||||
if row_idx >= 0:
|
||||
_old_tool_sequence = meta.get(f'assistant_{row_idx}', {}).pop('tool_sequence', None)
|
||||
|
||||
# Load tools if any are selected
|
||||
selected = state.get('selected_tools', [])
|
||||
parseToolCall = None
|
||||
if selected:
|
||||
from modules.tool_use import load_tools, execute_tool, generate_tool_call_id
|
||||
from modules.tool_use import load_tools, execute_tool
|
||||
try:
|
||||
from extensions.openai.utils import parseToolCall
|
||||
from extensions.openai.utils import parseToolCall, getToolCallId
|
||||
except ImportError:
|
||||
logger.warning('Tool calling requires the openai extension for parseToolCall. Disabling tools.')
|
||||
selected = []
|
||||
|
|
@ -1277,7 +1277,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
serialized = []
|
||||
tc_headers = []
|
||||
for tc in parsed_calls:
|
||||
tc['id'] = generate_tool_call_id()
|
||||
tc['id'] = getToolCallId()
|
||||
fn_name = tc['function']['name']
|
||||
fn_args = tc['function'].get('arguments', {})
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import importlib.util
|
||||
import json
|
||||
import random
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
|
@ -49,12 +48,6 @@ def load_tools(selected_names):
|
|||
return tool_defs, executors
|
||||
|
||||
|
||||
def generate_tool_call_id():
|
||||
"""Generate a unique tool call ID (e.g. 'call_a1b2c3d4')."""
|
||||
chars = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
return "call_" + "".join(random.choice(chars) for _ in range(8))
|
||||
|
||||
|
||||
def execute_tool(func_name, arguments, executors):
|
||||
"""Execute a tool by function name. Returns result as a JSON string."""
|
||||
fn = executors.get(func_name)
|
||||
|
|
|
|||
Loading…
Reference in a new issue