diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 51427050..fc17a19a 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -12,7 +12,7 @@ from pydantic import ValidationError from extensions.openai.errors import InvalidRequestError from extensions.openai.typing import ToolDefinition from extensions.openai.utils import debug_msg -from modules.tool_parsing import get_tool_call_id, parse_tool_call +from modules.tool_parsing import get_tool_call_id, parse_tool_call, detect_tool_call_format from modules import shared from modules.reasoning import extract_reasoning from modules.chat import ( @@ -484,6 +484,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p tool_calls = [] end_last_tool_call = 0 supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None + _tool_parsers = None # Filter supported_tools when tool_choice specifies a particular function if supported_tools and isinstance(tool_choice, dict): @@ -491,11 +492,15 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p if specified_func and specified_func in supported_tools: supported_tools = [specified_func] + if supported_tools is not None: + _template_str = generate_params.get('instruction_template_str', '') if generate_params.get('mode') == 'instruct' else generate_params.get('chat_template_str', '') + _tool_parsers, _, _ = detect_tool_call_format(_template_str) + for a in generator: answer = a['internal'][-1][1] if supported_tools is not None: - tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else [] + tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools, parsers=_tool_parsers) if len(answer) > 0 else [] if len(tool_call) > 0: for tc in tool_call: tc["id"] = get_tool_call_id() diff --git a/modules/chat.py b/modules/chat.py index 08f55539..1ffbb56b 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -1035,8 +1035,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess _check_tool_markers = bool(state.get('tools')) _last_visible_before_tool_buffer = None if _check_tool_markers: - from modules.tool_parsing import streaming_tool_buffer_check + from modules.tool_parsing import streaming_tool_buffer_check, detect_tool_call_format _tool_names = [t['function']['name'] for t in state['tools'] if 'function' in t and 'name' in t['function']] + _template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '') + _, _streaming_markers, _check_bare_names = detect_tool_call_format(_template_str) # Generate reply = None @@ -1088,7 +1090,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess if is_stream: if _check_tool_markers: - if streaming_tool_buffer_check(output['internal'][-1][1], _tool_names): + if streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names): continue _last_visible_before_tool_buffer = output['visible'][-1][1] @@ -1128,7 +1130,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess # visible text from before buffering started so raw markup doesn't flash # in the UI. The internal text is left intact so the caller can still # parse tool calls from it. - if is_stream and _check_tool_markers and streaming_tool_buffer_check(output['internal'][-1][1], _tool_names): + if is_stream and _check_tool_markers and streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names): output['visible'][-1][1] = _last_visible_before_tool_buffer or '' yield output @@ -1210,14 +1212,17 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): # Load tools if any are selected selected = state.get('selected_tools', []) parse_tool_call = None + _tool_parsers = None if selected: from modules.tool_use import load_tools, execute_tool - from modules.tool_parsing import parse_tool_call, get_tool_call_id + from modules.tool_parsing import parse_tool_call, get_tool_call_id, detect_tool_call_format if selected: tool_defs, tool_executors = load_tools(selected) state['tools'] = tool_defs tool_func_names = [t['function']['name'] for t in tool_defs] + _template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '') + _tool_parsers, _, _ = detect_tool_call_format(_template_str) else: tool_func_names = None @@ -1272,7 +1277,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): last_save_time = current_time # Early stop on tool call detection - if tool_func_names and parse_tool_call(history['internal'][-1][1], tool_func_names): + if tool_func_names and parse_tool_call(history['internal'][-1][1], tool_func_names, parsers=_tool_parsers): break # Save the model's visible output before re-applying visible_prefix, @@ -1304,7 +1309,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): break answer = history['internal'][-1][1] - parsed_calls, content_prefix = parse_tool_call(answer, tool_func_names, return_prefix=True) if answer else (None, '') + parsed_calls, content_prefix = parse_tool_call(answer, tool_func_names, return_prefix=True, parsers=_tool_parsers) if answer else (None, '') if not parsed_calls: break # No tool calls — done diff --git a/modules/tool_parsing.py b/modules/tool_parsing.py index 418503ad..0454e901 100644 --- a/modules/tool_parsing.py +++ b/modules/tool_parsing.py @@ -9,9 +9,7 @@ def get_tool_call_id() -> str: return "call_" + "".join(b).lower() -# Known opening markers for tool calls across model formats. -# Used during streaming to buffer output that might be tool call markup, -# preventing raw markup from leaking into displayed/streamed content. +# All known opening markers for tool calls across model formats. TOOL_CALL_OPENING_MARKERS = [ '', '', @@ -25,36 +23,47 @@ TOOL_CALL_OPENING_MARKERS = [ '<|channel|>commentary', ] -def streaming_tool_buffer_check(text, tool_names=None): + +def streaming_tool_buffer_check(text, markers=None, tool_names=None, check_bare_names=False): ''' Check whether streaming output should be withheld because it may contain tool-call markup. + + Args: + text: Full accumulated internal text. + markers: Template-specific markers for partial-prefix matching. + If None, falls back to TOOL_CALL_OPENING_MARKERS. + tool_names: List of tool function names. + check_bare_names: Whether to do partial-prefix matching on tool + names (for models with unknown template format). ''' - # Full marker found → buffer permanently + # Full marker found in text → buffer permanently. + # Always checks ALL known markers regardless of template (cheap safety net). for marker in TOOL_CALL_OPENING_MARKERS: if marker in text: return True - # Bare function-name style (e.g. Devstral): "get_weather{...}" - # Only match tool name followed by '{' to avoid false positives on - # common words that happen to be tool names (e.g. "get", "search"). + # Bare function-name full match: "get_weather{...}" or "get_weather {...}" if tool_names: for name in tool_names: if name + '{' in text or name + ' {' in text: return True - # Partial: text ends with tool name (or prefix of it) but '{' hasn't arrived yet + + # Partial-prefix matching: only for template-specific markers. + for marker in (markers if markers is not None else TOOL_CALL_OPENING_MARKERS): + for prefix_len in range(min(len(marker) - 1, len(text)), 0, -1): + if text.endswith(marker[:prefix_len]): + return True + + # Bare-name partial matching: only when template format is unknown. + if check_bare_names and tool_names: + for name in tool_names: if text.endswith(name): return True for prefix_len in range(min(len(name) - 1, len(text)), 0, -1): if text.endswith(name[:prefix_len]): return True - # Tail might be a partial marker forming across tokens - for marker in TOOL_CALL_OPENING_MARKERS: - for prefix_len in range(min(len(marker) - 1, len(text)), 0, -1): - if text.endswith(marker[:prefix_len]): - return True - return False @@ -488,7 +497,102 @@ def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]): return matches, start_pos -def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False): +# Format registry: maps template substrings to the parser and streaming +# markers for that format. When a format's hints are NOT found in the +# template, its parser and markers are excluded. +TOOL_CALL_FORMATS = [ + { + 'template_hints': ['tool▁call▁begin', 'tool▁calls▁begin'], + 'parser': _parse_deep_seek_tool_calls, + 'markers': ['<|tool▁call▁begin|>', '<|tool▁calls▁begin|>'], + }, + { + 'template_hints': ['<|tool_call_begin|>', 'tool_calls_section'], + 'parser': _parse_kimi_tool_calls, + 'markers': ['<|tool_call_begin|>', '<|tool_calls_section_begin|>'], + }, + { + 'template_hints': ['to=functions.', '<|channel|>'], + 'parser': _parse_channel_tool_calls, + 'markers': ['to=functions.', '<|channel|>commentary'], + }, + { + 'template_hints': ['minimax:tool_call'], + 'parser': _parse_minimax_tool_calls, + 'markers': [''], + }, + { + 'template_hints': [''], + 'parser': _parse_glm_tool_calls, + 'markers': [''], + }, + { + 'template_hints': [''], + 'parser': _parse_xml_param_tool_calls, + 'markers': [''], + }, + { + 'template_hints': ['[TOOL_CALLS]'], + 'parser': _parse_mistral_token_tool_calls, + 'markers': ['[TOOL_CALLS]'], + }, + { + 'template_hints': [''], + 'parser': None, + 'markers': [''], + }, +] + +# Default ordered list of all specialized parsers. +ALL_PARSERS = [ + _parse_deep_seek_tool_calls, + _parse_kimi_tool_calls, + _parse_channel_tool_calls, + _parse_minimax_tool_calls, + _parse_glm_tool_calls, + _parse_xml_param_tool_calls, + _parse_mistral_token_tool_calls, + _parse_bare_name_tool_calls, + _parse_pythonic_tool_calls, +] + + +def detect_tool_call_format(template_str): + """Inspect a chat/instruction template to determine which tool call + formats are relevant. + + Uses an exclude-based approach: starts with all parsers/markers, + then removes the ones whose hints are not found in the template. + + Returns (parsers, streaming_markers, check_bare_names). + """ + if not template_str: + return None, TOOL_CALL_OPENING_MARKERS, True + + matched_any = False + exclude_parsers = [] + exclude_markers = [] + matched_markers = [] + + for fmt in TOOL_CALL_FORMATS: + if any(hint in template_str for hint in fmt['template_hints']): + matched_any = True + matched_markers.extend(fmt['markers']) + else: + if fmt['parser'] is not None: + exclude_parsers.append(fmt['parser']) + exclude_markers.extend(fmt['markers']) + + if not matched_any: + return None, TOOL_CALL_OPENING_MARKERS, True + + parsers = [p for p in ALL_PARSERS if p not in exclude_parsers] + markers = [m for m in TOOL_CALL_OPENING_MARKERS if m not in exclude_markers or m in matched_markers] + + return parsers, markers, False + + +def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False, parsers: list = None): matches = [] start_pos = None @@ -498,52 +602,13 @@ def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = Fa return matches, prefix return matches - # Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters) - matches, start_pos = _parse_deep_seek_tool_calls(answer, tool_names) - if matches: - return _return(matches, start_pos) + # Try specialized parsers. + for parser in (parsers if parsers is not None else ALL_PARSERS): + matches, start_pos = parser(answer, tool_names) + if matches: + return _return(matches, start_pos) - # Check for Kimi-K2-style tool calls (pipe-delimited tokens) - matches, start_pos = _parse_kimi_tool_calls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for channel-based tool calls (e.g. GPT-OSS format) - matches, start_pos = _parse_channel_tool_calls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for MiniMax-style tool calls (invoke/parameter XML tags) - matches, start_pos = _parse_minimax_tool_calls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for GLM-style tool calls (arg_key/arg_value XML pairs) - matches, start_pos = _parse_glm_tool_calls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for XML-parameter style tool calls (e.g. Qwen3.5 format) - matches, start_pos = _parse_xml_param_tool_calls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for Mistral/Devstral-style tool calls ([TOOL_CALLS]name[ARGS]json) - matches, start_pos = _parse_mistral_token_tool_calls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for bare function-name style tool calls (e.g. Mistral format) - matches, start_pos = _parse_bare_name_tool_calls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for pythonic-style tool calls (e.g. Llama 4 format) - matches, start_pos = _parse_pythonic_tool_calls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Define the regex pattern to find the JSON content wrapped in , , , and other tags observed from various models + # Generic fallback: regex pattern to find the JSON content wrapped in , , , and other tags observed from various models patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)"] for pattern in patterns: