From e0a38da9f31c95332a5ca863217b1b2e485aecdc Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 13 Mar 2026 11:00:12 -0300 Subject: [PATCH] Improve tool call parsing for Devstral/GPT-OSS and preserve thinking across tool turns --- extensions/openai/utils.py | 51 +++++++++++++++++++++++++++++++++++--- modules/chat.py | 16 ++++++++++-- 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/extensions/openai/utils.py b/extensions/openai/utils.py index eb34ce88..b179c267 100644 --- a/extensions/openai/utils.py +++ b/extensions/openai/utils.py @@ -126,8 +126,49 @@ def _parseChannelToolCalls(answer: str, tool_names: list[str]): """ matches = [] start_pos = None - for m in re.finditer( + # Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format) + # Pattern 2: to=functions.NAME after <|channel|> (alternative format) + patterns = [ + r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>', r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>', + ] + for pattern in patterns: + for m in re.finditer(pattern, answer): + func_name = m.group(1).strip() + if func_name not in tool_names: + continue + json_str = _extractBalancedJson(answer, m.end()) + if json_str is None: + continue + try: + arguments = json.loads(json_str) + if start_pos is None: + prefix = answer.rfind('<|start|>assistant', 0, m.start()) + start_pos = prefix if prefix != -1 else m.start() + matches.append({ + "type": "function", + "function": { + "name": func_name, + "arguments": arguments + } + }) + except json.JSONDecodeError: + pass + if matches: + break + return matches, start_pos + + +def _parseMistralTokenToolCalls(answer: str, tool_names: list[str]): + """Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens. + + Format: + [TOOL_CALLS]func_name[ARGS]{"arg": "value"} + """ + matches = [] + start_pos = None + for m in re.finditer( + r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*', answer ): func_name = m.group(1).strip() @@ -139,8 +180,7 @@ def _parseChannelToolCalls(answer: str, tool_names: list[str]): try: arguments = json.loads(json_str) if start_pos is None: - prefix = answer.rfind('<|start|>assistant', 0, m.start()) - start_pos = prefix if prefix != -1 else m.start() + start_pos = m.start() matches.append({ "type": "function", "function": { @@ -497,6 +537,11 @@ def parseToolCall(answer: str, tool_names: list[str], return_prefix: bool = Fals if matches: return _return(matches, start_pos) + # Check for Mistral/Devstral-style tool calls ([TOOL_CALLS]name[ARGS]json) + matches, start_pos = _parseMistralTokenToolCalls(answer, tool_names) + if matches: + return _return(matches, start_pos) + # Check for bare function-name style tool calls (e.g. Mistral format) matches, start_pos = _parseBareNameToolCalls(answer, tool_names) if matches: diff --git a/modules/chat.py b/modules/chat.py index 57fd50e0..2c6f0ab2 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -177,7 +177,7 @@ def _expand_tool_sequence(tool_seq): deserialized = _deserialize_tool_call_arguments(item['tool_calls']) messages.append({ "role": "assistant", - "content": "", + "content": item.get('content', ''), "tool_calls": deserialized }) for tc in item['tool_calls']: @@ -1324,7 +1324,19 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): tc_headers.append(f'{fn_name}({args_summary})') - seq.append({'tool_calls': serialized}) + seq_entry = {'tool_calls': serialized} + if content_prefix.strip(): + # Strip GPT-OSS channel tokens so they don't get double-wrapped + # by the template (which adds its own channel markup). + clean = content_prefix.strip() + if '<|channel|>' in clean and '<|message|>' in clean: + inner = clean.split('<|message|>', 1)[1] if '<|message|>' in clean else clean + if '<|end|>' in inner: + inner = inner.split('<|end|>', 1)[0] + clean = inner.strip() + if clean: + seq_entry['content'] = clean + seq.append(seq_entry) # Clear internal (raw tool markup) history['internal'][-1][1] = ''