From 3531069824a1ba1d06839b7d7b6b94ecac4ea1b1 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 6 Mar 2026 13:12:14 -0300 Subject: [PATCH] API: Support Llama 4 tool calling and fix tool calling edge cases --- extensions/openai/completions.py | 6 +++ extensions/openai/utils.py | 63 +++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 297c9ba7..46502bdc 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -603,6 +603,12 @@ def validateTools(tools: list[dict]): tool = tools[idx] try: tool_definition = ToolDefinition(**tool) + # Backfill defaults so Jinja2 templates don't crash on missing fields + func = tool.get("function", {}) + if "description" not in func: + func["description"] = "" + if "parameters" not in func: + func["parameters"] = {"type": "object", "properties": {}} if valid_tools is None: valid_tools = [] valid_tools.append(tool) diff --git a/extensions/openai/utils.py b/extensions/openai/utils.py index 6937a108..64bd1631 100644 --- a/extensions/openai/utils.py +++ b/extensions/openai/utils.py @@ -91,7 +91,7 @@ def _parseChannelToolCalls(answer: str, tool_names: list[str]): """ matches = [] for m in re.finditer( - r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>(\{[^}]*(?:\{[^}]*\}[^}]*)*\})', + r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})', answer ): func_name = m.group(1).strip() @@ -181,6 +181,62 @@ def _parseXmlParamToolCalls(answer: str, tool_names: list[str]): return matches +def _parsePythonicToolCalls(answer: str, tool_names: list[str]): + """Parse pythonic-style tool calls used by Llama 4 and similar models. + + Format: + [func_name(param1="value1", param2="value2"), func_name2(...)] + """ + matches = [] + # Match a bracketed list of function calls + bracket_match = re.search(r'\[([^\[\]]+)\]', answer) + if not bracket_match: + return matches + + inner = bracket_match.group(1) + + # Build pattern for known tool names + escaped_names = [re.escape(name) for name in tool_names] + name_pattern = '|'.join(escaped_names) + + for call_match in re.finditer( + r'(' + name_pattern + r')\(([^)]*)\)', + inner + ): + func_name = call_match.group(1) + params_str = call_match.group(2).strip() + arguments = {} + + if params_str: + # Parse key="value" pairs, handling commas inside quoted values + for param_match in re.finditer( + r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)', + params_str + ): + param_name = param_match.group(1) + param_value = param_match.group(2).strip() + # Strip surrounding quotes + if (param_value.startswith('"') and param_value.endswith('"')) or \ + (param_value.startswith("'") and param_value.endswith("'")): + param_value = param_value[1:-1] + # Try to parse as JSON for numeric/bool/null values + try: + param_value = json.loads(param_value) + except (json.JSONDecodeError, ValueError): + pass + arguments[param_name] = param_value + + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + + return matches + + def parseToolCall(answer: str, tool_names: list[str]): matches = [] @@ -203,6 +259,11 @@ def parseToolCall(answer: str, tool_names: list[str]): if matches: return matches + # Check for pythonic-style tool calls (e.g. Llama 4 format) + matches = _parsePythonicToolCalls(answer, tool_names) + if matches: + return matches + # Define the regex pattern to find the JSON content wrapped in , , , and other tags observed from various models patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)"]