API: Add tool call parsing for DeepSeek, GLM, MiniMax, and Kimi models

This commit is contained in:
oobabooga 2026-03-06 15:06:16 -03:00
parent f5acf55207
commit 044566d42d
2 changed files with 270 additions and 21 deletions

View file

@ -83,6 +83,39 @@ def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str
return None
def _extractBalancedJson(text: str, start: int) -> str | None:
"""Extract a balanced JSON object from text starting at the given position.
Walks through the string tracking brace depth and string boundaries
to correctly handle arbitrary nesting levels.
"""
if start >= len(text) or text[start] != '{':
return None
depth = 0
in_string = False
escape_next = False
for i in range(start, len(text)):
c = text[i]
if escape_next:
escape_next = False
continue
if c == '\\' and in_string:
escape_next = True
continue
if c == '"':
in_string = not in_string
continue
if in_string:
continue
if c == '{':
depth += 1
elif c == '}':
depth -= 1
if depth == 0:
return text[start:i + 1]
return None
def _parseChannelToolCalls(answer: str, tool_names: list[str]):
"""Parse channel-based tool calls used by GPT-OSS and similar models.
@ -91,14 +124,17 @@ def _parseChannelToolCalls(answer: str, tool_names: list[str]):
"""
matches = []
for m in re.finditer(
r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>',
answer
):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extractBalancedJson(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(m.group(2))
arguments = json.loads(json_str)
matches.append({
"type": "function",
"function": {
@ -119,27 +155,33 @@ def _parseBareNameToolCalls(answer: str, tool_names: list[str]):
Multiple calls are concatenated directly or separated by whitespace.
"""
matches = []
# Build pattern that matches any known tool name followed by a JSON object
# Match tool name followed by opening brace, then extract balanced JSON
escaped_names = [re.escape(name) for name in tool_names]
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
for match in re.finditer(pattern, answer):
text = match.group(0)
# Split into function name and JSON arguments
for name in tool_names:
if text.startswith(name):
json_str = text[len(name):].strip()
try:
arguments = json.loads(json_str)
matches.append({
"type": "function",
"function": {
"name": name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
name = None
for n in tool_names:
if text.startswith(n):
name = n
break
if not name:
continue
brace_start = match.end() - 1
json_str = _extractBalancedJson(answer, brace_start)
if json_str is None:
continue
try:
arguments = json.loads(json_str)
matches.append({
"type": "function",
"function": {
"name": name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches
@ -181,6 +223,149 @@ def _parseXmlParamToolCalls(answer: str, tool_names: list[str]):
return matches
def _parseKimiToolCalls(answer: str, tool_names: list[str]):
"""Parse Kimi-K2-style tool calls using pipe-delimited tokens.
Format:
<|tool_calls_section_begin|>
<|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
<|tool_calls_section_end|>
"""
matches = []
for m in re.finditer(
r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
answer
):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extractBalancedJson(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches
def _parseMiniMaxToolCalls(answer: str, tool_names: list[str]):
"""Parse MiniMax-style tool calls using invoke/parameter XML tags.
Format:
<minimax:tool_call>
<invoke name="function_name">
<parameter name="param_name">value</parameter>
</invoke>
</minimax:tool_call>
"""
matches = []
for tc_match in re.finditer(r'<minimax:tool_call>\s*(.*?)\s*</minimax:tool_call>', answer, re.DOTALL):
tc_content = tc_match.group(1)
# Split on <invoke> to handle multiple parallel calls in one block
for invoke_match in re.finditer(r'<invoke\s+name="([^"]+)">(.*?)</invoke>', tc_content, re.DOTALL):
func_name = invoke_match.group(1).strip()
if func_name not in tool_names:
continue
invoke_body = invoke_match.group(2)
arguments = {}
for param_match in re.finditer(r'<parameter\s+name="([^"]+)">\s*(.*?)\s*</parameter>', invoke_body, re.DOTALL):
param_name = param_match.group(1).strip()
param_value = param_match.group(2).strip()
try:
param_value = json.loads(param_value)
except (json.JSONDecodeError, ValueError):
pass # keep as string
arguments[param_name] = param_value
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches
def _parseDeepSeekToolCalls(answer: str, tool_names: list[str]):
"""Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
Format:
<toolcallsbegin><toolcallbegin>func_name<toolsep>{"arg": "value"}<toolcallend><toolcallsend>
"""
matches = []
for m in re.finditer(
r'<tool▁call▁begin>\s*(\S+?)\s*<tool▁sep>\s*',
answer
):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extractBalancedJson(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
return matches
def _parseGlmToolCalls(answer: str, tool_names: list[str]):
"""Parse GLM-style tool calls using arg_key/arg_value XML pairs.
Format:
<tool_call>function_name
<arg_key>key1</arg_key>
<arg_value>value1</arg_value>
</tool_call>
"""
matches = []
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
tc_content = tc_match.group(1)
# First non-tag text is the function name
name_match = re.match(r'([^<\s]+)', tc_content.strip())
if not name_match:
continue
func_name = name_match.group(1).strip()
if func_name not in tool_names:
continue
# Extract arg_key/arg_value pairs
keys = [k.group(1).strip() for k in re.finditer(r'<arg_key>\s*(.*?)\s*</arg_key>', tc_content, re.DOTALL)]
vals = [v.group(1).strip() for v in re.finditer(r'<arg_value>\s*(.*?)\s*</arg_value>', tc_content, re.DOTALL)]
if len(keys) != len(vals):
continue
arguments = {}
for k, v in zip(keys, vals):
try:
v = json.loads(v)
except (json.JSONDecodeError, ValueError):
pass # keep as string
arguments[k] = v
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
return matches
def _parsePythonicToolCalls(answer: str, tool_names: list[str]):
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
@ -244,11 +429,31 @@ def parseToolCall(answer: str, tool_names: list[str]):
if len(answer) < 10:
return matches
# Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters)
matches = _parseDeepSeekToolCalls(answer, tool_names)
if matches:
return matches
# Check for Kimi-K2-style tool calls (pipe-delimited tokens)
matches = _parseKimiToolCalls(answer, tool_names)
if matches:
return matches
# Check for channel-based tool calls (e.g. GPT-OSS format)
matches = _parseChannelToolCalls(answer, tool_names)
if matches:
return matches
# Check for MiniMax-style tool calls (invoke/parameter XML tags)
matches = _parseMiniMaxToolCalls(answer, tool_names)
if matches:
return matches
# Check for GLM-style tool calls (arg_key/arg_value XML pairs)
matches = _parseGlmToolCalls(answer, tool_names)
if matches:
return matches
# Check for XML-parameter style tool calls (e.g. Qwen3.5 format)
matches = _parseXmlParamToolCalls(answer, tool_names)
if matches:

View file

@ -106,6 +106,50 @@ yaml.add_representer(str, str_presenter)
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
class _JsonDict(dict):
"""A dict that serializes as JSON when used in string concatenation.
Some Jinja2 templates (Qwen, GLM) iterate arguments with .items(),
requiring a dict. Others (DeepSeek) concatenate arguments as a
string, requiring JSON. This class satisfies both.
"""
def __str__(self):
return json.dumps(self, ensure_ascii=False)
def __add__(self, other):
return str(self) + other
def __radd__(self, other):
return other + str(self)
def _deserialize_tool_call_arguments(tool_calls):
"""Convert tool_call arguments from JSON strings to _JsonDict.
The OpenAI API spec sends arguments as a JSON string, but Jinja2
templates may need a dict (.items()) or a string (concatenation).
_JsonDict handles both transparently.
"""
result = []
for tc in tool_calls:
tc = copy.copy(tc)
func = tc.get('function', {})
if isinstance(func, dict):
func = dict(func)
args = func.get('arguments')
if isinstance(args, str):
try:
func['arguments'] = _JsonDict(json.loads(args))
except (json.JSONDecodeError, ValueError):
pass
elif isinstance(args, dict) and not isinstance(args, _JsonDict):
func['arguments'] = _JsonDict(args)
tc['function'] = func
result.append(tc)
return result
def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs.get('impersonate', False)
_continue = kwargs.get('_continue', False)
@ -172,7 +216,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
if not assistant_msg and entry_meta.get('tool_calls'):
# Assistant message with only tool_calls and no text content
messages.insert(insert_pos, {"role": "assistant", "content": "", "tool_calls": entry_meta['tool_calls']})
messages.insert(insert_pos, {"role": "assistant", "content": "", "tool_calls": _deserialize_tool_call_arguments(entry_meta['tool_calls'])})
elif assistant_msg:
# Handle GPT-OSS as a special case
if '<|channel|>analysis<|message|>' in assistant_msg or '<|channel|>final<|message|>' in assistant_msg:
@ -250,7 +294,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Attach tool_calls metadata to the assistant message if present
if entry_meta.get('tool_calls') and messages[insert_pos].get('role') == 'assistant':
messages[insert_pos]['tool_calls'] = entry_meta['tool_calls']
messages[insert_pos]['tool_calls'] = _deserialize_tool_call_arguments(entry_meta['tool_calls'])
if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
# Check for user message attachments in metadata