Improve tool call parsing for Devstral/GPT-OSS and preserve thinking across tool turns

This commit is contained in:
oobabooga 2026-03-13 11:00:12 -03:00
parent e50b823eee
commit e0a38da9f3
2 changed files with 62 additions and 5 deletions

View file

@ -126,8 +126,49 @@ def _parseChannelToolCalls(answer: str, tool_names: list[str]):
"""
matches = []
start_pos = None
for m in re.finditer(
# Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format)
# Pattern 2: to=functions.NAME after <|channel|> (alternative format)
patterns = [
r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>',
r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>',
]
for pattern in patterns:
for m in re.finditer(pattern, answer):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
json_str = _extractBalancedJson(answer, m.end())
if json_str is None:
continue
try:
arguments = json.loads(json_str)
if start_pos is None:
prefix = answer.rfind('<|start|>assistant', 0, m.start())
start_pos = prefix if prefix != -1 else m.start()
matches.append({
"type": "function",
"function": {
"name": func_name,
"arguments": arguments
}
})
except json.JSONDecodeError:
pass
if matches:
break
return matches, start_pos
def _parseMistralTokenToolCalls(answer: str, tool_names: list[str]):
"""Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens.
Format:
[TOOL_CALLS]func_name[ARGS]{"arg": "value"}
"""
matches = []
start_pos = None
for m in re.finditer(
r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*',
answer
):
func_name = m.group(1).strip()
@ -139,8 +180,7 @@ def _parseChannelToolCalls(answer: str, tool_names: list[str]):
try:
arguments = json.loads(json_str)
if start_pos is None:
prefix = answer.rfind('<|start|>assistant', 0, m.start())
start_pos = prefix if prefix != -1 else m.start()
start_pos = m.start()
matches.append({
"type": "function",
"function": {
@ -497,6 +537,11 @@ def parseToolCall(answer: str, tool_names: list[str], return_prefix: bool = Fals
if matches:
return _return(matches, start_pos)
# Check for Mistral/Devstral-style tool calls ([TOOL_CALLS]name[ARGS]json)
matches, start_pos = _parseMistralTokenToolCalls(answer, tool_names)
if matches:
return _return(matches, start_pos)
# Check for bare function-name style tool calls (e.g. Mistral format)
matches, start_pos = _parseBareNameToolCalls(answer, tool_names)
if matches:

View file

@ -177,7 +177,7 @@ def _expand_tool_sequence(tool_seq):
deserialized = _deserialize_tool_call_arguments(item['tool_calls'])
messages.append({
"role": "assistant",
"content": "",
"content": item.get('content', ''),
"tool_calls": deserialized
})
for tc in item['tool_calls']:
@ -1324,7 +1324,19 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
tc_headers.append(f'{fn_name}({args_summary})')
seq.append({'tool_calls': serialized})
seq_entry = {'tool_calls': serialized}
if content_prefix.strip():
# Strip GPT-OSS channel tokens so they don't get double-wrapped
# by the template (which adds its own channel markup).
clean = content_prefix.strip()
if '<|channel|>' in clean and '<|message|>' in clean:
inner = clean.split('<|message|>', 1)[1] if '<|message|>' in clean else clean
if '<|end|>' in inner:
inner = inner.split('<|end|>', 1)[0]
clean = inner.strip()
if clean:
seq_entry['content'] = clean
seq.append(seq_entry)
# Clear internal (raw tool markup)
history['internal'][-1][1] = ''