mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-08 06:33:51 +01:00
API: Add tool call parsing for DeepSeek, GLM, MiniMax, and Kimi models
This commit is contained in:
parent
f5acf55207
commit
044566d42d
|
|
@ -83,6 +83,39 @@ def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str
|
|||
return None
|
||||
|
||||
|
||||
def _extractBalancedJson(text: str, start: int) -> str | None:
|
||||
"""Extract a balanced JSON object from text starting at the given position.
|
||||
|
||||
Walks through the string tracking brace depth and string boundaries
|
||||
to correctly handle arbitrary nesting levels.
|
||||
"""
|
||||
if start >= len(text) or text[start] != '{':
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
for i in range(start, len(text)):
|
||||
c = text[i]
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
if c == '\\' and in_string:
|
||||
escape_next = True
|
||||
continue
|
||||
if c == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if c == '{':
|
||||
depth += 1
|
||||
elif c == '}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start:i + 1]
|
||||
return None
|
||||
|
||||
|
||||
def _parseChannelToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse channel-based tool calls used by GPT-OSS and similar models.
|
||||
|
||||
|
|
@ -91,14 +124,17 @@ def _parseChannelToolCalls(answer: str, tool_names: list[str]):
|
|||
"""
|
||||
matches = []
|
||||
for m in re.finditer(
|
||||
r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\})',
|
||||
r'<\|channel\|>commentary to=functions\.([^<\s]+)\s*(?:<\|constrain\|>json)?<\|message\|>',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extractBalancedJson(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(m.group(2))
|
||||
arguments = json.loads(json_str)
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
|
@ -119,27 +155,33 @@ def _parseBareNameToolCalls(answer: str, tool_names: list[str]):
|
|||
Multiple calls are concatenated directly or separated by whitespace.
|
||||
"""
|
||||
matches = []
|
||||
# Build pattern that matches any known tool name followed by a JSON object
|
||||
# Match tool name followed by opening brace, then extract balanced JSON
|
||||
escaped_names = [re.escape(name) for name in tool_names]
|
||||
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
|
||||
pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{'
|
||||
for match in re.finditer(pattern, answer):
|
||||
text = match.group(0)
|
||||
# Split into function name and JSON arguments
|
||||
for name in tool_names:
|
||||
if text.startswith(name):
|
||||
json_str = text[len(name):].strip()
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
name = None
|
||||
for n in tool_names:
|
||||
if text.startswith(n):
|
||||
name = n
|
||||
break
|
||||
if not name:
|
||||
continue
|
||||
brace_start = match.end() - 1
|
||||
json_str = _extractBalancedJson(answer, brace_start)
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches
|
||||
|
||||
|
||||
|
|
@ -181,6 +223,149 @@ def _parseXmlParamToolCalls(answer: str, tool_names: list[str]):
|
|||
return matches
|
||||
|
||||
|
||||
def _parseKimiToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse Kimi-K2-style tool calls using pipe-delimited tokens.
|
||||
|
||||
Format:
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>
|
||||
"""
|
||||
matches = []
|
||||
for m in re.finditer(
|
||||
r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extractBalancedJson(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches
|
||||
|
||||
|
||||
def _parseMiniMaxToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse MiniMax-style tool calls using invoke/parameter XML tags.
|
||||
|
||||
Format:
|
||||
<minimax:tool_call>
|
||||
<invoke name="function_name">
|
||||
<parameter name="param_name">value</parameter>
|
||||
</invoke>
|
||||
</minimax:tool_call>
|
||||
"""
|
||||
matches = []
|
||||
for tc_match in re.finditer(r'<minimax:tool_call>\s*(.*?)\s*</minimax:tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
# Split on <invoke> to handle multiple parallel calls in one block
|
||||
for invoke_match in re.finditer(r'<invoke\s+name="([^"]+)">(.*?)</invoke>', tc_content, re.DOTALL):
|
||||
func_name = invoke_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
invoke_body = invoke_match.group(2)
|
||||
arguments = {}
|
||||
for param_match in re.finditer(r'<parameter\s+name="([^"]+)">\s*(.*?)\s*</parameter>', invoke_body, re.DOTALL):
|
||||
param_name = param_match.group(1).strip()
|
||||
param_value = param_match.group(2).strip()
|
||||
try:
|
||||
param_value = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[param_name] = param_value
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches
|
||||
|
||||
|
||||
def _parseDeepSeekToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters.
|
||||
|
||||
Format:
|
||||
<|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|>
|
||||
"""
|
||||
matches = []
|
||||
for m in re.finditer(
|
||||
r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extractBalancedJson(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return matches
|
||||
|
||||
|
||||
def _parseGlmToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse GLM-style tool calls using arg_key/arg_value XML pairs.
|
||||
|
||||
Format:
|
||||
<tool_call>function_name
|
||||
<arg_key>key1</arg_key>
|
||||
<arg_value>value1</arg_value>
|
||||
</tool_call>
|
||||
"""
|
||||
matches = []
|
||||
for tc_match in re.finditer(r'<tool_call>\s*(.*?)\s*</tool_call>', answer, re.DOTALL):
|
||||
tc_content = tc_match.group(1)
|
||||
# First non-tag text is the function name
|
||||
name_match = re.match(r'([^<\s]+)', tc_content.strip())
|
||||
if not name_match:
|
||||
continue
|
||||
func_name = name_match.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
# Extract arg_key/arg_value pairs
|
||||
keys = [k.group(1).strip() for k in re.finditer(r'<arg_key>\s*(.*?)\s*</arg_key>', tc_content, re.DOTALL)]
|
||||
vals = [v.group(1).strip() for v in re.finditer(r'<arg_value>\s*(.*?)\s*</arg_value>', tc_content, re.DOTALL)]
|
||||
if len(keys) != len(vals):
|
||||
continue
|
||||
arguments = {}
|
||||
for k, v in zip(keys, vals):
|
||||
try:
|
||||
v = json.loads(v)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass # keep as string
|
||||
arguments[k] = v
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
return matches
|
||||
|
||||
|
||||
def _parsePythonicToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
|
||||
|
||||
|
|
@ -244,11 +429,31 @@ def parseToolCall(answer: str, tool_names: list[str]):
|
|||
if len(answer) < 10:
|
||||
return matches
|
||||
|
||||
# Check for DeepSeek-style tool calls (fullwidth Unicode token delimiters)
|
||||
matches = _parseDeepSeekToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Check for Kimi-K2-style tool calls (pipe-delimited tokens)
|
||||
matches = _parseKimiToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Check for channel-based tool calls (e.g. GPT-OSS format)
|
||||
matches = _parseChannelToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Check for MiniMax-style tool calls (invoke/parameter XML tags)
|
||||
matches = _parseMiniMaxToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Check for GLM-style tool calls (arg_key/arg_value XML pairs)
|
||||
matches = _parseGlmToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
# Check for XML-parameter style tool calls (e.g. Qwen3.5 format)
|
||||
matches = _parseXmlParamToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
|
|
|
|||
|
|
@ -106,6 +106,50 @@ yaml.add_representer(str, str_presenter)
|
|||
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
|
||||
|
||||
|
||||
class _JsonDict(dict):
|
||||
"""A dict that serializes as JSON when used in string concatenation.
|
||||
|
||||
Some Jinja2 templates (Qwen, GLM) iterate arguments with .items(),
|
||||
requiring a dict. Others (DeepSeek) concatenate arguments as a
|
||||
string, requiring JSON. This class satisfies both.
|
||||
"""
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(self, ensure_ascii=False)
|
||||
|
||||
def __add__(self, other):
|
||||
return str(self) + other
|
||||
|
||||
def __radd__(self, other):
|
||||
return other + str(self)
|
||||
|
||||
|
||||
def _deserialize_tool_call_arguments(tool_calls):
|
||||
"""Convert tool_call arguments from JSON strings to _JsonDict.
|
||||
|
||||
The OpenAI API spec sends arguments as a JSON string, but Jinja2
|
||||
templates may need a dict (.items()) or a string (concatenation).
|
||||
_JsonDict handles both transparently.
|
||||
"""
|
||||
result = []
|
||||
for tc in tool_calls:
|
||||
tc = copy.copy(tc)
|
||||
func = tc.get('function', {})
|
||||
if isinstance(func, dict):
|
||||
func = dict(func)
|
||||
args = func.get('arguments')
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
func['arguments'] = _JsonDict(json.loads(args))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
elif isinstance(args, dict) and not isinstance(args, _JsonDict):
|
||||
func['arguments'] = _JsonDict(args)
|
||||
tc['function'] = func
|
||||
result.append(tc)
|
||||
return result
|
||||
|
||||
|
||||
def generate_chat_prompt(user_input, state, **kwargs):
|
||||
impersonate = kwargs.get('impersonate', False)
|
||||
_continue = kwargs.get('_continue', False)
|
||||
|
|
@ -172,7 +216,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
|
||||
if not assistant_msg and entry_meta.get('tool_calls'):
|
||||
# Assistant message with only tool_calls and no text content
|
||||
messages.insert(insert_pos, {"role": "assistant", "content": "", "tool_calls": entry_meta['tool_calls']})
|
||||
messages.insert(insert_pos, {"role": "assistant", "content": "", "tool_calls": _deserialize_tool_call_arguments(entry_meta['tool_calls'])})
|
||||
elif assistant_msg:
|
||||
# Handle GPT-OSS as a special case
|
||||
if '<|channel|>analysis<|message|>' in assistant_msg or '<|channel|>final<|message|>' in assistant_msg:
|
||||
|
|
@ -250,7 +294,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
|
||||
# Attach tool_calls metadata to the assistant message if present
|
||||
if entry_meta.get('tool_calls') and messages[insert_pos].get('role') == 'assistant':
|
||||
messages[insert_pos]['tool_calls'] = entry_meta['tool_calls']
|
||||
messages[insert_pos]['tool_calls'] = _deserialize_tool_call_arguments(entry_meta['tool_calls'])
|
||||
|
||||
if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||
# Check for user message attachments in metadata
|
||||
|
|
|
|||
Loading…
Reference in a new issue