From e8d1c663037666bafc0a45f4be0471a88fda4d57 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Fri, 13 Mar 2026 18:13:12 -0700
Subject: [PATCH] Clean up tool calling code
---
extensions/openai/completions.py | 7 +-
extensions/openai/utils.py | 558 -------------------------------
modules/chat.py | 17 +-
modules/tool_parsing.py | 553 ++++++++++++++++++++++++++++++
modules/tool_use.py | 6 +-
modules/ui_chat.py | 2 +-
modules/web_search.py | 37 +-
7 files changed, 604 insertions(+), 576 deletions(-)
create mode 100644 modules/tool_parsing.py
diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py
index 290a5bc0..27defe42 100644
--- a/extensions/openai/completions.py
+++ b/extensions/openai/completions.py
@@ -11,7 +11,8 @@ from pydantic import ValidationError
from extensions.openai.errors import InvalidRequestError
from extensions.openai.typing import ToolDefinition
-from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
+from extensions.openai.utils import debug_msg
+from modules.tool_parsing import get_tool_call_id, parse_tool_call
from modules import shared
from modules.reasoning import extract_reasoning
from modules.chat import (
@@ -491,10 +492,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
answer = a['internal'][-1][1]
if supported_tools is not None:
- tool_call = parseToolCall(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
+ tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
if len(tool_call) > 0:
for tc in tool_call:
- tc["id"] = getToolCallId()
+ tc["id"] = get_tool_call_id()
if stream:
tc["index"] = len(tool_calls)
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
diff --git a/extensions/openai/utils.py b/extensions/openai/utils.py
index b179c267..2b414769 100644
--- a/extensions/openai/utils.py
+++ b/extensions/openai/utils.py
@@ -1,8 +1,5 @@
import base64
-import json
import os
-import random
-import re
import time
import traceback
from typing import Callable, Optional
@@ -55,558 +52,3 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star
time.sleep(3)
raise Exception('Could not start cloudflared.')
-
-
-def getToolCallId() -> str:
- letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
- b = [random.choice(letter_bytes) for _ in range(8)]
- return "call_" + "".join(b).lower()
-
-
-def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str]):
- # check if property 'function' exists and is a dictionary, otherwise adapt dict
- if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
- candidate_dict = {"type": "function", "function": candidate_dict}
- if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
- candidate_dict['name'] = candidate_dict['function']
- del candidate_dict['function']
- candidate_dict = {"type": "function", "function": candidate_dict}
- if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
- # check if 'name' exists within 'function' and is part of known tools
- if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
- candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
- # map property 'parameters' used by some older models to 'arguments'
- if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
- candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
- del candidate_dict["function"]["parameters"]
- return candidate_dict
- return None
-
-
-def _extractBalancedJson(text: str, start: int) -> str | None:
- """Extract a balanced JSON object from text starting at the given position.
-
- Walks through the string tracking brace depth and string boundaries
- to correctly handle arbitrary nesting levels.
- """
- if start >= len(text) or text[start] != '{':
- return None
- depth = 0
- in_string = False
- escape_next = False
- for i in range(start, len(text)):
- c = text[i]
- if escape_next:
- escape_next = False
- continue
- if c == '\\' and in_string:
- escape_next = True
- continue
- if c == '"':
- in_string = not in_string
- continue
- if in_string:
- continue
- if c == '{':
- depth += 1
- elif c == '}':
- depth -= 1
- if depth == 0:
- return text[start:i + 1]
- return None
-
-
-def _parseChannelToolCalls(answer: str, tool_names: list[str]):
- """Parse channel-based tool calls used by GPT-OSS and similar models.
-
- Format:
- <|start|>assistant to=functions.func_name<|channel|>commentary json<|message|>{"arg": "value"}
- or:
- <|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"}
- """
- matches = []
- start_pos = None
- # Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format)
- # Pattern 2: to=functions.NAME after <|channel|> (alternative format)
- patterns = [
- r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>',
- r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>',
- ]
- for pattern in patterns:
- for m in re.finditer(pattern, answer):
- func_name = m.group(1).strip()
- if func_name not in tool_names:
- continue
- json_str = _extractBalancedJson(answer, m.end())
- if json_str is None:
- continue
- try:
- arguments = json.loads(json_str)
- if start_pos is None:
- prefix = answer.rfind('<|start|>assistant', 0, m.start())
- start_pos = prefix if prefix != -1 else m.start()
- matches.append({
- "type": "function",
- "function": {
- "name": func_name,
- "arguments": arguments
- }
- })
- except json.JSONDecodeError:
- pass
- if matches:
- break
- return matches, start_pos
-
-
-def _parseMistralTokenToolCalls(answer: str, tool_names: list[str]):
- """Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens.
-
- Format:
- [TOOL_CALLS]func_name[ARGS]{"arg": "value"}
- """
- matches = []
- start_pos = None
- for m in re.finditer(
- r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*',
- answer
- ):
- func_name = m.group(1).strip()
- if func_name not in tool_names:
- continue
- json_str = _extractBalancedJson(answer, m.end())
- if json_str is None:
- continue
- try:
- arguments = json.loads(json_str)
- if start_pos is None:
- start_pos = m.start()
- matches.append({
- "type": "function",
- "function": {
- "name": func_name,
- "arguments": arguments
- }
- })
- except json.JSONDecodeError:
- pass
- return matches, start_pos
-
-
-def _parseBareNameToolCalls(answer: str, tool_names: list[str]):
- """Parse bare function-name style tool calls used by Mistral and similar models.
-
- Format:
- functionName{"arg": "value"}
- Multiple calls are concatenated directly or separated by whitespace.
- """
- matches = []
- start_pos = None
- # Match tool name followed by opening brace, then extract balanced JSON
- escaped_names = [re.escape(name) for name in tool_names]
- pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
- for match in re.finditer(pattern, answer):
- text = match.group(0)
- name = None
- for n in tool_names:
- if text.startswith(n):
- name = n
- break
- if not name:
- continue
- brace_start = match.end() - 1
- json_str = _extractBalancedJson(answer, brace_start)
- if json_str is None:
- continue
- try:
- arguments = json.loads(json_str)
- if start_pos is None:
- start_pos = match.start()
- matches.append({
- "type": "function",
- "function": {
- "name": name,
- "arguments": arguments
- }
- })
- except json.JSONDecodeError:
- pass
- return matches, start_pos
-
-
-def _parseXmlParamToolCalls(answer: str, tool_names: list[str]):
- """Parse XML-parameter style tool calls used by Qwen3.5 and similar models.
-
- Format:
-
-
- value
-
-
- """
- matches = []
- start_pos = None
- for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL):
- tc_content = tc_match.group(1)
- func_match = re.search(r']+)>', tc_content)
- if not func_match:
- continue
- func_name = func_match.group(1).strip()
- if func_name not in tool_names:
- continue
- arguments = {}
- for param_match in re.finditer(r']+)>\s*(.*?)\s*', tc_content, re.DOTALL):
- param_name = param_match.group(1).strip()
- param_value = param_match.group(2).strip()
- try:
- param_value = json.loads(param_value)
- except (json.JSONDecodeError, ValueError):
- pass # keep as string
- arguments[param_name] = param_value
- if start_pos is None:
- start_pos = tc_match.start()
- matches.append({
- "type": "function",
- "function": {
- "name": func_name,
- "arguments": arguments
- }
- })
- return matches, start_pos
-
-
-def _parseKimiToolCalls(answer: str, tool_names: list[str]):
- """Parse Kimi-K2-style tool calls using pipe-delimited tokens.
-
- Format:
- <|tool_calls_section_begin|>
- <|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
- <|tool_calls_section_end|>
- """
- matches = []
- start_pos = None
- for m in re.finditer(
- r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
- answer
- ):
- func_name = m.group(1).strip()
- if func_name not in tool_names:
- continue
- json_str = _extractBalancedJson(answer, m.end())
- if json_str is None:
- continue
- try:
- arguments = json.loads(json_str)
- if start_pos is None:
- # Check for section begin marker before the call marker
- section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start())
- start_pos = section if section != -1 else m.start()
- matches.append({
- "type": "function",
- "function": {
- "name": func_name,
- "arguments": arguments
- }
- })
- except json.JSONDecodeError:
- pass
- return matches, start_pos
-
-
-def _parseMiniMaxToolCalls(answer: str, tool_names: list[str]):
- """Parse MiniMax-style tool calls using invoke/parameter XML tags.
-
- Format:
-
-
- value
-
-
- """
- matches = []
- start_pos = None
- for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL):
- tc_content = tc_match.group(1)
- # Split on to handle multiple parallel calls in one block
- for invoke_match in re.finditer(r'(.*?)', tc_content, re.DOTALL):
- func_name = invoke_match.group(1).strip()
- if func_name not in tool_names:
- continue
- invoke_body = invoke_match.group(2)
- arguments = {}
- for param_match in re.finditer(r'\s*(.*?)\s*', invoke_body, re.DOTALL):
- param_name = param_match.group(1).strip()
- param_value = param_match.group(2).strip()
- try:
- param_value = json.loads(param_value)
- except (json.JSONDecodeError, ValueError):
- pass # keep as string
- arguments[param_name] = param_value
- if start_pos is None:
- start_pos = tc_match.start()
- matches.append({
- "type": "function",
- "function": {
- "name": func_name,
- "arguments": arguments
- }
- })
- return matches, start_pos
-
-
-def _parseDeepSeekToolCalls(answer: str, tool_names: list[str]):
- """Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
-
- Format:
- <|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|>
- """
- matches = []
- start_pos = None
- for m in re.finditer(
- r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*',
- answer
- ):
- func_name = m.group(1).strip()
- if func_name not in tool_names:
- continue
- json_str = _extractBalancedJson(answer, m.end())
- if json_str is None:
- continue
- try:
- arguments = json.loads(json_str)
- if start_pos is None:
- # Check for section begin marker before the call marker
- section = answer.rfind('<|tool▁calls▁begin|>', 0, m.start())
- start_pos = section if section != -1 else m.start()
- matches.append({
- "type": "function",
- "function": {
- "name": func_name,
- "arguments": arguments
- }
- })
- except json.JSONDecodeError:
- pass
- return matches, start_pos
-
-
-def _parseGlmToolCalls(answer: str, tool_names: list[str]):
- """Parse GLM-style tool calls using arg_key/arg_value XML pairs.
-
- Format:
- function_name
- key1
- value1
-
- """
- matches = []
- start_pos = None
- for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL):
- tc_content = tc_match.group(1)
- # First non-tag text is the function name
- name_match = re.match(r'([^<\s]+)', tc_content.strip())
- if not name_match:
- continue
- func_name = name_match.group(1).strip()
- if func_name not in tool_names:
- continue
- # Extract arg_key/arg_value pairs
- keys = [k.group(1).strip() for k in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)]
- vals = [v.group(1).strip() for v in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)]
- if len(keys) != len(vals):
- continue
- arguments = {}
- for k, v in zip(keys, vals):
- try:
- v = json.loads(v)
- except (json.JSONDecodeError, ValueError):
- pass # keep as string
- arguments[k] = v
- if start_pos is None:
- start_pos = tc_match.start()
- matches.append({
- "type": "function",
- "function": {
- "name": func_name,
- "arguments": arguments
- }
- })
- return matches, start_pos
-
-
-def _parsePythonicToolCalls(answer: str, tool_names: list[str]):
- """Parse pythonic-style tool calls used by Llama 4 and similar models.
-
- Format:
- [func_name(param1="value1", param2="value2"), func_name2(...)]
- """
- matches = []
- start_pos = None
- # Match a bracketed list of function calls
- bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
- if not bracket_match:
- return matches, start_pos
-
- inner = bracket_match.group(1)
-
- # Build pattern for known tool names
- escaped_names = [re.escape(name) for name in tool_names]
- name_pattern = '|'.join(escaped_names)
-
- for call_match in re.finditer(
- r'(' + name_pattern + r')\(([^)]*)\)',
- inner
- ):
- func_name = call_match.group(1)
- params_str = call_match.group(2).strip()
- arguments = {}
-
- if params_str:
- # Parse key="value" pairs, handling commas inside quoted values
- for param_match in re.finditer(
- r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
- params_str
- ):
- param_name = param_match.group(1)
- param_value = param_match.group(2).strip()
- # Strip surrounding quotes
- if (param_value.startswith('"') and param_value.endswith('"')) or \
- (param_value.startswith("'") and param_value.endswith("'")):
- param_value = param_value[1:-1]
- # Try to parse as JSON for numeric/bool/null values
- try:
- param_value = json.loads(param_value)
- except (json.JSONDecodeError, ValueError):
- pass
- arguments[param_name] = param_value
-
- if start_pos is None:
- start_pos = bracket_match.start()
- matches.append({
- "type": "function",
- "function": {
- "name": func_name,
- "arguments": arguments
- }
- })
-
- return matches, start_pos
-
-
-def parseToolCall(answer: str, tool_names: list[str], return_prefix: bool = False):
- matches = []
- start_pos = None
-
- def _return(matches, start_pos):
- if return_prefix:
- prefix = answer[:start_pos] if matches and start_pos is not None else ''
- return matches, prefix
- return matches
-
- # abort on very short answers to save computation cycles
- if len(answer) < 10:
- return _return(matches, start_pos)
-
- # Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters)
- matches, start_pos = _parseDeepSeekToolCalls(answer, tool_names)
- if matches:
- return _return(matches, start_pos)
-
- # Check for Kimi-K2-style tool calls (pipe-delimited tokens)
- matches, start_pos = _parseKimiToolCalls(answer, tool_names)
- if matches:
- return _return(matches, start_pos)
-
- # Check for channel-based tool calls (e.g. GPT-OSS format)
- matches, start_pos = _parseChannelToolCalls(answer, tool_names)
- if matches:
- return _return(matches, start_pos)
-
- # Check for MiniMax-style tool calls (invoke/parameter XML tags)
- matches, start_pos = _parseMiniMaxToolCalls(answer, tool_names)
- if matches:
- return _return(matches, start_pos)
-
- # Check for GLM-style tool calls (arg_key/arg_value XML pairs)
- matches, start_pos = _parseGlmToolCalls(answer, tool_names)
- if matches:
- return _return(matches, start_pos)
-
- # Check for XML-parameter style tool calls (e.g. Qwen3.5 format)
- matches, start_pos = _parseXmlParamToolCalls(answer, tool_names)
- if matches:
- return _return(matches, start_pos)
-
- # Check for Mistral/Devstral-style tool calls ([TOOL_CALLS]name[ARGS]json)
- matches, start_pos = _parseMistralTokenToolCalls(answer, tool_names)
- if matches:
- return _return(matches, start_pos)
-
- # Check for bare function-name style tool calls (e.g. Mistral format)
- matches, start_pos = _parseBareNameToolCalls(answer, tool_names)
- if matches:
- return _return(matches, start_pos)
-
- # Check for pythonic-style tool calls (e.g. Llama 4 format)
- matches, start_pos = _parsePythonicToolCalls(answer, tool_names)
- if matches:
- return _return(matches, start_pos)
-
- # Define the regex pattern to find the JSON content wrapped in , , , and other tags observed from various models
- patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)\1>"]
-
- for pattern in patterns:
- for match in re.finditer(pattern, answer, re.DOTALL):
- # print(match.group(2))
- if match.group(2) is None:
- continue
- # remove backtick wraps if present
- candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
- candidate = re.sub(r"```$", "", candidate.strip())
- # unwrap inner tags
- candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
- # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
- if re.search(r"\}\s*\n\s*\{", candidate) is not None:
- candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
- if not candidate.strip().startswith("["):
- candidate = "[" + candidate + "]"
-
- candidates = []
- try:
- # parse the candidate JSON into a dictionary
- candidates = json.loads(candidate)
- if not isinstance(candidates, list):
- candidates = [candidates]
- except json.JSONDecodeError:
- # Ignore invalid JSON silently
- continue
-
- for candidate_dict in candidates:
- checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
- if checked_candidate is not None:
- if start_pos is None:
- start_pos = match.start()
- matches.append(checked_candidate)
-
- # last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
- if len(matches) == 0:
- try:
- candidate = answer
- # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
- if re.search(r"\}\s*\n\s*\{", candidate) is not None:
- candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
- if not candidate.strip().startswith("["):
- candidate = "[" + candidate + "]"
- # parse the candidate JSON into a dictionary
- candidates = json.loads(candidate)
- if not isinstance(candidates, list):
- candidates = [candidates]
- for candidate_dict in candidates:
- checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
- if checked_candidate is not None:
- matches.append(checked_candidate)
- except json.JSONDecodeError:
- # Ignore invalid JSON silently
- pass
-
- return _return(matches, start_pos)
diff --git a/modules/chat.py b/modules/chat.py
index 87e52851..02ae46e4 100644
--- a/modules/chat.py
+++ b/modules/chat.py
@@ -239,6 +239,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
name1=state['name1'],
name2=state['name2'],
user_bio=replace_character_names(state['user_bio'], state['name1'], state['name2']),
+ tools=state['tools'] if 'tools' in state else None,
)
messages = []
@@ -1186,14 +1187,10 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
# Load tools if any are selected
selected = state.get('selected_tools', [])
- parseToolCall = None
+ parse_tool_call = None
if selected:
from modules.tool_use import load_tools, execute_tool
- try:
- from extensions.openai.utils import parseToolCall, getToolCallId
- except ImportError:
- logger.warning('Tool calling requires the openai extension for parseToolCall. Disabling tools.')
- selected = []
+ from modules.tool_parsing import parse_tool_call, get_tool_call_id
if selected:
tool_defs, tool_executors = load_tools(selected)
@@ -1253,7 +1250,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
last_save_time = current_time
# Early stop on tool call detection
- if tool_func_names and parseToolCall(history['internal'][-1][1], tool_func_names):
+ if tool_func_names and parse_tool_call(history['internal'][-1][1], tool_func_names):
break
# Save the model's visible output before re-applying visible_prefix,
@@ -1285,7 +1282,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
break
answer = history['internal'][-1][1]
- parsed_calls, content_prefix = parseToolCall(answer, tool_func_names, return_prefix=True) if answer else (None, '')
+ parsed_calls, content_prefix = parse_tool_call(answer, tool_func_names, return_prefix=True) if answer else (None, '')
if not parsed_calls:
break # No tool calls — done
@@ -1302,7 +1299,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
serialized = []
tc_headers = []
for tc in parsed_calls:
- tc['id'] = getToolCallId()
+ tc['id'] = get_tool_call_id()
fn_name = tc['function']['name']
fn_args = tc['function'].get('arguments', {})
@@ -1343,7 +1340,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
# Preserve thinking block and intermediate text from this turn.
# content_prefix is the raw text before tool call syntax (returned
- # by parseToolCall); HTML-escape it and extract thinking to get
+ # by parse_tool_call); HTML-escape it and extract thinking to get
# the content the user should see.
content_text = html.escape(content_prefix)
thinking_content, intermediate = extract_thinking_block(content_text)
diff --git a/modules/tool_parsing.py b/modules/tool_parsing.py
new file mode 100644
index 00000000..460188d3
--- /dev/null
+++ b/modules/tool_parsing.py
@@ -0,0 +1,553 @@
+import json
+import random
+import re
+
+
+def get_tool_call_id() -> str:
+ letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
+ b = [random.choice(letter_bytes) for _ in range(8)]
+ return "call_" + "".join(b).lower()
+
+
+def check_and_sanitize_tool_call_candidate(candidate_dict: dict, tool_names: list[str]):
+ # check if property 'function' exists and is a dictionary, otherwise adapt dict
+ if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
+ candidate_dict = {"type": "function", "function": candidate_dict}
+ if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
+ candidate_dict['name'] = candidate_dict['function']
+ del candidate_dict['function']
+ candidate_dict = {"type": "function", "function": candidate_dict}
+ if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
+ # check if 'name' exists within 'function' and is part of known tools
+ if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
+ candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
+ # map property 'parameters' used by some older models to 'arguments'
+ if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
+ candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
+ del candidate_dict["function"]["parameters"]
+ return candidate_dict
+ return None
+
+
+def _extract_balanced_json(text: str, start: int) -> str | None:
+ """Extract a balanced JSON object from text starting at the given position.
+
+ Walks through the string tracking brace depth and string boundaries
+ to correctly handle arbitrary nesting levels.
+ """
+ if start >= len(text) or text[start] != '{':
+ return None
+ depth = 0
+ in_string = False
+ escape_next = False
+ for i in range(start, len(text)):
+ c = text[i]
+ if escape_next:
+ escape_next = False
+ continue
+ if c == '\\' and in_string:
+ escape_next = True
+ continue
+ if c == '"':
+ in_string = not in_string
+ continue
+ if in_string:
+ continue
+ if c == '{':
+ depth += 1
+ elif c == '}':
+ depth -= 1
+ if depth == 0:
+ return text[start:i + 1]
+ return None
+
+
+def _parse_channel_tool_calls(answer: str, tool_names: list[str]):
+ """Parse channel-based tool calls used by GPT-OSS and similar models.
+
+ Format:
+ <|start|>assistant to=functions.func_name<|channel|>commentary json<|message|>{"arg": "value"}
+ or:
+ <|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"}
+ """
+ matches = []
+ start_pos = None
+ # Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format)
+ # Pattern 2: to=functions.NAME after <|channel|> (alternative format)
+ patterns = [
+ r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>',
+ r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>',
+ ]
+ for pattern in patterns:
+ for m in re.finditer(pattern, answer):
+ func_name = m.group(1).strip()
+ if func_name not in tool_names:
+ continue
+ json_str = _extract_balanced_json(answer, m.end())
+ if json_str is None:
+ continue
+ try:
+ arguments = json.loads(json_str)
+ if start_pos is None:
+ prefix = answer.rfind('<|start|>assistant', 0, m.start())
+ start_pos = prefix if prefix != -1 else m.start()
+ matches.append({
+ "type": "function",
+ "function": {
+ "name": func_name,
+ "arguments": arguments
+ }
+ })
+ except json.JSONDecodeError:
+ pass
+ if matches:
+ break
+ return matches, start_pos
+
+
+def _parse_mistral_token_tool_calls(answer: str, tool_names: list[str]):
+ """Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens.
+
+ Format:
+ [TOOL_CALLS]func_name[ARGS]{"arg": "value"}
+ """
+ matches = []
+ start_pos = None
+ for m in re.finditer(
+ r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*',
+ answer
+ ):
+ func_name = m.group(1).strip()
+ if func_name not in tool_names:
+ continue
+ json_str = _extract_balanced_json(answer, m.end())
+ if json_str is None:
+ continue
+ try:
+ arguments = json.loads(json_str)
+ if start_pos is None:
+ start_pos = m.start()
+ matches.append({
+ "type": "function",
+ "function": {
+ "name": func_name,
+ "arguments": arguments
+ }
+ })
+ except json.JSONDecodeError:
+ pass
+ return matches, start_pos
+
+
+def _parse_bare_name_tool_calls(answer: str, tool_names: list[str]):
+ """Parse bare function-name style tool calls used by Mistral and similar models.
+
+ Format:
+ functionName{"arg": "value"}
+ Multiple calls are concatenated directly or separated by whitespace.
+ """
+ matches = []
+ start_pos = None
+ # Match tool name followed by opening brace, then extract balanced JSON
+ escaped_names = [re.escape(name) for name in tool_names]
+ pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
+ for match in re.finditer(pattern, answer):
+ text = match.group(0)
+ name = None
+ for n in tool_names:
+ if text.startswith(n):
+ name = n
+ break
+ if not name:
+ continue
+ brace_start = match.end() - 1
+ json_str = _extract_balanced_json(answer, brace_start)
+ if json_str is None:
+ continue
+ try:
+ arguments = json.loads(json_str)
+ if start_pos is None:
+ start_pos = match.start()
+ matches.append({
+ "type": "function",
+ "function": {
+ "name": name,
+ "arguments": arguments
+ }
+ })
+ except json.JSONDecodeError:
+ pass
+ return matches, start_pos
+
+
+def _parse_xml_param_tool_calls(answer: str, tool_names: list[str]):
+ """Parse XML-parameter style tool calls used by Qwen3.5 and similar models.
+
+ Format:
+
+
+ value
+
+
+ """
+ matches = []
+ start_pos = None
+ for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL):
+ tc_content = tc_match.group(1)
+ func_match = re.search(r']+)>', tc_content)
+ if not func_match:
+ continue
+ func_name = func_match.group(1).strip()
+ if func_name not in tool_names:
+ continue
+ arguments = {}
+ for param_match in re.finditer(r']+)>\s*(.*?)\s*', tc_content, re.DOTALL):
+ param_name = param_match.group(1).strip()
+ param_value = param_match.group(2).strip()
+ try:
+ param_value = json.loads(param_value)
+ except (json.JSONDecodeError, ValueError):
+ pass # keep as string
+ arguments[param_name] = param_value
+ if start_pos is None:
+ start_pos = tc_match.start()
+ matches.append({
+ "type": "function",
+ "function": {
+ "name": func_name,
+ "arguments": arguments
+ }
+ })
+ return matches, start_pos
+
+
+def _parse_kimi_tool_calls(answer: str, tool_names: list[str]):
+ """Parse Kimi-K2-style tool calls using pipe-delimited tokens.
+
+ Format:
+ <|tool_calls_section_begin|>
+ <|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
+ <|tool_calls_section_end|>
+ """
+ matches = []
+ start_pos = None
+ for m in re.finditer(
+ r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
+ answer
+ ):
+ func_name = m.group(1).strip()
+ if func_name not in tool_names:
+ continue
+ json_str = _extract_balanced_json(answer, m.end())
+ if json_str is None:
+ continue
+ try:
+ arguments = json.loads(json_str)
+ if start_pos is None:
+ # Check for section begin marker before the call marker
+ section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start())
+ start_pos = section if section != -1 else m.start()
+ matches.append({
+ "type": "function",
+ "function": {
+ "name": func_name,
+ "arguments": arguments
+ }
+ })
+ except json.JSONDecodeError:
+ pass
+ return matches, start_pos
+
+
+def _parse_minimax_tool_calls(answer: str, tool_names: list[str]):
+ """Parse MiniMax-style tool calls using invoke/parameter XML tags.
+
+ Format:
+
+
+ value
+
+
+ """
+ matches = []
+ start_pos = None
+ for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL):
+ tc_content = tc_match.group(1)
+ # Split on to handle multiple parallel calls in one block
+ for invoke_match in re.finditer(r'(.*?)', tc_content, re.DOTALL):
+ func_name = invoke_match.group(1).strip()
+ if func_name not in tool_names:
+ continue
+ invoke_body = invoke_match.group(2)
+ arguments = {}
+ for param_match in re.finditer(r'\s*(.*?)\s*', invoke_body, re.DOTALL):
+ param_name = param_match.group(1).strip()
+ param_value = param_match.group(2).strip()
+ try:
+ param_value = json.loads(param_value)
+ except (json.JSONDecodeError, ValueError):
+ pass # keep as string
+ arguments[param_name] = param_value
+ if start_pos is None:
+ start_pos = tc_match.start()
+ matches.append({
+ "type": "function",
+ "function": {
+ "name": func_name,
+ "arguments": arguments
+ }
+ })
+ return matches, start_pos
+
+
+def _parse_deep_seek_tool_calls(answer: str, tool_names: list[str]):
+ """Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
+
+ Format:
+ <|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|>
+ """
+ matches = []
+ start_pos = None
+ for m in re.finditer(
+ r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*',
+ answer
+ ):
+ func_name = m.group(1).strip()
+ if func_name not in tool_names:
+ continue
+ json_str = _extract_balanced_json(answer, m.end())
+ if json_str is None:
+ continue
+ try:
+ arguments = json.loads(json_str)
+ if start_pos is None:
+ # Check for section begin marker before the call marker
+ section = answer.rfind('<|tool▁calls▁begin|>', 0, m.start())
+ start_pos = section if section != -1 else m.start()
+ matches.append({
+ "type": "function",
+ "function": {
+ "name": func_name,
+ "arguments": arguments
+ }
+ })
+ except json.JSONDecodeError:
+ pass
+ return matches, start_pos
+
+
+def _parse_glm_tool_calls(answer: str, tool_names: list[str]):
+ """Parse GLM-style tool calls using arg_key/arg_value XML pairs.
+
+ Format:
+ function_name
+ key1
+ value1
+
+ """
+ matches = []
+ start_pos = None
+ for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL):
+ tc_content = tc_match.group(1)
+ # First non-tag text is the function name
+ name_match = re.match(r'([^<\s]+)', tc_content.strip())
+ if not name_match:
+ continue
+ func_name = name_match.group(1).strip()
+ if func_name not in tool_names:
+ continue
+ # Extract arg_key/arg_value pairs
+ keys = [k.group(1).strip() for k in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)]
+ vals = [v.group(1).strip() for v in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)]
+ if len(keys) != len(vals):
+ continue
+ arguments = {}
+ for k, v in zip(keys, vals):
+ try:
+ v = json.loads(v)
+ except (json.JSONDecodeError, ValueError):
+ pass # keep as string
+ arguments[k] = v
+ if start_pos is None:
+ start_pos = tc_match.start()
+ matches.append({
+ "type": "function",
+ "function": {
+ "name": func_name,
+ "arguments": arguments
+ }
+ })
+ return matches, start_pos
+
+
+def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]):
+ """Parse pythonic-style tool calls used by Llama 4 and similar models.
+
+ Format:
+ [func_name(param1="value1", param2="value2"), func_name2(...)]
+ """
+ matches = []
+ start_pos = None
+ # Match a bracketed list of function calls
+ bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
+ if not bracket_match:
+ return matches, start_pos
+
+ inner = bracket_match.group(1)
+
+ # Build pattern for known tool names
+ escaped_names = [re.escape(name) for name in tool_names]
+ name_pattern = '|'.join(escaped_names)
+
+ for call_match in re.finditer(
+ r'(' + name_pattern + r')\(([^)]*)\)',
+ inner
+ ):
+ func_name = call_match.group(1)
+ params_str = call_match.group(2).strip()
+ arguments = {}
+
+ if params_str:
+ # Parse key="value" pairs, handling commas inside quoted values
+ for param_match in re.finditer(
+ r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
+ params_str
+ ):
+ param_name = param_match.group(1)
+ param_value = param_match.group(2).strip()
+ # Strip surrounding quotes
+ if (param_value.startswith('"') and param_value.endswith('"')) or \
+ (param_value.startswith("'") and param_value.endswith("'")):
+ param_value = param_value[1:-1]
+ # Try to parse as JSON for numeric/bool/null values
+ try:
+ param_value = json.loads(param_value)
+ except (json.JSONDecodeError, ValueError):
+ pass
+ arguments[param_name] = param_value
+
+ if start_pos is None:
+ start_pos = bracket_match.start()
+ matches.append({
+ "type": "function",
+ "function": {
+ "name": func_name,
+ "arguments": arguments
+ }
+ })
+
+ return matches, start_pos
+
+
+def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False):
+ matches = []
+ start_pos = None
+
+ def _return(matches, start_pos):
+ if return_prefix:
+ prefix = answer[:start_pos] if matches and start_pos is not None else ''
+ return matches, prefix
+ return matches
+
+ # Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters)
+ matches, start_pos = _parse_deep_seek_tool_calls(answer, tool_names)
+ if matches:
+ return _return(matches, start_pos)
+
+ # Check for Kimi-K2-style tool calls (pipe-delimited tokens)
+ matches, start_pos = _parse_kimi_tool_calls(answer, tool_names)
+ if matches:
+ return _return(matches, start_pos)
+
+ # Check for channel-based tool calls (e.g. GPT-OSS format)
+ matches, start_pos = _parse_channel_tool_calls(answer, tool_names)
+ if matches:
+ return _return(matches, start_pos)
+
+ # Check for MiniMax-style tool calls (invoke/parameter XML tags)
+ matches, start_pos = _parse_minimax_tool_calls(answer, tool_names)
+ if matches:
+ return _return(matches, start_pos)
+
+ # Check for GLM-style tool calls (arg_key/arg_value XML pairs)
+ matches, start_pos = _parse_glm_tool_calls(answer, tool_names)
+ if matches:
+ return _return(matches, start_pos)
+
+ # Check for XML-parameter style tool calls (e.g. Qwen3.5 format)
+ matches, start_pos = _parse_xml_param_tool_calls(answer, tool_names)
+ if matches:
+ return _return(matches, start_pos)
+
+ # Check for Mistral/Devstral-style tool calls ([TOOL_CALLS]name[ARGS]json)
+ matches, start_pos = _parse_mistral_token_tool_calls(answer, tool_names)
+ if matches:
+ return _return(matches, start_pos)
+
+ # Check for bare function-name style tool calls (e.g. Mistral format)
+ matches, start_pos = _parse_bare_name_tool_calls(answer, tool_names)
+ if matches:
+ return _return(matches, start_pos)
+
+ # Check for pythonic-style tool calls (e.g. Llama 4 format)
+ matches, start_pos = _parse_pythonic_tool_calls(answer, tool_names)
+ if matches:
+ return _return(matches, start_pos)
+
+ # Define the regex pattern to find the JSON content wrapped in , , , and other tags observed from various models
+ patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)\1>"]
+
+ for pattern in patterns:
+ for match in re.finditer(pattern, answer, re.DOTALL):
+ if match.group(2) is None:
+ continue
+ # remove backtick wraps if present
+ candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
+ candidate = re.sub(r"```$", "", candidate.strip())
+ # unwrap inner tags
+ candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
+ # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
+ if re.search(r"\}\s*\n\s*\{", candidate) is not None:
+ candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
+ if not candidate.strip().startswith("["):
+ candidate = "[" + candidate + "]"
+
+ candidates = []
+ try:
+ # parse the candidate JSON into a dictionary
+ candidates = json.loads(candidate)
+ if not isinstance(candidates, list):
+ candidates = [candidates]
+ except json.JSONDecodeError:
+ # Ignore invalid JSON silently
+ continue
+
+ for candidate_dict in candidates:
+ checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
+ if checked_candidate is not None:
+ if start_pos is None:
+ start_pos = match.start()
+ matches.append(checked_candidate)
+
+ # last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
+ if len(matches) == 0:
+ try:
+ candidate = answer
+ # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
+ if re.search(r"\}\s*\n\s*\{", candidate) is not None:
+ candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
+ if not candidate.strip().startswith("["):
+ candidate = "[" + candidate + "]"
+ # parse the candidate JSON into a dictionary
+ candidates = json.loads(candidate)
+ if not isinstance(candidates, list):
+ candidates = [candidates]
+ for candidate_dict in candidates:
+ checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
+ if checked_candidate is not None:
+ matches.append(checked_candidate)
+ except json.JSONDecodeError:
+ # Ignore invalid JSON silently
+ pass
+
+ return _return(matches, start_pos)
diff --git a/modules/tool_use.py b/modules/tool_use.py
index 55424853..e22b1798 100644
--- a/modules/tool_use.py
+++ b/modules/tool_use.py
@@ -3,7 +3,7 @@ import json
from modules import shared
from modules.logging_colors import logger
-from modules.utils import natural_keys
+from modules.utils import natural_keys, sanitize_filename
def get_available_tools():
@@ -23,6 +23,10 @@ def load_tools(selected_names):
tool_defs = []
executors = {}
for name in selected_names:
+ name = sanitize_filename(name)
+ if not name:
+ continue
+
path = shared.user_data_dir / 'tools' / f'{name}.py'
if not path.exists():
continue
diff --git a/modules/ui_chat.py b/modules/ui_chat.py
index ce9fc0a2..0acf9c04 100644
--- a/modules/ui_chat.py
+++ b/modules/ui_chat.py
@@ -97,7 +97,7 @@ def create_ui():
shared.gradio['tools_refresh'].click(fn=lambda: gr.update(choices=get_available_tools()), inputs=[], outputs=[shared.gradio['selected_tools']])
def sync_web_tools(selected):
- if 'web_search' in selected and 'fetch_webpage' not in selected:
+ if 'web_search' in selected and 'fetch_webpage' not in selected and 'fetch_webpage' in get_available_tools():
selected.append('fetch_webpage')
return gr.update(value=selected)
diff --git a/modules/web_search.py b/modules/web_search.py
index 754dd111..216d7933 100644
--- a/modules/web_search.py
+++ b/modules/web_search.py
@@ -1,11 +1,13 @@
import concurrent.futures
import html
+import ipaddress
import random
import re
+import socket
import urllib.request
from concurrent.futures import as_completed
from datetime import datetime
-from urllib.parse import quote_plus
+from urllib.parse import quote_plus, urlparse
import requests
@@ -13,6 +15,26 @@ from modules import shared
from modules.logging_colors import logger
+def _validate_url(url):
+ """Validate that a URL is safe to fetch (not targeting private/internal networks)."""
+ parsed = urlparse(url)
+ if parsed.scheme not in ('http', 'https'):
+ raise ValueError(f"Unsupported URL scheme: {parsed.scheme}")
+
+ hostname = parsed.hostname
+ if not hostname:
+ raise ValueError("No hostname in URL")
+
+ # Resolve hostname and check all returned addresses
+ try:
+ for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
+ ip = ipaddress.ip_address(sockaddr[0])
+ if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
+ raise ValueError(f"Access to private/internal address {ip} is blocked")
+ except socket.gaierror:
+ raise ValueError(f"Could not resolve hostname: {hostname}")
+
+
def get_current_timestamp():
"""Returns the current time in 24-hour format"""
return datetime.now().strftime('%b %d, %Y %H:%M')
@@ -25,11 +47,20 @@ def download_web_page(url, timeout=10, include_links=False):
import html2text
try:
+ _validate_url(url)
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
- response = requests.get(url, headers=headers, timeout=timeout)
- response.raise_for_status() # Raise an exception for bad status codes
+ max_redirects = 5
+ for _ in range(max_redirects):
+ response = requests.get(url, headers=headers, timeout=timeout, allow_redirects=False)
+ if response.is_redirect and 'Location' in response.headers:
+ url = response.headers['Location']
+ _validate_url(url)
+ else:
+ break
+
+ response.raise_for_status()
# Initialize the HTML to Markdown converter
h = html2text.HTML2Text()