From 6e2b70bde60c089f97b0abe97bb1b594cce75357 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 2 Apr 2026 20:26:09 -0700 Subject: [PATCH] Add Gemma 4 tool-calling support --- modules/chat.py | 57 +++++++++++++++++++++++++++++ modules/reasoning.py | 1 + modules/tool_parsing.py | 79 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+) diff --git a/modules/chat.py b/modules/chat.py index edda11b0..818309e6 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -210,6 +210,57 @@ def _expand_tool_sequence(tool_seq): return messages +def _convert_to_tool_responses(messages): + """Convert role:'tool' messages to tool_responses format. + + Templates like Gemma 4 expect tool results as a ``tool_responses`` + attribute on a message rather than separate ``role: 'tool'`` messages. + This function groups consecutive tool messages and rewrites them. + """ + result = [] + tc_id_to_name = {} + + i = 0 + while i < len(messages): + msg = messages[i] + + if msg.get('tool_calls'): + for tc in msg['tool_calls']: + tc_id = tc.get('id', '') + func_name = tc.get('function', {}).get('name', 'unknown') + if tc_id: + tc_id_to_name[tc_id] = func_name + + if msg.get('role') == 'tool': + tool_responses = [] + while i < len(messages) and messages[i].get('role') == 'tool': + tool_msg = messages[i] + tc_id = tool_msg.get('tool_call_id', '') + func_name = tc_id_to_name.get(tc_id, 'unknown') + + content = tool_msg.get('content', '') + try: + response = json.loads(content) + except (json.JSONDecodeError, ValueError, TypeError): + response = content + + tool_responses.append({ + 'name': func_name, + 'response': response, + }) + i += 1 + + result.append({ + 'role': 'tool', + 'tool_responses': tool_responses, + }) + else: + result.append(msg) + i += 1 + + return result + + def _format_attachments(attachments, include_text=True): """Build image ref and text attachment strings from a list of attachments.""" attachments_text = "" @@ -267,6 +318,9 @@ def generate_chat_prompt(user_input, state, **kwargs): tools=state['tools'] if 'tools' in state else None, ) + active_template_str = state['instruction_template_str'] if state['mode'] == 'instruct' else chat_template_str + uses_tool_responses = 'tool_responses' in active_template_str + messages = [] if state['mode'] == 'instruct': @@ -503,6 +557,9 @@ def generate_chat_prompt(user_input, state, **kwargs): return prompt + if uses_tool_responses: + messages = _convert_to_tool_responses(messages) + prompt = make_prompt(messages) # Handle truncation diff --git a/modules/reasoning.py b/modules/reasoning.py index aa1939b8..4a7cfa79 100644 --- a/modules/reasoning.py +++ b/modules/reasoning.py @@ -7,6 +7,7 @@ THINKING_FORMATS = [ ('<|channel|>analysis<|message|>', '<|end|>', '<|channel|>final<|message|>'), ('<|channel|>commentary<|message|>', '<|end|>', '<|channel|>final<|message|>'), ('', '', None), + ('<|channel>thought', '', None), # Gemma 4 ('<|think|>', '<|end|>', '<|content|>'), # Solar Open # ('Thinking Process:', '', None), # Qwen3.5 verbose thinking outside tags -- removed: too prone to false positives in streaming (None, '', None), # End-only variant (e.g., Qwen3-next) diff --git a/modules/tool_parsing.py b/modules/tool_parsing.py index ec49f77f..45da25c9 100644 --- a/modules/tool_parsing.py +++ b/modules/tool_parsing.py @@ -27,6 +27,7 @@ TOOL_CALL_OPENING_MARKERS = [ '[TOOL_CALLS]', 'to=functions.', '<|channel|>commentary', + '<|tool_call>call:', ] @@ -400,6 +401,78 @@ def _parse_glm_tool_calls(answer: str, tool_names: list[str]): return matches, start_pos +def _extract_gemma4_balanced(text, start): + """Extract balanced braces from Gemma 4 format, using <|"|> as string delimiters.""" + if start >= len(text) or text[start] != '{': + return None + depth = 0 + in_string = False + quote_token = '<|"|>' + quote_len = len(quote_token) + i = start + while i < len(text): + if text[i:i + quote_len] == quote_token: + in_string = not in_string + i += quote_len + continue + if in_string: + i += 1 + continue + c = text[i] + if c == '{': + depth += 1 + elif c == '}': + depth -= 1 + if depth == 0: + return text[start:i + 1] + i += 1 + return None + + +def _parse_gemma4_tool_calls(answer: str, tool_names: list[str]): + """Parse Gemma 4-style tool calls. + + Format: + <|tool_call>call:func_name{key:<|"|>value<|"|>,...} + + Values use <|"|> tokens instead of standard JSON quotes, and keys are + bare identifiers. + """ + matches = [] + start_pos = None + + for m in re.finditer(r'<\|tool_call>call:([^\s{]+)\s*', answer): + func_name = m.group(1).strip() + if func_name not in tool_names: + continue + + brace_start = m.end() + if brace_start >= len(answer) or answer[brace_start] != '{': + continue + + content = _extract_gemma4_balanced(answer, brace_start) + if content is None: + continue + + # Convert to JSON: split on <|"|> tokens so that key quoting + # only applies outside string values (even-indexed parts), + # then rejoin with real quotes. + parts = content.split('<|"|>') + for idx in range(0, len(parts), 2): + parts[idx] = re.sub(r'(^|[{,\[])\s*(\w+)\s*:', r'\1"\2":', parts[idx]) + json_str = '"'.join(parts) + + try: + arguments = json.loads(json_str) + if start_pos is None: + start_pos = m.start() + matches.append(_make_tool_call(func_name, arguments)) + except (json.JSONDecodeError, ValueError): + pass + + return matches, start_pos + + def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]): """Parse pythonic-style tool calls used by Llama 4 and similar models. @@ -472,6 +545,11 @@ TOOL_CALL_FORMATS = [ 'parser': _parse_channel_tool_calls, 'markers': ['to=functions.', '<|channel|>commentary'], }, + { + 'template_hints': ['<|tool_call>call:'], + 'parser': _parse_gemma4_tool_calls, + 'markers': ['<|tool_call>call:'], + }, { 'template_hints': ['minimax:tool_call'], 'parser': _parse_minimax_tool_calls, @@ -504,6 +582,7 @@ ALL_PARSERS = [ _parse_deep_seek_tool_calls, _parse_kimi_tool_calls, _parse_channel_tool_calls, + _parse_gemma4_tool_calls, _parse_minimax_tool_calls, _parse_glm_tool_calls, _parse_xml_param_tool_calls,