diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 290a5bc0..27defe42 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -11,7 +11,8 @@ from pydantic import ValidationError from extensions.openai.errors import InvalidRequestError from extensions.openai.typing import ToolDefinition -from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall +from extensions.openai.utils import debug_msg +from modules.tool_parsing import get_tool_call_id, parse_tool_call from modules import shared from modules.reasoning import extract_reasoning from modules.chat import ( @@ -491,10 +492,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p answer = a['internal'][-1][1] if supported_tools is not None: - tool_call = parseToolCall(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else [] + tool_call = parse_tool_call(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else [] if len(tool_call) > 0: for tc in tool_call: - tc["id"] = getToolCallId() + tc["id"] = get_tool_call_id() if stream: tc["index"] = len(tool_calls) tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"]) diff --git a/extensions/openai/utils.py b/extensions/openai/utils.py index b179c267..2b414769 100644 --- a/extensions/openai/utils.py +++ b/extensions/openai/utils.py @@ -1,8 +1,5 @@ import base64 -import json import os -import random -import re import time import traceback from typing import Callable, Optional @@ -55,558 +52,3 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star time.sleep(3) raise Exception('Could not start cloudflared.') - - -def getToolCallId() -> str: - letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789" - b = [random.choice(letter_bytes) for _ in range(8)] - return "call_" + "".join(b).lower() - - -def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str]): - # check if property 'function' exists and is a dictionary, otherwise adapt dict - if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str): - candidate_dict = {"type": "function", "function": candidate_dict} - if 'function' in candidate_dict and isinstance(candidate_dict['function'], str): - candidate_dict['name'] = candidate_dict['function'] - del candidate_dict['function'] - candidate_dict = {"type": "function", "function": candidate_dict} - if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict): - # check if 'name' exists within 'function' and is part of known tools - if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names: - candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value - # map property 'parameters' used by some older models to 'arguments' - if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]: - candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"] - del candidate_dict["function"]["parameters"] - return candidate_dict - return None - - -def _extractBalancedJson(text: str, start: int) -> str | None: - """Extract a balanced JSON object from text starting at the given position. - - Walks through the string tracking brace depth and string boundaries - to correctly handle arbitrary nesting levels. - """ - if start >= len(text) or text[start] != '{': - return None - depth = 0 - in_string = False - escape_next = False - for i in range(start, len(text)): - c = text[i] - if escape_next: - escape_next = False - continue - if c == '\\' and in_string: - escape_next = True - continue - if c == '"': - in_string = not in_string - continue - if in_string: - continue - if c == '{': - depth += 1 - elif c == '}': - depth -= 1 - if depth == 0: - return text[start:i + 1] - return None - - -def _parseChannelToolCalls(answer: str, tool_names: list[str]): - """Parse channel-based tool calls used by GPT-OSS and similar models. - - Format: - <|start|>assistant to=functions.func_name<|channel|>commentary json<|message|>{"arg": "value"} - or: - <|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"} - """ - matches = [] - start_pos = None - # Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format) - # Pattern 2: to=functions.NAME after <|channel|> (alternative format) - patterns = [ - r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>', - r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>', - ] - for pattern in patterns: - for m in re.finditer(pattern, answer): - func_name = m.group(1).strip() - if func_name not in tool_names: - continue - json_str = _extractBalancedJson(answer, m.end()) - if json_str is None: - continue - try: - arguments = json.loads(json_str) - if start_pos is None: - prefix = answer.rfind('<|start|>assistant', 0, m.start()) - start_pos = prefix if prefix != -1 else m.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - except json.JSONDecodeError: - pass - if matches: - break - return matches, start_pos - - -def _parseMistralTokenToolCalls(answer: str, tool_names: list[str]): - """Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens. - - Format: - [TOOL_CALLS]func_name[ARGS]{"arg": "value"} - """ - matches = [] - start_pos = None - for m in re.finditer( - r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*', - answer - ): - func_name = m.group(1).strip() - if func_name not in tool_names: - continue - json_str = _extractBalancedJson(answer, m.end()) - if json_str is None: - continue - try: - arguments = json.loads(json_str) - if start_pos is None: - start_pos = m.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - except json.JSONDecodeError: - pass - return matches, start_pos - - -def _parseBareNameToolCalls(answer: str, tool_names: list[str]): - """Parse bare function-name style tool calls used by Mistral and similar models. - - Format: - functionName{"arg": "value"} - Multiple calls are concatenated directly or separated by whitespace. - """ - matches = [] - start_pos = None - # Match tool name followed by opening brace, then extract balanced JSON - escaped_names = [re.escape(name) for name in tool_names] - pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{' - for match in re.finditer(pattern, answer): - text = match.group(0) - name = None - for n in tool_names: - if text.startswith(n): - name = n - break - if not name: - continue - brace_start = match.end() - 1 - json_str = _extractBalancedJson(answer, brace_start) - if json_str is None: - continue - try: - arguments = json.loads(json_str) - if start_pos is None: - start_pos = match.start() - matches.append({ - "type": "function", - "function": { - "name": name, - "arguments": arguments - } - }) - except json.JSONDecodeError: - pass - return matches, start_pos - - -def _parseXmlParamToolCalls(answer: str, tool_names: list[str]): - """Parse XML-parameter style tool calls used by Qwen3.5 and similar models. - - Format: - - - value - - - """ - matches = [] - start_pos = None - for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): - tc_content = tc_match.group(1) - func_match = re.search(r']+)>', tc_content) - if not func_match: - continue - func_name = func_match.group(1).strip() - if func_name not in tool_names: - continue - arguments = {} - for param_match in re.finditer(r']+)>\s*(.*?)\s*', tc_content, re.DOTALL): - param_name = param_match.group(1).strip() - param_value = param_match.group(2).strip() - try: - param_value = json.loads(param_value) - except (json.JSONDecodeError, ValueError): - pass # keep as string - arguments[param_name] = param_value - if start_pos is None: - start_pos = tc_match.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - return matches, start_pos - - -def _parseKimiToolCalls(answer: str, tool_names: list[str]): - """Parse Kimi-K2-style tool calls using pipe-delimited tokens. - - Format: - <|tool_calls_section_begin|> - <|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|> - <|tool_calls_section_end|> - """ - matches = [] - start_pos = None - for m in re.finditer( - r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*', - answer - ): - func_name = m.group(1).strip() - if func_name not in tool_names: - continue - json_str = _extractBalancedJson(answer, m.end()) - if json_str is None: - continue - try: - arguments = json.loads(json_str) - if start_pos is None: - # Check for section begin marker before the call marker - section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start()) - start_pos = section if section != -1 else m.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - except json.JSONDecodeError: - pass - return matches, start_pos - - -def _parseMiniMaxToolCalls(answer: str, tool_names: list[str]): - """Parse MiniMax-style tool calls using invoke/parameter XML tags. - - Format: - - - value - - - """ - matches = [] - start_pos = None - for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): - tc_content = tc_match.group(1) - # Split on to handle multiple parallel calls in one block - for invoke_match in re.finditer(r'(.*?)', tc_content, re.DOTALL): - func_name = invoke_match.group(1).strip() - if func_name not in tool_names: - continue - invoke_body = invoke_match.group(2) - arguments = {} - for param_match in re.finditer(r'\s*(.*?)\s*', invoke_body, re.DOTALL): - param_name = param_match.group(1).strip() - param_value = param_match.group(2).strip() - try: - param_value = json.loads(param_value) - except (json.JSONDecodeError, ValueError): - pass # keep as string - arguments[param_name] = param_value - if start_pos is None: - start_pos = tc_match.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - return matches, start_pos - - -def _parseDeepSeekToolCalls(answer: str, tool_names: list[str]): - """Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters. - - Format: - <|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|> - """ - matches = [] - start_pos = None - for m in re.finditer( - r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*', - answer - ): - func_name = m.group(1).strip() - if func_name not in tool_names: - continue - json_str = _extractBalancedJson(answer, m.end()) - if json_str is None: - continue - try: - arguments = json.loads(json_str) - if start_pos is None: - # Check for section begin marker before the call marker - section = answer.rfind('<|tool▁calls▁begin|>', 0, m.start()) - start_pos = section if section != -1 else m.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - except json.JSONDecodeError: - pass - return matches, start_pos - - -def _parseGlmToolCalls(answer: str, tool_names: list[str]): - """Parse GLM-style tool calls using arg_key/arg_value XML pairs. - - Format: - function_name - key1 - value1 - - """ - matches = [] - start_pos = None - for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): - tc_content = tc_match.group(1) - # First non-tag text is the function name - name_match = re.match(r'([^<\s]+)', tc_content.strip()) - if not name_match: - continue - func_name = name_match.group(1).strip() - if func_name not in tool_names: - continue - # Extract arg_key/arg_value pairs - keys = [k.group(1).strip() for k in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)] - vals = [v.group(1).strip() for v in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)] - if len(keys) != len(vals): - continue - arguments = {} - for k, v in zip(keys, vals): - try: - v = json.loads(v) - except (json.JSONDecodeError, ValueError): - pass # keep as string - arguments[k] = v - if start_pos is None: - start_pos = tc_match.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - return matches, start_pos - - -def _parsePythonicToolCalls(answer: str, tool_names: list[str]): - """Parse pythonic-style tool calls used by Llama 4 and similar models. - - Format: - [func_name(param1="value1", param2="value2"), func_name2(...)] - """ - matches = [] - start_pos = None - # Match a bracketed list of function calls - bracket_match = re.search(r'\[([^\[\]]+)\]', answer) - if not bracket_match: - return matches, start_pos - - inner = bracket_match.group(1) - - # Build pattern for known tool names - escaped_names = [re.escape(name) for name in tool_names] - name_pattern = '|'.join(escaped_names) - - for call_match in re.finditer( - r'(' + name_pattern + r')\(([^)]*)\)', - inner - ): - func_name = call_match.group(1) - params_str = call_match.group(2).strip() - arguments = {} - - if params_str: - # Parse key="value" pairs, handling commas inside quoted values - for param_match in re.finditer( - r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)', - params_str - ): - param_name = param_match.group(1) - param_value = param_match.group(2).strip() - # Strip surrounding quotes - if (param_value.startswith('"') and param_value.endswith('"')) or \ - (param_value.startswith("'") and param_value.endswith("'")): - param_value = param_value[1:-1] - # Try to parse as JSON for numeric/bool/null values - try: - param_value = json.loads(param_value) - except (json.JSONDecodeError, ValueError): - pass - arguments[param_name] = param_value - - if start_pos is None: - start_pos = bracket_match.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) - - return matches, start_pos - - -def parseToolCall(answer: str, tool_names: list[str], return_prefix: bool = False): - matches = [] - start_pos = None - - def _return(matches, start_pos): - if return_prefix: - prefix = answer[:start_pos] if matches and start_pos is not None else '' - return matches, prefix - return matches - - # abort on very short answers to save computation cycles - if len(answer) < 10: - return _return(matches, start_pos) - - # Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters) - matches, start_pos = _parseDeepSeekToolCalls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for Kimi-K2-style tool calls (pipe-delimited tokens) - matches, start_pos = _parseKimiToolCalls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for channel-based tool calls (e.g. GPT-OSS format) - matches, start_pos = _parseChannelToolCalls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for MiniMax-style tool calls (invoke/parameter XML tags) - matches, start_pos = _parseMiniMaxToolCalls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for GLM-style tool calls (arg_key/arg_value XML pairs) - matches, start_pos = _parseGlmToolCalls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for XML-parameter style tool calls (e.g. Qwen3.5 format) - matches, start_pos = _parseXmlParamToolCalls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for Mistral/Devstral-style tool calls ([TOOL_CALLS]name[ARGS]json) - matches, start_pos = _parseMistralTokenToolCalls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for bare function-name style tool calls (e.g. Mistral format) - matches, start_pos = _parseBareNameToolCalls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Check for pythonic-style tool calls (e.g. Llama 4 format) - matches, start_pos = _parsePythonicToolCalls(answer, tool_names) - if matches: - return _return(matches, start_pos) - - # Define the regex pattern to find the JSON content wrapped in , , , and other tags observed from various models - patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)"] - - for pattern in patterns: - for match in re.finditer(pattern, answer, re.DOTALL): - # print(match.group(2)) - if match.group(2) is None: - continue - # remove backtick wraps if present - candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip()) - candidate = re.sub(r"```$", "", candidate.strip()) - # unwrap inner tags - candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL) - # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually - if re.search(r"\}\s*\n\s*\{", candidate) is not None: - candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate) - if not candidate.strip().startswith("["): - candidate = "[" + candidate + "]" - - candidates = [] - try: - # parse the candidate JSON into a dictionary - candidates = json.loads(candidate) - if not isinstance(candidates, list): - candidates = [candidates] - except json.JSONDecodeError: - # Ignore invalid JSON silently - continue - - for candidate_dict in candidates: - checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names) - if checked_candidate is not None: - if start_pos is None: - start_pos = match.start() - matches.append(checked_candidate) - - # last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags - if len(matches) == 0: - try: - candidate = answer - # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually - if re.search(r"\}\s*\n\s*\{", candidate) is not None: - candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate) - if not candidate.strip().startswith("["): - candidate = "[" + candidate + "]" - # parse the candidate JSON into a dictionary - candidates = json.loads(candidate) - if not isinstance(candidates, list): - candidates = [candidates] - for candidate_dict in candidates: - checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names) - if checked_candidate is not None: - matches.append(checked_candidate) - except json.JSONDecodeError: - # Ignore invalid JSON silently - pass - - return _return(matches, start_pos) diff --git a/modules/chat.py b/modules/chat.py index 87e52851..02ae46e4 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -239,6 +239,7 @@ def generate_chat_prompt(user_input, state, **kwargs): name1=state['name1'], name2=state['name2'], user_bio=replace_character_names(state['user_bio'], state['name1'], state['name2']), + tools=state['tools'] if 'tools' in state else None, ) messages = [] @@ -1186,14 +1187,10 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): # Load tools if any are selected selected = state.get('selected_tools', []) - parseToolCall = None + parse_tool_call = None if selected: from modules.tool_use import load_tools, execute_tool - try: - from extensions.openai.utils import parseToolCall, getToolCallId - except ImportError: - logger.warning('Tool calling requires the openai extension for parseToolCall. Disabling tools.') - selected = [] + from modules.tool_parsing import parse_tool_call, get_tool_call_id if selected: tool_defs, tool_executors = load_tools(selected) @@ -1253,7 +1250,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): last_save_time = current_time # Early stop on tool call detection - if tool_func_names and parseToolCall(history['internal'][-1][1], tool_func_names): + if tool_func_names and parse_tool_call(history['internal'][-1][1], tool_func_names): break # Save the model's visible output before re-applying visible_prefix, @@ -1285,7 +1282,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): break answer = history['internal'][-1][1] - parsed_calls, content_prefix = parseToolCall(answer, tool_func_names, return_prefix=True) if answer else (None, '') + parsed_calls, content_prefix = parse_tool_call(answer, tool_func_names, return_prefix=True) if answer else (None, '') if not parsed_calls: break # No tool calls — done @@ -1302,7 +1299,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): serialized = [] tc_headers = [] for tc in parsed_calls: - tc['id'] = getToolCallId() + tc['id'] = get_tool_call_id() fn_name = tc['function']['name'] fn_args = tc['function'].get('arguments', {}) @@ -1343,7 +1340,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): # Preserve thinking block and intermediate text from this turn. # content_prefix is the raw text before tool call syntax (returned - # by parseToolCall); HTML-escape it and extract thinking to get + # by parse_tool_call); HTML-escape it and extract thinking to get # the content the user should see. content_text = html.escape(content_prefix) thinking_content, intermediate = extract_thinking_block(content_text) diff --git a/modules/tool_parsing.py b/modules/tool_parsing.py new file mode 100644 index 00000000..460188d3 --- /dev/null +++ b/modules/tool_parsing.py @@ -0,0 +1,553 @@ +import json +import random +import re + + +def get_tool_call_id() -> str: + letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789" + b = [random.choice(letter_bytes) for _ in range(8)] + return "call_" + "".join(b).lower() + + +def check_and_sanitize_tool_call_candidate(candidate_dict: dict, tool_names: list[str]): + # check if property 'function' exists and is a dictionary, otherwise adapt dict + if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str): + candidate_dict = {"type": "function", "function": candidate_dict} + if 'function' in candidate_dict and isinstance(candidate_dict['function'], str): + candidate_dict['name'] = candidate_dict['function'] + del candidate_dict['function'] + candidate_dict = {"type": "function", "function": candidate_dict} + if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict): + # check if 'name' exists within 'function' and is part of known tools + if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names: + candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value + # map property 'parameters' used by some older models to 'arguments' + if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]: + candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"] + del candidate_dict["function"]["parameters"] + return candidate_dict + return None + + +def _extract_balanced_json(text: str, start: int) -> str | None: + """Extract a balanced JSON object from text starting at the given position. + + Walks through the string tracking brace depth and string boundaries + to correctly handle arbitrary nesting levels. + """ + if start >= len(text) or text[start] != '{': + return None + depth = 0 + in_string = False + escape_next = False + for i in range(start, len(text)): + c = text[i] + if escape_next: + escape_next = False + continue + if c == '\\' and in_string: + escape_next = True + continue + if c == '"': + in_string = not in_string + continue + if in_string: + continue + if c == '{': + depth += 1 + elif c == '}': + depth -= 1 + if depth == 0: + return text[start:i + 1] + return None + + +def _parse_channel_tool_calls(answer: str, tool_names: list[str]): + """Parse channel-based tool calls used by GPT-OSS and similar models. + + Format: + <|start|>assistant to=functions.func_name<|channel|>commentary json<|message|>{"arg": "value"} + or: + <|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"} + """ + matches = [] + start_pos = None + # Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format) + # Pattern 2: to=functions.NAME after <|channel|> (alternative format) + patterns = [ + r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>', + r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>', + ] + for pattern in patterns: + for m in re.finditer(pattern, answer): + func_name = m.group(1).strip() + if func_name not in tool_names: + continue + json_str = _extract_balanced_json(answer, m.end()) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + if start_pos is None: + prefix = answer.rfind('<|start|>assistant', 0, m.start()) + start_pos = prefix if prefix != -1 else m.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass + if matches: + break + return matches, start_pos + + +def _parse_mistral_token_tool_calls(answer: str, tool_names: list[str]): + """Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens. + + Format: + [TOOL_CALLS]func_name[ARGS]{"arg": "value"} + """ + matches = [] + start_pos = None + for m in re.finditer( + r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*', + answer + ): + func_name = m.group(1).strip() + if func_name not in tool_names: + continue + json_str = _extract_balanced_json(answer, m.end()) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + if start_pos is None: + start_pos = m.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass + return matches, start_pos + + +def _parse_bare_name_tool_calls(answer: str, tool_names: list[str]): + """Parse bare function-name style tool calls used by Mistral and similar models. + + Format: + functionName{"arg": "value"} + Multiple calls are concatenated directly or separated by whitespace. + """ + matches = [] + start_pos = None + # Match tool name followed by opening brace, then extract balanced JSON + escaped_names = [re.escape(name) for name in tool_names] + pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{' + for match in re.finditer(pattern, answer): + text = match.group(0) + name = None + for n in tool_names: + if text.startswith(n): + name = n + break + if not name: + continue + brace_start = match.end() - 1 + json_str = _extract_balanced_json(answer, brace_start) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + if start_pos is None: + start_pos = match.start() + matches.append({ + "type": "function", + "function": { + "name": name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass + return matches, start_pos + + +def _parse_xml_param_tool_calls(answer: str, tool_names: list[str]): + """Parse XML-parameter style tool calls used by Qwen3.5 and similar models. + + Format: + + + value + + + """ + matches = [] + start_pos = None + for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): + tc_content = tc_match.group(1) + func_match = re.search(r']+)>', tc_content) + if not func_match: + continue + func_name = func_match.group(1).strip() + if func_name not in tool_names: + continue + arguments = {} + for param_match in re.finditer(r']+)>\s*(.*?)\s*', tc_content, re.DOTALL): + param_name = param_match.group(1).strip() + param_value = param_match.group(2).strip() + try: + param_value = json.loads(param_value) + except (json.JSONDecodeError, ValueError): + pass # keep as string + arguments[param_name] = param_value + if start_pos is None: + start_pos = tc_match.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + return matches, start_pos + + +def _parse_kimi_tool_calls(answer: str, tool_names: list[str]): + """Parse Kimi-K2-style tool calls using pipe-delimited tokens. + + Format: + <|tool_calls_section_begin|> + <|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|> + <|tool_calls_section_end|> + """ + matches = [] + start_pos = None + for m in re.finditer( + r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*', + answer + ): + func_name = m.group(1).strip() + if func_name not in tool_names: + continue + json_str = _extract_balanced_json(answer, m.end()) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + if start_pos is None: + # Check for section begin marker before the call marker + section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start()) + start_pos = section if section != -1 else m.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass + return matches, start_pos + + +def _parse_minimax_tool_calls(answer: str, tool_names: list[str]): + """Parse MiniMax-style tool calls using invoke/parameter XML tags. + + Format: + + + value + + + """ + matches = [] + start_pos = None + for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): + tc_content = tc_match.group(1) + # Split on to handle multiple parallel calls in one block + for invoke_match in re.finditer(r'(.*?)', tc_content, re.DOTALL): + func_name = invoke_match.group(1).strip() + if func_name not in tool_names: + continue + invoke_body = invoke_match.group(2) + arguments = {} + for param_match in re.finditer(r'\s*(.*?)\s*', invoke_body, re.DOTALL): + param_name = param_match.group(1).strip() + param_value = param_match.group(2).strip() + try: + param_value = json.loads(param_value) + except (json.JSONDecodeError, ValueError): + pass # keep as string + arguments[param_name] = param_value + if start_pos is None: + start_pos = tc_match.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + return matches, start_pos + + +def _parse_deep_seek_tool_calls(answer: str, tool_names: list[str]): + """Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters. + + Format: + <|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|> + """ + matches = [] + start_pos = None + for m in re.finditer( + r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*', + answer + ): + func_name = m.group(1).strip() + if func_name not in tool_names: + continue + json_str = _extract_balanced_json(answer, m.end()) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + if start_pos is None: + # Check for section begin marker before the call marker + section = answer.rfind('<|tool▁calls▁begin|>', 0, m.start()) + start_pos = section if section != -1 else m.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass + return matches, start_pos + + +def _parse_glm_tool_calls(answer: str, tool_names: list[str]): + """Parse GLM-style tool calls using arg_key/arg_value XML pairs. + + Format: + function_name + key1 + value1 + + """ + matches = [] + start_pos = None + for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): + tc_content = tc_match.group(1) + # First non-tag text is the function name + name_match = re.match(r'([^<\s]+)', tc_content.strip()) + if not name_match: + continue + func_name = name_match.group(1).strip() + if func_name not in tool_names: + continue + # Extract arg_key/arg_value pairs + keys = [k.group(1).strip() for k in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)] + vals = [v.group(1).strip() for v in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)] + if len(keys) != len(vals): + continue + arguments = {} + for k, v in zip(keys, vals): + try: + v = json.loads(v) + except (json.JSONDecodeError, ValueError): + pass # keep as string + arguments[k] = v + if start_pos is None: + start_pos = tc_match.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + return matches, start_pos + + +def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]): + """Parse pythonic-style tool calls used by Llama 4 and similar models. + + Format: + [func_name(param1="value1", param2="value2"), func_name2(...)] + """ + matches = [] + start_pos = None + # Match a bracketed list of function calls + bracket_match = re.search(r'\[([^\[\]]+)\]', answer) + if not bracket_match: + return matches, start_pos + + inner = bracket_match.group(1) + + # Build pattern for known tool names + escaped_names = [re.escape(name) for name in tool_names] + name_pattern = '|'.join(escaped_names) + + for call_match in re.finditer( + r'(' + name_pattern + r')\(([^)]*)\)', + inner + ): + func_name = call_match.group(1) + params_str = call_match.group(2).strip() + arguments = {} + + if params_str: + # Parse key="value" pairs, handling commas inside quoted values + for param_match in re.finditer( + r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)', + params_str + ): + param_name = param_match.group(1) + param_value = param_match.group(2).strip() + # Strip surrounding quotes + if (param_value.startswith('"') and param_value.endswith('"')) or \ + (param_value.startswith("'") and param_value.endswith("'")): + param_value = param_value[1:-1] + # Try to parse as JSON for numeric/bool/null values + try: + param_value = json.loads(param_value) + except (json.JSONDecodeError, ValueError): + pass + arguments[param_name] = param_value + + if start_pos is None: + start_pos = bracket_match.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + + return matches, start_pos + + +def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False): + matches = [] + start_pos = None + + def _return(matches, start_pos): + if return_prefix: + prefix = answer[:start_pos] if matches and start_pos is not None else '' + return matches, prefix + return matches + + # Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters) + matches, start_pos = _parse_deep_seek_tool_calls(answer, tool_names) + if matches: + return _return(matches, start_pos) + + # Check for Kimi-K2-style tool calls (pipe-delimited tokens) + matches, start_pos = _parse_kimi_tool_calls(answer, tool_names) + if matches: + return _return(matches, start_pos) + + # Check for channel-based tool calls (e.g. GPT-OSS format) + matches, start_pos = _parse_channel_tool_calls(answer, tool_names) + if matches: + return _return(matches, start_pos) + + # Check for MiniMax-style tool calls (invoke/parameter XML tags) + matches, start_pos = _parse_minimax_tool_calls(answer, tool_names) + if matches: + return _return(matches, start_pos) + + # Check for GLM-style tool calls (arg_key/arg_value XML pairs) + matches, start_pos = _parse_glm_tool_calls(answer, tool_names) + if matches: + return _return(matches, start_pos) + + # Check for XML-parameter style tool calls (e.g. Qwen3.5 format) + matches, start_pos = _parse_xml_param_tool_calls(answer, tool_names) + if matches: + return _return(matches, start_pos) + + # Check for Mistral/Devstral-style tool calls ([TOOL_CALLS]name[ARGS]json) + matches, start_pos = _parse_mistral_token_tool_calls(answer, tool_names) + if matches: + return _return(matches, start_pos) + + # Check for bare function-name style tool calls (e.g. Mistral format) + matches, start_pos = _parse_bare_name_tool_calls(answer, tool_names) + if matches: + return _return(matches, start_pos) + + # Check for pythonic-style tool calls (e.g. Llama 4 format) + matches, start_pos = _parse_pythonic_tool_calls(answer, tool_names) + if matches: + return _return(matches, start_pos) + + # Define the regex pattern to find the JSON content wrapped in , , , and other tags observed from various models + patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)"] + + for pattern in patterns: + for match in re.finditer(pattern, answer, re.DOTALL): + if match.group(2) is None: + continue + # remove backtick wraps if present + candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip()) + candidate = re.sub(r"```$", "", candidate.strip()) + # unwrap inner tags + candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL) + # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually + if re.search(r"\}\s*\n\s*\{", candidate) is not None: + candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate) + if not candidate.strip().startswith("["): + candidate = "[" + candidate + "]" + + candidates = [] + try: + # parse the candidate JSON into a dictionary + candidates = json.loads(candidate) + if not isinstance(candidates, list): + candidates = [candidates] + except json.JSONDecodeError: + # Ignore invalid JSON silently + continue + + for candidate_dict in candidates: + checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names) + if checked_candidate is not None: + if start_pos is None: + start_pos = match.start() + matches.append(checked_candidate) + + # last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags + if len(matches) == 0: + try: + candidate = answer + # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually + if re.search(r"\}\s*\n\s*\{", candidate) is not None: + candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate) + if not candidate.strip().startswith("["): + candidate = "[" + candidate + "]" + # parse the candidate JSON into a dictionary + candidates = json.loads(candidate) + if not isinstance(candidates, list): + candidates = [candidates] + for candidate_dict in candidates: + checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names) + if checked_candidate is not None: + matches.append(checked_candidate) + except json.JSONDecodeError: + # Ignore invalid JSON silently + pass + + return _return(matches, start_pos) diff --git a/modules/tool_use.py b/modules/tool_use.py index 55424853..e22b1798 100644 --- a/modules/tool_use.py +++ b/modules/tool_use.py @@ -3,7 +3,7 @@ import json from modules import shared from modules.logging_colors import logger -from modules.utils import natural_keys +from modules.utils import natural_keys, sanitize_filename def get_available_tools(): @@ -23,6 +23,10 @@ def load_tools(selected_names): tool_defs = [] executors = {} for name in selected_names: + name = sanitize_filename(name) + if not name: + continue + path = shared.user_data_dir / 'tools' / f'{name}.py' if not path.exists(): continue diff --git a/modules/ui_chat.py b/modules/ui_chat.py index ce9fc0a2..0acf9c04 100644 --- a/modules/ui_chat.py +++ b/modules/ui_chat.py @@ -97,7 +97,7 @@ def create_ui(): shared.gradio['tools_refresh'].click(fn=lambda: gr.update(choices=get_available_tools()), inputs=[], outputs=[shared.gradio['selected_tools']]) def sync_web_tools(selected): - if 'web_search' in selected and 'fetch_webpage' not in selected: + if 'web_search' in selected and 'fetch_webpage' not in selected and 'fetch_webpage' in get_available_tools(): selected.append('fetch_webpage') return gr.update(value=selected) diff --git a/modules/web_search.py b/modules/web_search.py index 754dd111..216d7933 100644 --- a/modules/web_search.py +++ b/modules/web_search.py @@ -1,11 +1,13 @@ import concurrent.futures import html +import ipaddress import random import re +import socket import urllib.request from concurrent.futures import as_completed from datetime import datetime -from urllib.parse import quote_plus +from urllib.parse import quote_plus, urlparse import requests @@ -13,6 +15,26 @@ from modules import shared from modules.logging_colors import logger +def _validate_url(url): + """Validate that a URL is safe to fetch (not targeting private/internal networks).""" + parsed = urlparse(url) + if parsed.scheme not in ('http', 'https'): + raise ValueError(f"Unsupported URL scheme: {parsed.scheme}") + + hostname = parsed.hostname + if not hostname: + raise ValueError("No hostname in URL") + + # Resolve hostname and check all returned addresses + try: + for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None): + ip = ipaddress.ip_address(sockaddr[0]) + if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: + raise ValueError(f"Access to private/internal address {ip} is blocked") + except socket.gaierror: + raise ValueError(f"Could not resolve hostname: {hostname}") + + def get_current_timestamp(): """Returns the current time in 24-hour format""" return datetime.now().strftime('%b %d, %Y %H:%M') @@ -25,11 +47,20 @@ def download_web_page(url, timeout=10, include_links=False): import html2text try: + _validate_url(url) headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' } - response = requests.get(url, headers=headers, timeout=timeout) - response.raise_for_status() # Raise an exception for bad status codes + max_redirects = 5 + for _ in range(max_redirects): + response = requests.get(url, headers=headers, timeout=timeout, allow_redirects=False) + if response.is_redirect and 'Location' in response.headers: + url = response.headers['Location'] + _validate_url(url) + else: + break + + response.raise_for_status() # Initialize the HTML to Markdown converter h = html2text.HTML2Text()