mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-17 19:04:39 +01:00
Clean up tool calling code
This commit is contained in:
parent
cb88066d15
commit
e8d1c66303
|
|
@ -11,7 +11,8 @@ from pydantic import ValidationError
|
|||
|
||||
from extensions.openai.errors import InvalidRequestError
|
||||
from extensions.openai.typing import ToolDefinition
|
||||
from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
|
||||
from extensions.openai.utils import debug_msg
|
||||
from modules.tool_parsing import get_tool_call_id, parse_tool_call
|
||||
from modules import shared
|
||||
from modules.reasoning import extract_reasoning
|
||||
from modules.chat import (
|
||||
|
|
@ -491,10 +492,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
answer = a['internal'][-1][1]
|
||||
|
||||
if supported_tools is not None:
|
||||
tool_call = parseToolCall(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
|
||||
tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
|
||||
if len(tool_call) > 0:
|
||||
for tc in tool_call:
|
||||
tc["id"] = getToolCallId()
|
||||
tc["id"] = get_tool_call_id()
|
||||
if stream:
|
||||
tc["index"] = len(tool_calls)
|
||||
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
|
||||
|
|
|
|||
|
|
@ -1,8 +1,5 @@
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
|
|
@ -55,558 +52,3 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star
|
|||
time.sleep(3)
|
||||
|
||||
raise Exception('Could not start cloudflared.')
|
||||
|
||||
|
||||
def getToolCallId() -> str:
|
||||
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b = [random.choice(letter_bytes) for _ in range(8)]
|
||||
return "call_" + "".join(b).lower()
|
||||
|
||||
|
||||
def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str]):
|
||||
# check if property 'function' exists and is a dictionary, otherwise adapt dict
|
||||
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
|
||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
|
||||
candidate_dict['name'] = candidate_dict['function']
|
||||
del candidate_dict['function']
|
||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
|
||||
# check if 'name' exists within 'function' and is part of known tools
|
||||
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
|
||||
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
|
||||
# map property 'parameters' used by some older models to 'arguments'
|
||||
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
|
||||
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
|
||||
del candidate_dict["function"]["parameters"]
|
||||
return candidate_dict
|
||||
return None
|
||||
|
||||
|
||||
def _extractBalancedJson(text: str, start: int) -> str | None:
|
||||
"""Extract a balanced JSON object from text starting at the given position.
|
||||
|
||||
Walks through the string tracking brace depth and string boundaries
|
||||
to correctly handle arbitrary nesting levels.
|
||||
"""
|
||||
if start >= len(text) or text[start] != '{':
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
for i in range(start, len(text)):
|
||||
c = text[i]
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
if c == '\\' and in_string:
|
||||
escape_next = True
|
||||
continue
|
||||
if c == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if c == '{':
|
||||
depth += 1
|
||||
elif c == '}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start:i + 1]
|
||||
return None
|
||||
|
||||
|
||||
def _parseChannelToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse channel-based tool calls used by GPT-OSS and similar models.
|
||||
|
||||
Format:
|
||||
<|start|>assistant to=functions.func_name<|channel|>commentary json<|message|>{"arg": "value"}
|
||||
or:
|
||||
<|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"}
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
# Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format)
|
||||
# Pattern 2: to=functions.NAME after <|channel|> (alternative format)
|
||||
patterns = [
|
||||
r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>',
|
||||
r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>',
|
||||
]
|
||||
for pattern in patterns:
|
||||
for m in re.finditer(pattern, answer):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extractBalancedJson(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
prefix = answer.rfind('<|start|>assistant', 0, m.start())
|
||||
start_pos = prefix if prefix != -1 else m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if matches:
|
||||
break
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parseMistralTokenToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens.
|
||||
|
||||
Format:
|
||||
[TOOL_CALLS]func_name[ARGS]{"arg": "value"}
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for m in re.finditer(
|
||||
r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extractBalancedJson(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
start_pos = m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parseBareNameToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse bare function-name style tool calls used by Mistral and similar models.
|
||||
|
||||
Format:
|
||||
functionName{"arg": "value"}
|
||||
Multiple calls are concatenated directly or separated by whitespace.
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
# Match tool name followed by opening brace, then extract balanced JSON
|
||||
escaped_names = [re.escape(name) for name in tool_names]
|
||||
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
|
||||
for match in re.finditer(pattern, answer):
|
||||
text = match.group(0)
|
||||
name = None
|
||||
for n in tool_names:
|
||||
if text.startswith(n):
|
||||
name = n
|
||||
break
|
||||
if not name:
|
||||
continue
|
||||
brace_start = match.end() - 1
|
||||
json_str = _extractBalancedJson(answer, brace_start)
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
start_pos = match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parseXmlParamToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse XML-parameter style tool calls used by Qwen3.5 and similar models.
|
||||
|
||||
Format:
|
||||
<tool_call>
|
||||
<function=function_name>
|
||||
<parameter=param_name>value</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
func_match = re.search(r'<function=([^>]+)>', tc_content)
|
||||
if not func_match:
|
||||
continue
|
||||
func_name = func_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
arguments = {}
|
||||
for param_match in re.finditer(r'<parameter=([^>]+)>\s*(.*?)\s*</parameter>', tc_content, re.DOTALL):
|
||||
param_name = param_match.group(1).strip()
|
||||
param_value = param_match.group(2).strip()
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[param_name] = param_value
|
||||
if start_pos is None:
|
||||
start_pos = tc_match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parseKimiToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse Kimi-K2-style tool calls using pipe-delimited tokens.
|
||||
|
||||
Format:
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for m in re.finditer(
|
||||
r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extractBalancedJson(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
# Check for section begin marker before the call marker
|
||||
section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start())
|
||||
start_pos = section if section != -1 else m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parseMiniMaxToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse MiniMax-style tool calls using invoke/parameter XML tags.
|
||||
|
||||
Format:
|
||||
<minimax:tool_call>
|
||||
<invoke name="function_name">
|
||||
<parameter name="param_name">value</parameter>
|
||||
</invoke>
|
||||
</minimax:tool_call>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for tc_match in re.finditer(r'<minimax:tool_call>\s*(.*?)\s*</minimax:tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
# Split on <invoke> to handle multiple parallel calls in one block
|
||||
for invoke_match in re.finditer(r'<invoke\s+name="([^"]+)">(.*?)</invoke>', tc_content, re.DOTALL):
|
||||
func_name = invoke_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
invoke_body = invoke_match.group(2)
|
||||
arguments = {}
|
||||
for param_match in re.finditer(r'<parameter\s+name="([^"]+)">\s*(.*?)\s*</parameter>', invoke_body, re.DOTALL):
|
||||
param_name = param_match.group(1).strip()
|
||||
param_value = param_match.group(2).strip()
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[param_name] = param_value
|
||||
if start_pos is None:
|
||||
start_pos = tc_match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parseDeepSeekToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
|
||||
|
||||
Format:
|
||||
<|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for m in re.finditer(
|
||||
r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extractBalancedJson(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
# Check for section begin marker before the call marker
|
||||
section = answer.rfind('<|tool▁calls▁begin|>', 0, m.start())
|
||||
start_pos = section if section != -1 else m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parseGlmToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse GLM-style tool calls using arg_key/arg_value XML pairs.
|
||||
|
||||
Format:
|
||||
<tool_call>function_name
|
||||
<arg_key>key1</arg_key>
|
||||
<arg_value>value1</arg_value>
|
||||
</tool_call>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
# First non-tag text is the function name
|
||||
name_match = re.match(r'([^<\s]+)', tc_content.strip())
|
||||
if not name_match:
|
||||
continue
|
||||
func_name = name_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
# Extract arg_key/arg_value pairs
|
||||
keys = [k.group(1).strip() for k in re.finditer(r'<arg_key>\s*(.*?)\s*</arg_key>', tc_content, re.DOTALL)]
|
||||
vals = [v.group(1).strip() for v in re.finditer(r'<arg_value>\s*(.*?)\s*</arg_value>', tc_content, re.DOTALL)]
|
||||
if len(keys) != len(vals):
|
||||
continue
|
||||
arguments = {}
|
||||
for k, v in zip(keys, vals):
|
||||
try:
|
||||
v = json.loads(v)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[k] = v
|
||||
if start_pos is None:
|
||||
start_pos = tc_match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parsePythonicToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
|
||||
|
||||
Format:
|
||||
[func_name(param1="value1", param2="value2"), func_name2(...)]
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
# Match a bracketed list of function calls
|
||||
bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
|
||||
if not bracket_match:
|
||||
return matches, start_pos
|
||||
|
||||
inner = bracket_match.group(1)
|
||||
|
||||
# Build pattern for known tool names
|
||||
escaped_names = [re.escape(name) for name in tool_names]
|
||||
name_pattern = '|'.join(escaped_names)
|
||||
|
||||
for call_match in re.finditer(
|
||||
r'(' + name_pattern + r')\(([^)]*)\)',
|
||||
inner
|
||||
):
|
||||
func_name = call_match.group(1)
|
||||
params_str = call_match.group(2).strip()
|
||||
arguments = {}
|
||||
|
||||
if params_str:
|
||||
# Parse key="value" pairs, handling commas inside quoted values
|
||||
for param_match in re.finditer(
|
||||
r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
|
||||
params_str
|
||||
):
|
||||
param_name = param_match.group(1)
|
||||
param_value = param_match.group(2).strip()
|
||||
# Strip surrounding quotes
|
||||
if (param_value.startswith('"') and param_value.endswith('"')) or \
|
||||
(param_value.startswith("'") and param_value.endswith("'")):
|
||||
param_value = param_value[1:-1]
|
||||
# Try to parse as JSON for numeric/bool/null values
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
arguments[param_name] = param_value
|
||||
|
||||
if start_pos is None:
|
||||
start_pos = bracket_match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def parseToolCall(answer: str, tool_names: list[str], return_prefix: bool = False):
|
||||
matches = []
|
||||
start_pos = None
|
||||
|
||||
def _return(matches, start_pos):
|
||||
if return_prefix:
|
||||
prefix = answer[:start_pos] if matches and start_pos is not None else ''
|
||||
return matches, prefix
|
||||
return matches
|
||||
|
||||
# abort on very short answers to save computation cycles
|
||||
if len(answer) < 10:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters)
|
||||
matches, start_pos = _parseDeepSeekToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for Kimi-K2-style tool calls (pipe-delimited tokens)
|
||||
matches, start_pos = _parseKimiToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for channel-based tool calls (e.g. GPT-OSS format)
|
||||
matches, start_pos = _parseChannelToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for MiniMax-style tool calls (invoke/parameter XML tags)
|
||||
matches, start_pos = _parseMiniMaxToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for GLM-style tool calls (arg_key/arg_value XML pairs)
|
||||
matches, start_pos = _parseGlmToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for XML-parameter style tool calls (e.g. Qwen3.5 format)
|
||||
matches, start_pos = _parseXmlParamToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for Mistral/Devstral-style tool calls ([TOOL_CALLS]name[ARGS]json)
|
||||
matches, start_pos = _parseMistralTokenToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for bare function-name style tool calls (e.g. Mistral format)
|
||||
matches, start_pos = _parseBareNameToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for pythonic-style tool calls (e.g. Llama 4 format)
|
||||
matches, start_pos = _parsePythonicToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
|
||||
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
|
||||
|
||||
for pattern in patterns:
|
||||
for match in re.finditer(pattern, answer, re.DOTALL):
|
||||
# print(match.group(2))
|
||||
if match.group(2) is None:
|
||||
continue
|
||||
# remove backtick wraps if present
|
||||
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
|
||||
candidate = re.sub(r"```$", "", candidate.strip())
|
||||
# unwrap inner tags
|
||||
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
|
||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||
if not candidate.strip().startswith("["):
|
||||
candidate = "[" + candidate + "]"
|
||||
|
||||
candidates = []
|
||||
try:
|
||||
# parse the candidate JSON into a dictionary
|
||||
candidates = json.loads(candidate)
|
||||
if not isinstance(candidates, list):
|
||||
candidates = [candidates]
|
||||
except json.JSONDecodeError:
|
||||
# Ignore invalid JSON silently
|
||||
continue
|
||||
|
||||
for candidate_dict in candidates:
|
||||
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
|
||||
if checked_candidate is not None:
|
||||
if start_pos is None:
|
||||
start_pos = match.start()
|
||||
matches.append(checked_candidate)
|
||||
|
||||
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
|
||||
if len(matches) == 0:
|
||||
try:
|
||||
candidate = answer
|
||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||
if not candidate.strip().startswith("["):
|
||||
candidate = "[" + candidate + "]"
|
||||
# parse the candidate JSON into a dictionary
|
||||
candidates = json.loads(candidate)
|
||||
if not isinstance(candidates, list):
|
||||
candidates = [candidates]
|
||||
for candidate_dict in candidates:
|
||||
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
|
||||
if checked_candidate is not None:
|
||||
matches.append(checked_candidate)
|
||||
except json.JSONDecodeError:
|
||||
# Ignore invalid JSON silently
|
||||
pass
|
||||
|
||||
return _return(matches, start_pos)
|
||||
|
|
|
|||
|
|
@ -239,6 +239,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
name1=state['name1'],
|
||||
name2=state['name2'],
|
||||
user_bio=replace_character_names(state['user_bio'], state['name1'], state['name2']),
|
||||
tools=state['tools'] if 'tools' in state else None,
|
||||
)
|
||||
|
||||
messages = []
|
||||
|
|
@ -1186,14 +1187,10 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
|
||||
# Load tools if any are selected
|
||||
selected = state.get('selected_tools', [])
|
||||
parseToolCall = None
|
||||
parse_tool_call = None
|
||||
if selected:
|
||||
from modules.tool_use import load_tools, execute_tool
|
||||
try:
|
||||
from extensions.openai.utils import parseToolCall, getToolCallId
|
||||
except ImportError:
|
||||
logger.warning('Tool calling requires the openai extension for parseToolCall. Disabling tools.')
|
||||
selected = []
|
||||
from modules.tool_parsing import parse_tool_call, get_tool_call_id
|
||||
|
||||
if selected:
|
||||
tool_defs, tool_executors = load_tools(selected)
|
||||
|
|
@ -1253,7 +1250,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
last_save_time = current_time
|
||||
|
||||
# Early stop on tool call detection
|
||||
if tool_func_names and parseToolCall(history['internal'][-1][1], tool_func_names):
|
||||
if tool_func_names and parse_tool_call(history['internal'][-1][1], tool_func_names):
|
||||
break
|
||||
|
||||
# Save the model's visible output before re-applying visible_prefix,
|
||||
|
|
@ -1285,7 +1282,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
break
|
||||
|
||||
answer = history['internal'][-1][1]
|
||||
parsed_calls, content_prefix = parseToolCall(answer, tool_func_names, return_prefix=True) if answer else (None, '')
|
||||
parsed_calls, content_prefix = parse_tool_call(answer, tool_func_names, return_prefix=True) if answer else (None, '')
|
||||
|
||||
if not parsed_calls:
|
||||
break # No tool calls — done
|
||||
|
|
@ -1302,7 +1299,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
serialized = []
|
||||
tc_headers = []
|
||||
for tc in parsed_calls:
|
||||
tc['id'] = getToolCallId()
|
||||
tc['id'] = get_tool_call_id()
|
||||
fn_name = tc['function']['name']
|
||||
fn_args = tc['function'].get('arguments', {})
|
||||
|
||||
|
|
@ -1343,7 +1340,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
|
||||
# Preserve thinking block and intermediate text from this turn.
|
||||
# content_prefix is the raw text before tool call syntax (returned
|
||||
# by parseToolCall); HTML-escape it and extract thinking to get
|
||||
# by parse_tool_call); HTML-escape it and extract thinking to get
|
||||
# the content the user should see.
|
||||
content_text = html.escape(content_prefix)
|
||||
thinking_content, intermediate = extract_thinking_block(content_text)
|
||||
|
|
|
|||
553
modules/tool_parsing.py
Normal file
553
modules/tool_parsing.py
Normal file
|
|
@ -0,0 +1,553 @@
|
|||
import json
|
||||
import random
|
||||
import re
|
||||
|
||||
|
||||
def get_tool_call_id() -> str:
|
||||
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b = [random.choice(letter_bytes) for _ in range(8)]
|
||||
return "call_" + "".join(b).lower()
|
||||
|
||||
|
||||
def check_and_sanitize_tool_call_candidate(candidate_dict: dict, tool_names: list[str]):
|
||||
# check if property 'function' exists and is a dictionary, otherwise adapt dict
|
||||
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
|
||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
|
||||
candidate_dict['name'] = candidate_dict['function']
|
||||
del candidate_dict['function']
|
||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
|
||||
# check if 'name' exists within 'function' and is part of known tools
|
||||
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
|
||||
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
|
||||
# map property 'parameters' used by some older models to 'arguments'
|
||||
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
|
||||
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
|
||||
del candidate_dict["function"]["parameters"]
|
||||
return candidate_dict
|
||||
return None
|
||||
|
||||
|
||||
def _extract_balanced_json(text: str, start: int) -> str | None:
|
||||
"""Extract a balanced JSON object from text starting at the given position.
|
||||
|
||||
Walks through the string tracking brace depth and string boundaries
|
||||
to correctly handle arbitrary nesting levels.
|
||||
"""
|
||||
if start >= len(text) or text[start] != '{':
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
for i in range(start, len(text)):
|
||||
c = text[i]
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
if c == '\\' and in_string:
|
||||
escape_next = True
|
||||
continue
|
||||
if c == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if c == '{':
|
||||
depth += 1
|
||||
elif c == '}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start:i + 1]
|
||||
return None
|
||||
|
||||
|
||||
def _parse_channel_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse channel-based tool calls used by GPT-OSS and similar models.
|
||||
|
||||
Format:
|
||||
<|start|>assistant to=functions.func_name<|channel|>commentary json<|message|>{"arg": "value"}
|
||||
or:
|
||||
<|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"}
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
# Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format)
|
||||
# Pattern 2: to=functions.NAME after <|channel|> (alternative format)
|
||||
patterns = [
|
||||
r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>',
|
||||
r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>',
|
||||
]
|
||||
for pattern in patterns:
|
||||
for m in re.finditer(pattern, answer):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extract_balanced_json(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
prefix = answer.rfind('<|start|>assistant', 0, m.start())
|
||||
start_pos = prefix if prefix != -1 else m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if matches:
|
||||
break
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_mistral_token_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens.
|
||||
|
||||
Format:
|
||||
[TOOL_CALLS]func_name[ARGS]{"arg": "value"}
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for m in re.finditer(
|
||||
r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extract_balanced_json(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
start_pos = m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_bare_name_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse bare function-name style tool calls used by Mistral and similar models.
|
||||
|
||||
Format:
|
||||
functionName{"arg": "value"}
|
||||
Multiple calls are concatenated directly or separated by whitespace.
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
# Match tool name followed by opening brace, then extract balanced JSON
|
||||
escaped_names = [re.escape(name) for name in tool_names]
|
||||
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
|
||||
for match in re.finditer(pattern, answer):
|
||||
text = match.group(0)
|
||||
name = None
|
||||
for n in tool_names:
|
||||
if text.startswith(n):
|
||||
name = n
|
||||
break
|
||||
if not name:
|
||||
continue
|
||||
brace_start = match.end() - 1
|
||||
json_str = _extract_balanced_json(answer, brace_start)
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
start_pos = match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_xml_param_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse XML-parameter style tool calls used by Qwen3.5 and similar models.
|
||||
|
||||
Format:
|
||||
<tool_call>
|
||||
<function=function_name>
|
||||
<parameter=param_name>value</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
func_match = re.search(r'<function=([^>]+)>', tc_content)
|
||||
if not func_match:
|
||||
continue
|
||||
func_name = func_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
arguments = {}
|
||||
for param_match in re.finditer(r'<parameter=([^>]+)>\s*(.*?)\s*</parameter>', tc_content, re.DOTALL):
|
||||
param_name = param_match.group(1).strip()
|
||||
param_value = param_match.group(2).strip()
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[param_name] = param_value
|
||||
if start_pos is None:
|
||||
start_pos = tc_match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_kimi_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse Kimi-K2-style tool calls using pipe-delimited tokens.
|
||||
|
||||
Format:
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for m in re.finditer(
|
||||
r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extract_balanced_json(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
# Check for section begin marker before the call marker
|
||||
section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start())
|
||||
start_pos = section if section != -1 else m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_minimax_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse MiniMax-style tool calls using invoke/parameter XML tags.
|
||||
|
||||
Format:
|
||||
<minimax:tool_call>
|
||||
<invoke name="function_name">
|
||||
<parameter name="param_name">value</parameter>
|
||||
</invoke>
|
||||
</minimax:tool_call>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for tc_match in re.finditer(r'<minimax:tool_call>\s*(.*?)\s*</minimax:tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
# Split on <invoke> to handle multiple parallel calls in one block
|
||||
for invoke_match in re.finditer(r'<invoke\s+name="([^"]+)">(.*?)</invoke>', tc_content, re.DOTALL):
|
||||
func_name = invoke_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
invoke_body = invoke_match.group(2)
|
||||
arguments = {}
|
||||
for param_match in re.finditer(r'<parameter\s+name="([^"]+)">\s*(.*?)\s*</parameter>', invoke_body, re.DOTALL):
|
||||
param_name = param_match.group(1).strip()
|
||||
param_value = param_match.group(2).strip()
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[param_name] = param_value
|
||||
if start_pos is None:
|
||||
start_pos = tc_match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_deep_seek_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
|
||||
|
||||
Format:
|
||||
<|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for m in re.finditer(
|
||||
r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extract_balanced_json(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
# Check for section begin marker before the call marker
|
||||
section = answer.rfind('<|tool▁calls▁begin|>', 0, m.start())
|
||||
start_pos = section if section != -1 else m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_glm_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse GLM-style tool calls using arg_key/arg_value XML pairs.
|
||||
|
||||
Format:
|
||||
<tool_call>function_name
|
||||
<arg_key>key1</arg_key>
|
||||
<arg_value>value1</arg_value>
|
||||
</tool_call>
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
# First non-tag text is the function name
|
||||
name_match = re.match(r'([^<\s]+)', tc_content.strip())
|
||||
if not name_match:
|
||||
continue
|
||||
func_name = name_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
# Extract arg_key/arg_value pairs
|
||||
keys = [k.group(1).strip() for k in re.finditer(r'<arg_key>\s*(.*?)\s*</arg_key>', tc_content, re.DOTALL)]
|
||||
vals = [v.group(1).strip() for v in re.finditer(r'<arg_value>\s*(.*?)\s*</arg_value>', tc_content, re.DOTALL)]
|
||||
if len(keys) != len(vals):
|
||||
continue
|
||||
arguments = {}
|
||||
for k, v in zip(keys, vals):
|
||||
try:
|
||||
v = json.loads(v)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[k] = v
|
||||
if start_pos is None:
|
||||
start_pos = tc_match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]):
|
||||
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
|
||||
|
||||
Format:
|
||||
[func_name(param1="value1", param2="value2"), func_name2(...)]
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
# Match a bracketed list of function calls
|
||||
bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
|
||||
if not bracket_match:
|
||||
return matches, start_pos
|
||||
|
||||
inner = bracket_match.group(1)
|
||||
|
||||
# Build pattern for known tool names
|
||||
escaped_names = [re.escape(name) for name in tool_names]
|
||||
name_pattern = '|'.join(escaped_names)
|
||||
|
||||
for call_match in re.finditer(
|
||||
r'(' + name_pattern + r')\(([^)]*)\)',
|
||||
inner
|
||||
):
|
||||
func_name = call_match.group(1)
|
||||
params_str = call_match.group(2).strip()
|
||||
arguments = {}
|
||||
|
||||
if params_str:
|
||||
# Parse key="value" pairs, handling commas inside quoted values
|
||||
for param_match in re.finditer(
|
||||
r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
|
||||
params_str
|
||||
):
|
||||
param_name = param_match.group(1)
|
||||
param_value = param_match.group(2).strip()
|
||||
# Strip surrounding quotes
|
||||
if (param_value.startswith('"') and param_value.endswith('"')) or \
|
||||
(param_value.startswith("'") and param_value.endswith("'")):
|
||||
param_value = param_value[1:-1]
|
||||
# Try to parse as JSON for numeric/bool/null values
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
arguments[param_name] = param_value
|
||||
|
||||
if start_pos is None:
|
||||
start_pos = bracket_match.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False):
|
||||
matches = []
|
||||
start_pos = None
|
||||
|
||||
def _return(matches, start_pos):
|
||||
if return_prefix:
|
||||
prefix = answer[:start_pos] if matches and start_pos is not None else ''
|
||||
return matches, prefix
|
||||
return matches
|
||||
|
||||
# Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters)
|
||||
matches, start_pos = _parse_deep_seek_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for Kimi-K2-style tool calls (pipe-delimited tokens)
|
||||
matches, start_pos = _parse_kimi_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for channel-based tool calls (e.g. GPT-OSS format)
|
||||
matches, start_pos = _parse_channel_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for MiniMax-style tool calls (invoke/parameter XML tags)
|
||||
matches, start_pos = _parse_minimax_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for GLM-style tool calls (arg_key/arg_value XML pairs)
|
||||
matches, start_pos = _parse_glm_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for XML-parameter style tool calls (e.g. Qwen3.5 format)
|
||||
matches, start_pos = _parse_xml_param_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for Mistral/Devstral-style tool calls ([TOOL_CALLS]name[ARGS]json)
|
||||
matches, start_pos = _parse_mistral_token_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for bare function-name style tool calls (e.g. Mistral format)
|
||||
matches, start_pos = _parse_bare_name_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for pythonic-style tool calls (e.g. Llama 4 format)
|
||||
matches, start_pos = _parse_pythonic_tool_calls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
|
||||
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
|
||||
|
||||
for pattern in patterns:
|
||||
for match in re.finditer(pattern, answer, re.DOTALL):
|
||||
if match.group(2) is None:
|
||||
continue
|
||||
# remove backtick wraps if present
|
||||
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
|
||||
candidate = re.sub(r"```$", "", candidate.strip())
|
||||
# unwrap inner tags
|
||||
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
|
||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||
if not candidate.strip().startswith("["):
|
||||
candidate = "[" + candidate + "]"
|
||||
|
||||
candidates = []
|
||||
try:
|
||||
# parse the candidate JSON into a dictionary
|
||||
candidates = json.loads(candidate)
|
||||
if not isinstance(candidates, list):
|
||||
candidates = [candidates]
|
||||
except json.JSONDecodeError:
|
||||
# Ignore invalid JSON silently
|
||||
continue
|
||||
|
||||
for candidate_dict in candidates:
|
||||
checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
|
||||
if checked_candidate is not None:
|
||||
if start_pos is None:
|
||||
start_pos = match.start()
|
||||
matches.append(checked_candidate)
|
||||
|
||||
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
|
||||
if len(matches) == 0:
|
||||
try:
|
||||
candidate = answer
|
||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||
if not candidate.strip().startswith("["):
|
||||
candidate = "[" + candidate + "]"
|
||||
# parse the candidate JSON into a dictionary
|
||||
candidates = json.loads(candidate)
|
||||
if not isinstance(candidates, list):
|
||||
candidates = [candidates]
|
||||
for candidate_dict in candidates:
|
||||
checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names)
|
||||
if checked_candidate is not None:
|
||||
matches.append(checked_candidate)
|
||||
except json.JSONDecodeError:
|
||||
# Ignore invalid JSON silently
|
||||
pass
|
||||
|
||||
return _return(matches, start_pos)
|
||||
|
|
@ -3,7 +3,7 @@ import json
|
|||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.utils import natural_keys
|
||||
from modules.utils import natural_keys, sanitize_filename
|
||||
|
||||
|
||||
def get_available_tools():
|
||||
|
|
@ -23,6 +23,10 @@ def load_tools(selected_names):
|
|||
tool_defs = []
|
||||
executors = {}
|
||||
for name in selected_names:
|
||||
name = sanitize_filename(name)
|
||||
if not name:
|
||||
continue
|
||||
|
||||
path = shared.user_data_dir / 'tools' / f'{name}.py'
|
||||
if not path.exists():
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ def create_ui():
|
|||
shared.gradio['tools_refresh'].click(fn=lambda: gr.update(choices=get_available_tools()), inputs=[], outputs=[shared.gradio['selected_tools']])
|
||||
|
||||
def sync_web_tools(selected):
|
||||
if 'web_search' in selected and 'fetch_webpage' not in selected:
|
||||
if 'web_search' in selected and 'fetch_webpage' not in selected and 'fetch_webpage' in get_available_tools():
|
||||
selected.append('fetch_webpage')
|
||||
|
||||
return gr.update(value=selected)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
import concurrent.futures
|
||||
import html
|
||||
import ipaddress
|
||||
import random
|
||||
import re
|
||||
import socket
|
||||
import urllib.request
|
||||
from concurrent.futures import as_completed
|
||||
from datetime import datetime
|
||||
from urllib.parse import quote_plus
|
||||
from urllib.parse import quote_plus, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
|
|
@ -13,6 +15,26 @@ from modules import shared
|
|||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
def _validate_url(url):
|
||||
"""Validate that a URL is safe to fetch (not targeting private/internal networks)."""
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ('http', 'https'):
|
||||
raise ValueError(f"Unsupported URL scheme: {parsed.scheme}")
|
||||
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise ValueError("No hostname in URL")
|
||||
|
||||
# Resolve hostname and check all returned addresses
|
||||
try:
|
||||
for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
|
||||
ip = ipaddress.ip_address(sockaddr[0])
|
||||
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
|
||||
raise ValueError(f"Access to private/internal address {ip} is blocked")
|
||||
except socket.gaierror:
|
||||
raise ValueError(f"Could not resolve hostname: {hostname}")
|
||||
|
||||
|
||||
def get_current_timestamp():
|
||||
"""Returns the current time in 24-hour format"""
|
||||
return datetime.now().strftime('%b %d, %Y %H:%M')
|
||||
|
|
@ -25,11 +47,20 @@ def download_web_page(url, timeout=10, include_links=False):
|
|||
import html2text
|
||||
|
||||
try:
|
||||
_validate_url(url)
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||||
}
|
||||
response = requests.get(url, headers=headers, timeout=timeout)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
max_redirects = 5
|
||||
for _ in range(max_redirects):
|
||||
response = requests.get(url, headers=headers, timeout=timeout, allow_redirects=False)
|
||||
if response.is_redirect and 'Location' in response.headers:
|
||||
url = response.headers['Location']
|
||||
_validate_url(url)
|
||||
else:
|
||||
break
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
# Initialize the HTML to Markdown converter
|
||||
h = html2text.HTML2Text()
|
||||
|
|
|
|||
Loading…
Reference in a new issue