diff --git a/extensions/openai/utils.py b/extensions/openai/utils.py index 64bd1631..f4a31d1a 100644 --- a/extensions/openai/utils.py +++ b/extensions/openai/utils.py @@ -83,6 +83,39 @@ def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str return None +def _extractBalancedJson(text: str, start: int) -> str | None: + """Extract a balanced JSON object from text starting at the given position. + + Walks through the string tracking brace depth and string boundaries + to correctly handle arbitrary nesting levels. + """ + if start >= len(text) or text[start] != '{': + return None + depth = 0 + in_string = False + escape_next = False + for i in range(start, len(text)): + c = text[i] + if escape_next: + escape_next = False + continue + if c == '\\' and in_string: + escape_next = True + continue + if c == '"': + in_string = not in_string + continue + if in_string: + continue + if c == '{': + depth += 1 + elif c == '}': + depth -= 1 + if depth == 0: + return text[start:i + 1] + return None + + def _parseChannelToolCalls(answer: str, tool_names: list[str]): """Parse channel-based tool calls used by GPT-OSS and similar models. @@ -91,14 +124,17 @@ def _parseChannelToolCalls(answer: str, tool_names: list[str]): """ matches = [] for m in re.finditer( - r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})', + r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>', answer ): func_name = m.group(1).strip() if func_name not in tool_names: continue + json_str = _extractBalancedJson(answer, m.end()) + if json_str is None: + continue try: - arguments = json.loads(m.group(2)) + arguments = json.loads(json_str) matches.append({ "type": "function", "function": { @@ -119,27 +155,33 @@ def _parseBareNameToolCalls(answer: str, tool_names: list[str]): Multiple calls are concatenated directly or separated by whitespace. """ matches = [] - # Build pattern that matches any known tool name followed by a JSON object + # Match tool name followed by opening brace, then extract balanced JSON escaped_names = [re.escape(name) for name in tool_names] - pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}' + pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{' for match in re.finditer(pattern, answer): text = match.group(0) - # Split into function name and JSON arguments - for name in tool_names: - if text.startswith(name): - json_str = text[len(name):].strip() - try: - arguments = json.loads(json_str) - matches.append({ - "type": "function", - "function": { - "name": name, - "arguments": arguments - } - }) - except json.JSONDecodeError: - pass + name = None + for n in tool_names: + if text.startswith(n): + name = n break + if not name: + continue + brace_start = match.end() - 1 + json_str = _extractBalancedJson(answer, brace_start) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + matches.append({ + "type": "function", + "function": { + "name": name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass return matches @@ -181,6 +223,149 @@ def _parseXmlParamToolCalls(answer: str, tool_names: list[str]): return matches +def _parseKimiToolCalls(answer: str, tool_names: list[str]): + """Parse Kimi-K2-style tool calls using pipe-delimited tokens. + + Format: + <|tool_calls_section_begin|> + <|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|> + <|tool_calls_section_end|> + """ + matches = [] + for m in re.finditer( + r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*', + answer + ): + func_name = m.group(1).strip() + if func_name not in tool_names: + continue + json_str = _extractBalancedJson(answer, m.end()) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass + return matches + + +def _parseMiniMaxToolCalls(answer: str, tool_names: list[str]): + """Parse MiniMax-style tool calls using invoke/parameter XML tags. + + Format: + + + value + + + """ + matches = [] + for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): + tc_content = tc_match.group(1) + # Split on to handle multiple parallel calls in one block + for invoke_match in re.finditer(r'(.*?)', tc_content, re.DOTALL): + func_name = invoke_match.group(1).strip() + if func_name not in tool_names: + continue + invoke_body = invoke_match.group(2) + arguments = {} + for param_match in re.finditer(r'\s*(.*?)\s*', invoke_body, re.DOTALL): + param_name = param_match.group(1).strip() + param_value = param_match.group(2).strip() + try: + param_value = json.loads(param_value) + except (json.JSONDecodeError, ValueError): + pass # keep as string + arguments[param_name] = param_value + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + return matches + + +def _parseDeepSeekToolCalls(answer: str, tool_names: list[str]): + """Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters. + + Format: + <|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|> + """ + matches = [] + for m in re.finditer( + r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*', + answer + ): + func_name = m.group(1).strip() + if func_name not in tool_names: + continue + json_str = _extractBalancedJson(answer, m.end()) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass + return matches + + +def _parseGlmToolCalls(answer: str, tool_names: list[str]): + """Parse GLM-style tool calls using arg_key/arg_value XML pairs. + + Format: + function_name + key1 + value1 + + """ + matches = [] + for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): + tc_content = tc_match.group(1) + # First non-tag text is the function name + name_match = re.match(r'([^<\s]+)', tc_content.strip()) + if not name_match: + continue + func_name = name_match.group(1).strip() + if func_name not in tool_names: + continue + # Extract arg_key/arg_value pairs + keys = [k.group(1).strip() for k in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)] + vals = [v.group(1).strip() for v in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)] + if len(keys) != len(vals): + continue + arguments = {} + for k, v in zip(keys, vals): + try: + v = json.loads(v) + except (json.JSONDecodeError, ValueError): + pass # keep as string + arguments[k] = v + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + return matches + + def _parsePythonicToolCalls(answer: str, tool_names: list[str]): """Parse pythonic-style tool calls used by Llama 4 and similar models. @@ -244,11 +429,31 @@ def parseToolCall(answer: str, tool_names: list[str]): if len(answer) < 10: return matches + # Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters) + matches = _parseDeepSeekToolCalls(answer, tool_names) + if matches: + return matches + + # Check for Kimi-K2-style tool calls (pipe-delimited tokens) + matches = _parseKimiToolCalls(answer, tool_names) + if matches: + return matches + # Check for channel-based tool calls (e.g. GPT-OSS format) matches = _parseChannelToolCalls(answer, tool_names) if matches: return matches + # Check for MiniMax-style tool calls (invoke/parameter XML tags) + matches = _parseMiniMaxToolCalls(answer, tool_names) + if matches: + return matches + + # Check for GLM-style tool calls (arg_key/arg_value XML pairs) + matches = _parseGlmToolCalls(answer, tool_names) + if matches: + return matches + # Check for XML-parameter style tool calls (e.g. Qwen3.5 format) matches = _parseXmlParamToolCalls(answer, tool_names) if matches: diff --git a/modules/chat.py b/modules/chat.py index bc4fc1d8..36d373d6 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -106,6 +106,50 @@ yaml.add_representer(str, str_presenter) yaml.representer.SafeRepresenter.add_representer(str, str_presenter) +class _JsonDict(dict): + """A dict that serializes as JSON when used in string concatenation. + + Some Jinja2 templates (Qwen, GLM) iterate arguments with .items(), + requiring a dict. Others (DeepSeek) concatenate arguments as a + string, requiring JSON. This class satisfies both. + """ + + def __str__(self): + return json.dumps(self, ensure_ascii=False) + + def __add__(self, other): + return str(self) + other + + def __radd__(self, other): + return other + str(self) + + +def _deserialize_tool_call_arguments(tool_calls): + """Convert tool_call arguments from JSON strings to _JsonDict. + + The OpenAI API spec sends arguments as a JSON string, but Jinja2 + templates may need a dict (.items()) or a string (concatenation). + _JsonDict handles both transparently. + """ + result = [] + for tc in tool_calls: + tc = copy.copy(tc) + func = tc.get('function', {}) + if isinstance(func, dict): + func = dict(func) + args = func.get('arguments') + if isinstance(args, str): + try: + func['arguments'] = _JsonDict(json.loads(args)) + except (json.JSONDecodeError, ValueError): + pass + elif isinstance(args, dict) and not isinstance(args, _JsonDict): + func['arguments'] = _JsonDict(args) + tc['function'] = func + result.append(tc) + return result + + def generate_chat_prompt(user_input, state, **kwargs): impersonate = kwargs.get('impersonate', False) _continue = kwargs.get('_continue', False) @@ -172,7 +216,7 @@ def generate_chat_prompt(user_input, state, **kwargs): if not assistant_msg and entry_meta.get('tool_calls'): # Assistant message with only tool_calls and no text content - messages.insert(insert_pos, {"role": "assistant", "content": "", "tool_calls": entry_meta['tool_calls']}) + messages.insert(insert_pos, {"role": "assistant", "content": "", "tool_calls": _deserialize_tool_call_arguments(entry_meta['tool_calls'])}) elif assistant_msg: # Handle GPT-OSS as a special case if '<|channel|>analysis<|message|>' in assistant_msg or '<|channel|>final<|message|>' in assistant_msg: @@ -250,7 +294,7 @@ def generate_chat_prompt(user_input, state, **kwargs): # Attach tool_calls metadata to the assistant message if present if entry_meta.get('tool_calls') and messages[insert_pos].get('role') == 'assistant': - messages[insert_pos]['tool_calls'] = entry_meta['tool_calls'] + messages[insert_pos]['tool_calls'] = _deserialize_tool_call_arguments(entry_meta['tool_calls']) if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']: # Check for user message attachments in metadata