API: Support Llama 4 tool calling and fix tool calling edge cases

This commit is contained in:
oobabooga 2026-03-06 13:12:14 -03:00
parent 160f7ad6b4
commit 3531069824
2 changed files with 68 additions and 1 deletions

View file

@ -603,6 +603,12 @@ def validateTools(tools: list[dict]):
tool = tools[idx]
try:
tool_definition = ToolDefinition(**tool)
# Backfill defaults so Jinja2 templates don't crash on missing fields
func = tool.get("function", {})
if "description" not in func:
func["description"] = ""
if "parameters" not in func:
func["parameters"] = {"type": "object", "properties": {}}
if valid_tools is None:
valid_tools = []
valid_tools.append(tool)

View file

@ -91,7 +91,7 @@ def _parseChannelToolCalls(answer: str, tool_names: list[str]):
"""
matches = []
for m in re.finditer(
r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>(\{[^}]*(?:\{[^}]*\}[^}]*)*\})',
r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
answer
):
func_name = m.group(1).strip()
@ -181,6 +181,62 @@ def _parseXmlParamToolCalls(answer: str, tool_names: list[str]):
return matches
def _parsePythonicToolCalls(answer: str, tool_names: list[str]):
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
Format:
[func_name(param1="value1", param2="value2"), func_name2(...)]
"""
matches = []
# Match a bracketed list of function calls
bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
if not bracket_match:
return matches
inner = bracket_match.group(1)
# Build pattern for known tool names
escaped_names = [re.escape(name) for name in tool_names]
name_pattern = '|'.join(escaped_names)
for call_match in re.finditer(
r'(' + name_pattern + r')\(([^)]*)\)',
inner
):
func_name = call_match.group(1)
params_str = call_match.group(2).strip()
arguments = {}
if params_str:
# Parse key="value" pairs, handling commas inside quoted values
for param_match in re.finditer(
r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
params_str
):
param_name = param_match.group(1)
param_value = param_match.group(2).strip()
# Strip surrounding quotes
if (param_value.startswith('"') and param_value.endswith('"')) or \
(param_value.startswith("'") and param_value.endswith("'")):
param_value = param_value[1:-1]
# Try to parse as JSON for numeric/bool/null values
try:
param_value = json.loads(param_value)
except (json.JSONDecodeError, ValueError):
pass
arguments[param_name] = param_value
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches
def parseToolCall(answer: str, tool_names: list[str]):
matches = []
@ -203,6 +259,11 @@ def parseToolCall(answer: str, tool_names: list[str]):
if matches:
return matches
# Check for pythonic-style tool calls (e.g. Llama 4 format)
matches = _parsePythonicToolCalls(answer, tool_names)
if matches:
return matches
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]