Optimize tool call detection

Avoids templates that don't contain a given necessary keyword
This commit is contained in:
oobabooga 2026-03-14 12:09:41 -07:00
parent d0a4993cf4
commit 573617157a
3 changed files with 144 additions and 69 deletions

View file

@ -12,7 +12,7 @@ from pydantic import ValidationError
from extensions.openai.errors import InvalidRequestError
from extensions.openai.typing import ToolDefinition
from extensions.openai.utils import debug_msg
from modules.tool_parsing import get_tool_call_id, parse_tool_call
from modules.tool_parsing import get_tool_call_id, parse_tool_call, detect_tool_call_format
from modules import shared
from modules.reasoning import extract_reasoning
from modules.chat import (
@ -484,6 +484,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
tool_calls = []
end_last_tool_call = 0
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
_tool_parsers = None
# Filter supported_tools when tool_choice specifies a particular function
if supported_tools and isinstance(tool_choice, dict):
@ -491,11 +492,15 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
if specified_func and specified_func in supported_tools:
supported_tools = [specified_func]
if supported_tools is not None:
_template_str = generate_params.get('instruction_template_str', '') if generate_params.get('mode') == 'instruct' else generate_params.get('chat_template_str', '')
_tool_parsers, _, _ = detect_tool_call_format(_template_str)
for a in generator:
answer = a['internal'][-1][1]
if supported_tools is not None:
tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools, parsers=_tool_parsers) if len(answer) > 0 else []
if len(tool_call) > 0:
for tc in tool_call:
tc["id"] = get_tool_call_id()

View file

@ -1035,8 +1035,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
_check_tool_markers = bool(state.get('tools'))
_last_visible_before_tool_buffer = None
if _check_tool_markers:
from modules.tool_parsing import streaming_tool_buffer_check
from modules.tool_parsing import streaming_tool_buffer_check, detect_tool_call_format
_tool_names = [t['function']['name'] for t in state['tools'] if 'function' in t and 'name' in t['function']]
_template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '')
_, _streaming_markers, _check_bare_names = detect_tool_call_format(_template_str)
# Generate
reply = None
@ -1088,7 +1090,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
if is_stream:
if _check_tool_markers:
if streaming_tool_buffer_check(output['internal'][-1][1], _tool_names):
if streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names):
continue
_last_visible_before_tool_buffer = output['visible'][-1][1]
@ -1128,7 +1130,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
# visible text from before buffering started so raw markup doesn't flash
# in the UI. The internal text is left intact so the caller can still
# parse tool calls from it.
if is_stream and _check_tool_markers and streaming_tool_buffer_check(output['internal'][-1][1], _tool_names):
if is_stream and _check_tool_markers and streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names):
output['visible'][-1][1] = _last_visible_before_tool_buffer or ''
yield output
@ -1210,14 +1212,17 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
# Load tools if any are selected
selected = state.get('selected_tools', [])
parse_tool_call = None
_tool_parsers = None
if selected:
from modules.tool_use import load_tools, execute_tool
from modules.tool_parsing import parse_tool_call, get_tool_call_id
from modules.tool_parsing import parse_tool_call, get_tool_call_id, detect_tool_call_format
if selected:
tool_defs, tool_executors = load_tools(selected)
state['tools'] = tool_defs
tool_func_names = [t['function']['name'] for t in tool_defs]
_template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '')
_tool_parsers, _, _ = detect_tool_call_format(_template_str)
else:
tool_func_names = None
@ -1272,7 +1277,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
last_save_time = current_time
# Early stop on tool call detection
if tool_func_names and parse_tool_call(history['internal'][-1][1], tool_func_names):
if tool_func_names and parse_tool_call(history['internal'][-1][1], tool_func_names, parsers=_tool_parsers):
break
# Save the model's visible output before re-applying visible_prefix,
@ -1304,7 +1309,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
break
answer = history['internal'][-1][1]
parsed_calls, content_prefix = parse_tool_call(answer, tool_func_names, return_prefix=True) if answer else (None, '')
parsed_calls, content_prefix = parse_tool_call(answer, tool_func_names, return_prefix=True, parsers=_tool_parsers) if answer else (None, '')
if not parsed_calls:
break # No tool calls — done

