mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-08 06:33:51 +01:00
API: Support Llama 4 tool calling and fix tool calling edge cases
This commit is contained in:
parent
160f7ad6b4
commit
3531069824
|
|
@ -603,6 +603,12 @@ def validateTools(tools: list[dict]):
|
|||
tool = tools[idx]
|
||||
try:
|
||||
tool_definition = ToolDefinition(**tool)
|
||||
# Backfill defaults so Jinja2 templates don't crash on missing fields
|
||||
func = tool.get("function", {})
|
||||
if "description" not in func:
|
||||
func["description"] = ""
|
||||
if "parameters" not in func:
|
||||
func["parameters"] = {"type": "object", "properties": {}}
|
||||
if valid_tools is None:
|
||||
valid_tools = []
|
||||
valid_tools.append(tool)
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ def _parseChannelToolCalls(answer: str, tool_names: list[str]):
|
|||
"""
|
||||
matches = []
|
||||
for m in re.finditer(
|
||||
r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>(\{[^}]*(?:\{[^}]*\}[^}]*)*\})',
|
||||
r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
|
|
@ -181,6 +181,62 @@ def _parseXmlParamToolCalls(answer: str, tool_names: list[str]):
|
|||
return matches
|
||||
|
||||
|
||||
def _parsePythonicToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
|
||||
|
||||
Format:
|
||||
[func_name(param1="value1", param2="value2"), func_name2(...)]
|
||||
"""
|
||||
matches = []
|
||||
# Match a bracketed list of function calls
|
||||
bracket_match = re.search(r'\[([^\[\]]+)\]', answer)
|
||||
if not bracket_match:
|
||||
return matches
|
||||
|
||||
inner = bracket_match.group(1)
|
||||
|
||||
# Build pattern for known tool names
|
||||
escaped_names = [re.escape(name) for name in tool_names]
|
||||
name_pattern = '|'.join(escaped_names)
|
||||
|
||||
for call_match in re.finditer(
|
||||
r'(' + name_pattern + r')\(([^)]*)\)',
|
||||
inner
|
||||
):
|
||||
func_name = call_match.group(1)
|
||||
params_str = call_match.group(2).strip()
|
||||
arguments = {}
|
||||
|
||||
if params_str:
|
||||
# Parse key="value" pairs, handling commas inside quoted values
|
||||
for param_match in re.finditer(
|
||||
r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)',
|
||||
params_str
|
||||
):
|
||||
param_name = param_match.group(1)
|
||||
param_value = param_match.group(2).strip()
|
||||
# Strip surrounding quotes
|
||||
if (param_value.startswith('"') and param_value.endswith('"')) or \
|
||||
(param_value.startswith("'") and param_value.endswith("'")):
|
||||
param_value = param_value[1:-1]
|
||||
# Try to parse as JSON for numeric/bool/null values
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
arguments[param_name] = param_value
|
||||
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def parseToolCall(answer: str, tool_names: list[str]):
|
||||
matches = []
|
||||
|
||||
|
|
@ -203,6 +259,11 @@ def parseToolCall(answer: str, tool_names: list[str]):
|
|||
if matches:
|
||||
return matches
|
||||
|
||||
# Check for pythonic-style tool calls (e.g. Llama 4 format)
|
||||
matches = _parsePythonicToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
|
||||
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue