From 6e2b70bde60c089f97b0abe97bb1b594cce75357 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 2 Apr 2026 20:26:09 -0700
Subject: [PATCH] Add Gemma 4 tool-calling support
---
modules/chat.py | 57 +++++++++++++++++++++++++++++
modules/reasoning.py | 1 +
modules/tool_parsing.py | 79 +++++++++++++++++++++++++++++++++++++++++
3 files changed, 137 insertions(+)
diff --git a/modules/chat.py b/modules/chat.py
index edda11b0..818309e6 100644
--- a/modules/chat.py
+++ b/modules/chat.py
@@ -210,6 +210,57 @@ def _expand_tool_sequence(tool_seq):
return messages
+def _convert_to_tool_responses(messages):
+ """Convert role:'tool' messages to tool_responses format.
+
+ Templates like Gemma 4 expect tool results as a ``tool_responses``
+ attribute on a message rather than separate ``role: 'tool'`` messages.
+ This function groups consecutive tool messages and rewrites them.
+ """
+ result = []
+ tc_id_to_name = {}
+
+ i = 0
+ while i < len(messages):
+ msg = messages[i]
+
+ if msg.get('tool_calls'):
+ for tc in msg['tool_calls']:
+ tc_id = tc.get('id', '')
+ func_name = tc.get('function', {}).get('name', 'unknown')
+ if tc_id:
+ tc_id_to_name[tc_id] = func_name
+
+ if msg.get('role') == 'tool':
+ tool_responses = []
+ while i < len(messages) and messages[i].get('role') == 'tool':
+ tool_msg = messages[i]
+ tc_id = tool_msg.get('tool_call_id', '')
+ func_name = tc_id_to_name.get(tc_id, 'unknown')
+
+ content = tool_msg.get('content', '')
+ try:
+ response = json.loads(content)
+ except (json.JSONDecodeError, ValueError, TypeError):
+ response = content
+
+ tool_responses.append({
+ 'name': func_name,
+ 'response': response,
+ })
+ i += 1
+
+ result.append({
+ 'role': 'tool',
+ 'tool_responses': tool_responses,
+ })
+ else:
+ result.append(msg)
+ i += 1
+
+ return result
+
+
def _format_attachments(attachments, include_text=True):
"""Build image ref and text attachment strings from a list of attachments."""
attachments_text = ""
@@ -267,6 +318,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
tools=state['tools'] if 'tools' in state else None,
)
+ active_template_str = state['instruction_template_str'] if state['mode'] == 'instruct' else chat_template_str
+ uses_tool_responses = 'tool_responses' in active_template_str
+
messages = []
if state['mode'] == 'instruct':
@@ -503,6 +557,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
return prompt
+ if uses_tool_responses:
+ messages = _convert_to_tool_responses(messages)
+
prompt = make_prompt(messages)
# Handle truncation
diff --git a/modules/reasoning.py b/modules/reasoning.py
index aa1939b8..4a7cfa79 100644
--- a/modules/reasoning.py
+++ b/modules/reasoning.py
@@ -7,6 +7,7 @@ THINKING_FORMATS = [
('<|channel|>analysis<|message|>', '<|end|>', '<|channel|>final<|message|>'),
('<|channel|>commentary<|message|>', '<|end|>', '<|channel|>final<|message|>'),
('', '', None),
+ ('<|channel>thought', '', None), # Gemma 4
('<|think|>', '<|end|>', '<|content|>'), # Solar Open
# ('Thinking Process:', '', None), # Qwen3.5 verbose thinking outside tags -- removed: too prone to false positives in streaming
(None, '', None), # End-only variant (e.g., Qwen3-next)
diff --git a/modules/tool_parsing.py b/modules/tool_parsing.py
index ec49f77f..45da25c9 100644
--- a/modules/tool_parsing.py
+++ b/modules/tool_parsing.py
@@ -27,6 +27,7 @@ TOOL_CALL_OPENING_MARKERS = [
'[TOOL_CALLS]',
'to=functions.',
'<|channel|>commentary',
+ '<|tool_call>call:',
]
@@ -400,6 +401,78 @@ def _parse_glm_tool_calls(answer: str, tool_names: list[str]):
return matches, start_pos
+def _extract_gemma4_balanced(text, start):
+ """Extract balanced braces from Gemma 4 format, using <|"|> as string delimiters."""
+ if start >= len(text) or text[start] != '{':
+ return None
+ depth = 0
+ in_string = False
+ quote_token = '<|"|>'
+ quote_len = len(quote_token)
+ i = start
+ while i < len(text):
+ if text[i:i + quote_len] == quote_token:
+ in_string = not in_string
+ i += quote_len
+ continue
+ if in_string:
+ i += 1
+ continue
+ c = text[i]
+ if c == '{':
+ depth += 1
+ elif c == '}':
+ depth -= 1
+ if depth == 0:
+ return text[start:i + 1]
+ i += 1
+ return None
+
+
+def _parse_gemma4_tool_calls(answer: str, tool_names: list[str]):
+ """Parse Gemma 4-style tool calls.
+
+ Format:
+ <|tool_call>call:func_name{key:<|"|>value<|"|>,...}
+
+ Values use <|"|> tokens instead of standard JSON quotes, and keys are
+ bare identifiers.
+ """
+ matches = []
+ start_pos = None
+
+ for m in re.finditer(r'<\|tool_call>call:([^\s{]+)\s*', answer):
+ func_name = m.group(1).strip()
+ if func_name not in tool_names:
+ continue
+
+ brace_start = m.end()
+ if brace_start >= len(answer) or answer[brace_start] != '{':
+ continue
+
+ content = _extract_gemma4_balanced(answer, brace_start)
+ if content is None:
+ continue
+
+ # Convert to JSON: split on <|"|> tokens so that key quoting
+ # only applies outside string values (even-indexed parts),
+ # then rejoin with real quotes.
+ parts = content.split('<|"|>')
+ for idx in range(0, len(parts), 2):
+ parts[idx] = re.sub(r'(^|[{,\[])\s*(\w+)\s*:', r'\1"\2":', parts[idx])
+ json_str = '"'.join(parts)
+
+ try:
+ arguments = json.loads(json_str)
+ if start_pos is None:
+ start_pos = m.start()
+ matches.append(_make_tool_call(func_name, arguments))
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ return matches, start_pos
+
+
def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]):
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
@@ -472,6 +545,11 @@ TOOL_CALL_FORMATS = [
'parser': _parse_channel_tool_calls,
'markers': ['to=functions.', '<|channel|>commentary'],
},
+ {
+ 'template_hints': ['<|tool_call>call:'],
+ 'parser': _parse_gemma4_tool_calls,
+ 'markers': ['<|tool_call>call:'],
+ },
{
'template_hints': ['minimax:tool_call'],
'parser': _parse_minimax_tool_calls,
@@ -504,6 +582,7 @@ ALL_PARSERS = [
_parse_deep_seek_tool_calls,
_parse_kimi_tool_calls,
_parse_channel_tool_calls,
+ _parse_gemma4_tool_calls,
_parse_minimax_tool_calls,
_parse_glm_tool_calls,
_parse_xml_param_tool_calls,