mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-18 03:14:39 +01:00
Optimize tool call detection
Avoids templates that don't contain a given necessary keyword
This commit is contained in:
parent
d0a4993cf4
commit
573617157a
|
|
@ -12,7 +12,7 @@ from pydantic import ValidationError
|
|||
from extensions.openai.errors import InvalidRequestError
|
||||
from extensions.openai.typing import ToolDefinition
|
||||
from extensions.openai.utils import debug_msg
|
||||
from modules.tool_parsing import get_tool_call_id, parse_tool_call
|
||||
from modules.tool_parsing import get_tool_call_id, parse_tool_call, detect_tool_call_format
|
||||
from modules import shared
|
||||
from modules.reasoning import extract_reasoning
|
||||
from modules.chat import (
|
||||
|
|
@ -484,6 +484,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
tool_calls = []
|
||||
end_last_tool_call = 0
|
||||
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
|
||||
_tool_parsers = None
|
||||
|
||||
# Filter supported_tools when tool_choice specifies a particular function
|
||||
if supported_tools and isinstance(tool_choice, dict):
|
||||
|
|
@ -491,11 +492,15 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
if specified_func and specified_func in supported_tools:
|
||||
supported_tools = [specified_func]
|
||||
|
||||
if supported_tools is not None:
|
||||
_template_str = generate_params.get('instruction_template_str', '') if generate_params.get('mode') == 'instruct' else generate_params.get('chat_template_str', '')
|
||||
_tool_parsers, _, _ = detect_tool_call_format(_template_str)
|
||||
|
||||
for a in generator:
|
||||
answer = a['internal'][-1][1]
|
||||
|
||||
if supported_tools is not None:
|
||||
tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
|
||||
tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools, parsers=_tool_parsers) if len(answer) > 0 else []
|
||||
if len(tool_call) > 0:
|
||||
for tc in tool_call:
|
||||
tc["id"] = get_tool_call_id()
|
||||
|
|
|
|||
|
|
@ -1035,8 +1035,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
_check_tool_markers = bool(state.get('tools'))
|
||||
_last_visible_before_tool_buffer = None
|
||||
if _check_tool_markers:
|
||||
from modules.tool_parsing import streaming_tool_buffer_check
|
||||
from modules.tool_parsing import streaming_tool_buffer_check, detect_tool_call_format
|
||||
_tool_names = [t['function']['name'] for t in state['tools'] if 'function' in t and 'name' in t['function']]
|
||||
_template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '')
|
||||
_, _streaming_markers, _check_bare_names = detect_tool_call_format(_template_str)
|
||||
|
||||
# Generate
|
||||
reply = None
|
||||
|
|
@ -1088,7 +1090,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
|
||||
if is_stream:
|
||||
if _check_tool_markers:
|
||||
if streaming_tool_buffer_check(output['internal'][-1][1], _tool_names):
|
||||
if streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names):
|
||||
continue
|
||||
_last_visible_before_tool_buffer = output['visible'][-1][1]
|
||||
|
||||
|
|
@ -1128,7 +1130,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
# visible text from before buffering started so raw markup doesn't flash
|
||||
# in the UI. The internal text is left intact so the caller can still
|
||||
# parse tool calls from it.
|
||||
if is_stream and _check_tool_markers and streaming_tool_buffer_check(output['internal'][-1][1], _tool_names):
|
||||
if is_stream and _check_tool_markers and streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names):
|
||||
output['visible'][-1][1] = _last_visible_before_tool_buffer or ''
|
||||
|
||||
yield output
|
||||
|
|
@ -1210,14 +1212,17 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
# Load tools if any are selected
|
||||
selected = state.get('selected_tools', [])
|
||||
parse_tool_call = None
|
||||
_tool_parsers = None
|
||||
if selected:
|
||||
from modules.tool_use import load_tools, execute_tool
|
||||
from modules.tool_parsing import parse_tool_call, get_tool_call_id
|
||||
from modules.tool_parsing import parse_tool_call, get_tool_call_id, detect_tool_call_format
|
||||
|
||||
if selected:
|
||||
tool_defs, tool_executors = load_tools(selected)
|
||||
state['tools'] = tool_defs
|
||||
tool_func_names = [t['function']['name'] for t in tool_defs]
|
||||
_template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '')
|
||||
_tool_parsers, _, _ = detect_tool_call_format(_template_str)
|
||||
else:
|
||||
tool_func_names = None
|
||||
|
||||
|
|
@ -1272,7 +1277,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
last_save_time = current_time
|
||||
|
||||
# Early stop on tool call detection
|
||||
if tool_func_names and parse_tool_call(history['internal'][-1][1], tool_func_names):
|
||||
if tool_func_names and parse_tool_call(history['internal'][-1][1], tool_func_names, parsers=_tool_parsers):
|
||||
break
|
||||
|
||||
# Save the model's visible output before re-applying visible_prefix,
|
||||
|
|
@ -1304,7 +1309,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
break
|
||||
|
||||
answer = history['internal'][-1][1]
|
||||
parsed_calls, content_prefix = parse_tool_call(answer, tool_func_names, return_prefix=True) if answer else (None, '')
|
||||
parsed_calls, content_prefix = parse_tool_call(answer, tool_func_names, return_prefix=True, parsers=_tool_parsers) if answer else (None, '')
|
||||
|
||||
if not parsed_calls:
|
||||
break # No tool calls — done
|
||||
|
|
|
|||
|
|
@ -9,9 +9,7 @@ def get_tool_call_id() -> str:
|
|||
return "call_" + "".join(b).lower()
|
||||
|
||||
|
||||
# Known opening markers for tool calls across model formats.
|
||||
# Used during streaming to buffer output that might be tool call markup,
|
||||
# preventing raw markup from leaking into displayed/streamed content.
|
||||
# All known opening markers for tool calls across model formats.
|
||||
TOOL_CALL_OPENING_MARKERS = [
|
||||
'<tool_call>',
|
||||
'<function_call>',
|
||||
|
|
@ -25,36 +23,47 @@ TOOL_CALL_OPENING_MARKERS = [
|
|||
'<|channel|>commentary',
|
||||
]
|
||||
|
||||
def streaming_tool_buffer_check(text, tool_names=None):
|
||||
|
||||
def streaming_tool_buffer_check(text, markers=None, tool_names=None, check_bare_names=False):
|
||||
'''
|
||||
Check whether streaming output should be withheld because it may
|
||||
contain tool-call markup.
|
||||
|
||||
Args:
|
||||
text: Full accumulated internal text.
|
||||
markers: Template-specific markers for partial-prefix matching.
|
||||
If None, falls back to TOOL_CALL_OPENING_MARKERS.
|
||||
tool_names: List of tool function names.
|
||||
check_bare_names: Whether to do partial-prefix matching on tool
|
||||
names (for models with unknown template format).
|
||||
'''
|
||||
# Full marker found → buffer permanently
|
||||
# Full marker found in text → buffer permanently.
|
||||
# Always checks ALL known markers regardless of template (cheap safety net).
|
||||
for marker in TOOL_CALL_OPENING_MARKERS:
|
||||
if marker in text:
|
||||
return True
|
||||
|
||||
# Bare function-name style (e.g. Devstral): "get_weather{...}"
|
||||
# Only match tool name followed by '{' to avoid false positives on
|
||||
# common words that happen to be tool names (e.g. "get", "search").
|
||||
# Bare function-name full match: "get_weather{...}" or "get_weather {...}"
|
||||
if tool_names:
|
||||
for name in tool_names:
|
||||
if name + '{' in text or name + ' {' in text:
|
||||
return True
|
||||
# Partial: text ends with tool name (or prefix of it) but '{' hasn't arrived yet
|
||||
|
||||
# Partial-prefix matching: only for template-specific markers.
|
||||
for marker in (markers if markers is not None else TOOL_CALL_OPENING_MARKERS):
|
||||
for prefix_len in range(min(len(marker) - 1, len(text)), 0, -1):
|
||||
if text.endswith(marker[:prefix_len]):
|
||||
return True
|
||||
|
||||
# Bare-name partial matching: only when template format is unknown.
|
||||
if check_bare_names and tool_names:
|
||||
for name in tool_names:
|
||||
if text.endswith(name):
|
||||
return True
|
||||
for prefix_len in range(min(len(name) - 1, len(text)), 0, -1):
|
||||
if text.endswith(name[:prefix_len]):
|
||||
return True
|
||||
|
||||
# Tail might be a partial marker forming across tokens
|
||||
for marker in TOOL_CALL_OPENING_MARKERS:
|
||||
for prefix_len in range(min(len(marker) - 1, len(text)), 0, -1):
|
||||
if text.endswith(marker[:prefix_len]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -488,7 +497,102 @@ def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]):
|
|||
return matches, start_pos
|
||||
|
||||
|
||||
def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False):
|
||||
# Format registry: maps template substrings to the parser and streaming
|
||||
# markers for that format. When a format's hints are NOT found in the
|
||||
# template, its parser and markers are excluded.
|
||||
TOOL_CALL_FORMATS = [
|
||||
{
|
||||
'template_hints': ['tool▁call▁begin', 'tool▁calls▁begin'],
|
||||
'parser': _parse_deep_seek_tool_calls,
|
||||
'markers': ['<|tool▁call▁begin|>', '<|tool▁calls▁begin|>'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['<|tool_call_begin|>', 'tool_calls_section'],
|
||||
'parser': _parse_kimi_tool_calls,
|
||||
'markers': ['<|tool_call_begin|>', '<|tool_calls_section_begin|>'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['to=functions.', '<|channel|>'],
|
||||
'parser': _parse_channel_tool_calls,
|
||||
'markers': ['to=functions.', '<|channel|>commentary'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['minimax:tool_call'],
|
||||
'parser': _parse_minimax_tool_calls,
|
||||
'markers': ['<minimax:tool_call>'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['<arg_key>'],
|
||||
'parser': _parse_glm_tool_calls,
|
||||
'markers': ['<tool_call>'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['<tool_call>'],
|
||||
'parser': _parse_xml_param_tool_calls,
|
||||
'markers': ['<tool_call>'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['[TOOL_CALLS]'],
|
||||
'parser': _parse_mistral_token_tool_calls,
|
||||
'markers': ['[TOOL_CALLS]'],
|
||||
},
|
||||
{
|
||||
'template_hints': ['<function_call>'],
|
||||
'parser': None,
|
||||
'markers': ['<function_call>'],
|
||||
},
|
||||
]
|
||||
|
||||
# Default ordered list of all specialized parsers.
|
||||
ALL_PARSERS = [
|
||||
_parse_deep_seek_tool_calls,
|
||||
_parse_kimi_tool_calls,
|
||||
_parse_channel_tool_calls,
|
||||
_parse_minimax_tool_calls,
|
||||
_parse_glm_tool_calls,
|
||||
_parse_xml_param_tool_calls,
|
||||
_parse_mistral_token_tool_calls,
|
||||
_parse_bare_name_tool_calls,
|
||||
_parse_pythonic_tool_calls,
|
||||
]
|
||||
|
||||
|
||||
def detect_tool_call_format(template_str):
|
||||
"""Inspect a chat/instruction template to determine which tool call
|
||||
formats are relevant.
|
||||
|
||||
Uses an exclude-based approach: starts with all parsers/markers,
|
||||
then removes the ones whose hints are not found in the template.
|
||||
|
||||
Returns (parsers, streaming_markers, check_bare_names).
|
||||
"""
|
||||
if not template_str:
|
||||
return None, TOOL_CALL_OPENING_MARKERS, True
|
||||
|
||||
matched_any = False
|
||||
exclude_parsers = []
|
||||
exclude_markers = []
|
||||
matched_markers = []
|
||||
|
||||
for fmt in TOOL_CALL_FORMATS:
|
||||
if any(hint in template_str for hint in fmt['template_hints']):
|
||||
matched_any = True
|
||||
matched_markers.extend(fmt['markers'])
|
||||
else:
|
||||
if fmt['parser'] is not None:
|
||||
exclude_parsers.append(fmt['parser'])
|
||||
exclude_markers.extend(fmt['markers'])
|
||||
|
||||
if not matched_any:
|
||||
return None, TOOL_CALL_OPENING_MARKERS, True
|
||||
|
||||
parsers = [p for p in ALL_PARSERS if p not in exclude_parsers]
|
||||
markers = [m for m in TOOL_CALL_OPENING_MARKERS if m not in exclude_markers or m in matched_markers]
|
||||
|
||||
return parsers, markers, False
|
||||
|
||||
|
||||
def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False, parsers: list = None):
|
||||
matches = []
|
||||
start_pos = None
|
||||
|
||||
|
|
@ -498,52 +602,13 @@ def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = Fa
|
|||
return matches, prefix
|
||||
return matches
|
||||
|
||||
# Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters)
|
||||
matches, start_pos = _parse_deep_seek_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
# Try specialized parsers.
|
||||
for parser in (parsers if parsers is not None else ALL_PARSERS):
|
||||
matches, start_pos = parser(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for Kimi-K2-style tool calls (pipe-delimited tokens)
|
||||
matches, start_pos = _parse_kimi_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for channel-based tool calls (e.g. GPT-OSS format)
|
||||
matches, start_pos = _parse_channel_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for MiniMax-style tool calls (invoke/parameter XML tags)
|
||||
matches, start_pos = _parse_minimax_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for GLM-style tool calls (arg_key/arg_value XML pairs)
|
||||
matches, start_pos = _parse_glm_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for XML-parameter style tool calls (e.g. Qwen3.5 format)
|
||||
matches, start_pos = _parse_xml_param_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for Mistral/Devstral-style tool calls ([TOOL_CALLS]name[ARGS]json)
|
||||
matches, start_pos = _parse_mistral_token_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for bare function-name style tool calls (e.g. Mistral format)
|
||||
matches, start_pos = _parse_bare_name_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for pythonic-style tool calls (e.g. Llama 4 format)
|
||||
matches, start_pos = _parse_pythonic_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
|
||||
# Generic fallback: regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
|
||||
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
|
||||
|
||||
for pattern in patterns:
|
||||
|
|
|
|||
Loading…
Reference in a new issue