View file

@ -9,9 +9,7 @@ def get_tool_call_id() -> str:
return "call_" + "".join(b).lower()
# Known opening markers for tool calls across model formats.
# Used during streaming to buffer output that might be tool call markup,
# preventing raw markup from leaking into displayed/streamed content.
# All known opening markers for tool calls across model formats.
TOOL_CALL_OPENING_MARKERS = [
'<tool_call>',
'<function_call>',
@ -25,36 +23,47 @@ TOOL_CALL_OPENING_MARKERS = [
'<|channel|>commentary',
]
def streaming_tool_buffer_check(text, tool_names=None):
def streaming_tool_buffer_check(text, markers=None, tool_names=None, check_bare_names=False):
'''
Check whether streaming output should be withheld because it may
contain tool-call markup.
Args:
text: Full accumulated internal text.
markers: Template-specific markers for partial-prefix matching.
If None, falls back to TOOL_CALL_OPENING_MARKERS.
tool_names: List of tool function names.
check_bare_names: Whether to do partial-prefix matching on tool
names (for models with unknown template format).
'''
# Full marker found → buffer permanently
# Full marker found in text → buffer permanently.
# Always checks ALL known markers regardless of template (cheap safety net).
for marker in TOOL_CALL_OPENING_MARKERS:
if marker in text:
return True
# Bare function-name style (e.g. Devstral): "get_weather{...}"
# Only match tool name followed by '{' to avoid false positives on
# common words that happen to be tool names (e.g. "get", "search").
# Bare function-name full match: "get_weather{...}" or "get_weather {...}"
if tool_names:
for name in tool_names:
if name + '{' in text or name + ' {' in text:
return True
# Partial: text ends with tool name (or prefix of it) but '{' hasn't arrived yet
# Partial-prefix matching: only for template-specific markers.
for marker in (markers if markers is not None else TOOL_CALL_OPENING_MARKERS):
for prefix_len in range(min(len(marker) - 1, len(text)), 0, -1):
if text.endswith(marker[:prefix_len]):
return True
# Bare-name partial matching: only when template format is unknown.
if check_bare_names and tool_names:
for name in tool_names:
if text.endswith(name):
return True
for prefix_len in range(min(len(name) - 1, len(text)), 0, -1):
if text.endswith(name[:prefix_len]):
return True
# Tail might be a partial marker forming across tokens
for marker in TOOL_CALL_OPENING_MARKERS:
for prefix_len in range(min(len(marker) - 1, len(text)), 0, -1):
if text.endswith(marker[:prefix_len]):
return True
return False
@ -488,7 +497,102 @@ def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]):
return matches, start_pos
def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False):
# Format registry: maps template substrings to the parser and streaming
# markers for that format. When a format's hints are NOT found in the
# template, its parser and markers are excluded.
TOOL_CALL_FORMATS = [
{
'template_hints': ['tool▁call▁begin', 'tool▁calls▁begin'],
'parser': _parse_deep_seek_tool_calls,
'markers': ['<tool▁call▁begin>', '<tool▁calls▁begin>'],
},
{
'template_hints': ['<|tool_call_begin|>', 'tool_calls_section'],
'parser': _parse_kimi_tool_calls,
'markers': ['<|tool_call_begin|>', '<|tool_calls_section_begin|>'],
},
{
'template_hints': ['to=functions.', '<|channel|>'],
'parser': _parse_channel_tool_calls,
'markers': ['to=functions.', '<|channel|>commentary'],
},
{
'template_hints': ['minimax:tool_call'],
'parser': _parse_minimax_tool_calls,
'markers': ['<minimax:tool_call>'],
},
{
'template_hints': ['<arg_key>'],
'parser': _parse_glm_tool_calls,
'markers': ['<tool_call>'],
},
{
'template_hints': ['<tool_call>'],
'parser': _parse_xml_param_tool_calls,
'markers': ['<tool_call>'],
},
{
'template_hints': ['[TOOL_CALLS]'],
'parser': _parse_mistral_token_tool_calls,
'markers': ['[TOOL_CALLS]'],
},
{
'template_hints': ['<function_call>'],
'parser': None,
'markers': ['<function_call>'],
},
]
# Default ordered list of all specialized parsers.
ALL_PARSERS = [
_parse_deep_seek_tool_calls,
_parse_kimi_tool_calls,
_parse_channel_tool_calls,
_parse_minimax_tool_calls,
_parse_glm_tool_calls,
_parse_xml_param_tool_calls,
_parse_mistral_token_tool_calls,
_parse_bare_name_tool_calls,
_parse_pythonic_tool_calls,
]
def detect_tool_call_format(template_str):
"""Inspect a chat/instruction template to determine which tool call
formats are relevant.
Uses an exclude-based approach: starts with all parsers/markers,
then removes the ones whose hints are not found in the template.
Returns (parsers, streaming_markers, check_bare_names).
"""
if not template_str:
return None, TOOL_CALL_OPENING_MARKERS, True
matched_any = False
exclude_parsers = []
exclude_markers = []
matched_markers = []
for fmt in TOOL_CALL_FORMATS:
if any(hint in template_str for hint in fmt['template_hints']):
matched_any = True
matched_markers.extend(fmt['markers'])
else:
if fmt['parser'] is not None:
exclude_parsers.append(fmt['parser'])
exclude_markers.extend(fmt['markers'])
if not matched_any:
return None, TOOL_CALL_OPENING_MARKERS, True
parsers = [p for p in ALL_PARSERS if p not in exclude_parsers]
markers = [m for m in TOOL_CALL_OPENING_MARKERS if m not in exclude_markers or m in matched_markers]
return parsers, markers, False
def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False, parsers: list = None):
matches = []
start_pos = None
@ -498,52 +602,13 @@ def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = Fa
return matches, prefix
return matches
# Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters)
matches, start_pos = _parse_deep_seek_tool_calls(answer, tool_names)
if matches:
return _return(matches, start_pos)
# Try specialized parsers.
for parser in (parsers if parsers is not None else ALL_PARSERS):
matches, start_pos = parser(answer, tool_names)
if matches:
return _return(matches, start_pos)
# Check for Kimi-K2-style tool calls (pipe-delimited tokens)
matches, start_pos = _parse_kimi_tool_calls(answer, tool_names)
if matches:
return _return(matches, start_pos)
# Check for channel-based tool calls (e.g. GPT-OSS format)
matches, start_pos = _parse_channel_tool_calls(answer, tool_names)
if matches:
return _return(matches, start_pos)
# Check for MiniMax-style tool calls (invoke/parameter XML tags)
matches, start_pos = _parse_minimax_tool_calls(answer, tool_names)
if matches:
return _return(matches, start_pos)
# Check for GLM-style tool calls (arg_key/arg_value XML pairs)
matches, start_pos = _parse_glm_tool_calls(answer, tool_names)
if matches:
return _return(matches, start_pos)
# Check for XML-parameter style tool calls (e.g. Qwen3.5 format)
matches, start_pos = _parse_xml_param_tool_calls(answer, tool_names)
if matches:
return _return(matches, start_pos)
# Check for Mistral/Devstral-style tool calls ([TOOL_CALLS]name[ARGS]json)
matches, start_pos = _parse_mistral_token_tool_calls(answer, tool_names)
if matches:
return _return(matches, start_pos)
# Check for bare function-name style tool calls (e.g. Mistral format)
matches, start_pos = _parse_bare_name_tool_calls(answer, tool_names)
if matches:
return _return(matches, start_pos)
# Check for pythonic-style tool calls (e.g. Llama 4 format)
matches, start_pos = _parse_pythonic_tool_calls(answer, tool_names)
if matches:
return _return(matches, start_pos)
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
# Generic fallback: regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
for pattern in patterns: