mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-18 03:14:39 +01:00
Improve tool call parsing for Devstral/GPT-OSS and preserve thinking across tool turns
This commit is contained in:
parent
e50b823eee
commit
e0a38da9f3
|
|
@ -126,8 +126,49 @@ def _parseChannelToolCalls(answer: str, tool_names: list[str]):
|
|||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for m in re.finditer(
|
||||
# Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format)
|
||||
# Pattern 2: to=functions.NAME after <|channel|> (alternative format)
|
||||
patterns = [
|
||||
r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>',
|
||||
r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>',
|
||||
]
|
||||
for pattern in patterns:
|
||||
for m in re.finditer(pattern, answer):
|
||||
func_name = m.group(1).strip()
|
||||
if func_name not in tool_names:
|
||||
continue
|
||||
json_str = _extractBalancedJson(answer, m.end())
|
||||
if json_str is None:
|
||||
continue
|
||||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
prefix = answer.rfind('<|start|>assistant', 0, m.start())
|
||||
start_pos = prefix if prefix != -1 else m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if matches:
|
||||
break
|
||||
return matches, start_pos
|
||||
|
||||
|
||||
def _parseMistralTokenToolCalls(answer: str, tool_names: list[str]):
|
||||
"""Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens.
|
||||
|
||||
Format:
|
||||
[TOOL_CALLS]func_name[ARGS]{"arg": "value"}
|
||||
"""
|
||||
matches = []
|
||||
start_pos = None
|
||||
for m in re.finditer(
|
||||
r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*',
|
||||
answer
|
||||
):
|
||||
func_name = m.group(1).strip()
|
||||
|
|
@ -139,8 +180,7 @@ def _parseChannelToolCalls(answer: str, tool_names: list[str]):
|
|||
try:
|
||||
arguments = json.loads(json_str)
|
||||
if start_pos is None:
|
||||
prefix = answer.rfind('<|start|>assistant', 0, m.start())
|
||||
start_pos = prefix if prefix != -1 else m.start()
|
||||
start_pos = m.start()
|
||||
matches.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
|
@ -497,6 +537,11 @@ def parseToolCall(answer: str, tool_names: list[str], return_prefix: bool = Fals
|
|||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for Mistral/Devstral-style tool calls ([TOOL_CALLS]name[ARGS]json)
|
||||
matches, start_pos = _parseMistralTokenToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
return _return(matches, start_pos)
|
||||
|
||||
# Check for bare function-name style tool calls (e.g. Mistral format)
|
||||
matches, start_pos = _parseBareNameToolCalls(answer, tool_names)
|
||||
if matches:
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ def _expand_tool_sequence(tool_seq):
|
|||
deserialized = _deserialize_tool_call_arguments(item['tool_calls'])
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"content": item.get('content', ''),
|
||||
"tool_calls": deserialized
|
||||
})
|
||||
for tc in item['tool_calls']:
|
||||
|
|
@ -1324,7 +1324,19 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
|
||||
tc_headers.append(f'{fn_name}({args_summary})')
|
||||
|
||||
seq.append({'tool_calls': serialized})
|
||||
seq_entry = {'tool_calls': serialized}
|
||||
if content_prefix.strip():
|
||||
# Strip GPT-OSS channel tokens so they don't get double-wrapped
|
||||
# by the template (which adds its own channel markup).
|
||||
clean = content_prefix.strip()
|
||||
if '<|channel|>' in clean and '<|message|>' in clean:
|
||||
inner = clean.split('<|message|>', 1)[1] if '<|message|>' in clean else clean
|
||||
if '<|end|>' in inner:
|
||||
inner = inner.split('<|end|>', 1)[0]
|
||||
clean = inner.strip()
|
||||
if clean:
|
||||
seq_entry['content'] = clean
|
||||
seq.append(seq_entry)
|
||||
|
||||
# Clear internal (raw tool markup)
|
||||
history['internal'][-1][1] = ''
|
||||
|
|
|
|||
Loading…
Reference in a new issue