Add Gemma 4 tool-calling support

This commit is contained in:
oobabooga 2026-04-02 20:26:09 -07:00
parent b108c55353
commit 6e2b70bde6
3 changed files with 137 additions and 0 deletions

View file

@ -210,6 +210,57 @@ def _expand_tool_sequence(tool_seq):
return messages
def _convert_to_tool_responses(messages):
"""Convert role:'tool' messages to tool_responses format.
Templates like Gemma 4 expect tool results as a ``tool_responses``
attribute on a message rather than separate ``role: 'tool'`` messages.
This function groups consecutive tool messages and rewrites them.
"""
result = []
tc_id_to_name = {}
i = 0
while i < len(messages):
msg = messages[i]
if msg.get('tool_calls'):
for tc in msg['tool_calls']:
tc_id = tc.get('id', '')
func_name = tc.get('function', {}).get('name', 'unknown')
if tc_id:
tc_id_to_name[tc_id] = func_name
if msg.get('role') == 'tool':
tool_responses = []
while i < len(messages) and messages[i].get('role') == 'tool':
tool_msg = messages[i]
tc_id = tool_msg.get('tool_call_id', '')
func_name = tc_id_to_name.get(tc_id, 'unknown')
content = tool_msg.get('content', '')
try:
response = json.loads(content)
except (json.JSONDecodeError, ValueError, TypeError):
response = content
tool_responses.append({
'name': func_name,
'response': response,
})
i += 1
result.append({
'role': 'tool',
'tool_responses': tool_responses,
})
else:
result.append(msg)
i += 1
return result
def _format_attachments(attachments, include_text=True):
"""Build image ref and text attachment strings from a list of attachments."""
attachments_text = ""
@ -267,6 +318,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
tools=state['tools'] if 'tools' in state else None,
)
active_template_str = state['instruction_template_str'] if state['mode'] == 'instruct' else chat_template_str
uses_tool_responses = 'tool_responses' in active_template_str
messages = []
if state['mode'] == 'instruct':
@ -503,6 +557,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
return prompt
if uses_tool_responses:
messages = _convert_to_tool_responses(messages)
prompt = make_prompt(messages)
# Handle truncation

View file

@ -7,6 +7,7 @@ THINKING_FORMATS = [
('<|channel|>analysis<|message|>', '<|end|>', '<|channel|>final<|message|>'),
('<|channel|>commentary<|message|>', '<|end|>', '<|channel|>final<|message|>'),
('<seed:think>', '</seed:think>', None),
('<|channel>thought', '<channel|>', None), # Gemma 4
('<|think|>', '<|end|>', '<|content|>'), # Solar Open
# ('Thinking Process:', '</think>', None), # Qwen3.5 verbose thinking outside tags -- removed: too prone to false positives in streaming
(None, '</think>', None), # End-only variant (e.g., Qwen3-next)

View file

@ -27,6 +27,7 @@ TOOL_CALL_OPENING_MARKERS = [
'[TOOL_CALLS]',
'to=functions.',
'<|channel|>commentary',
'<|tool_call>call:',
]
@ -400,6 +401,78 @@ def _parse_glm_tool_calls(answer: str, tool_names: list[str]):
return matches, start_pos
def _extract_gemma4_balanced(text, start):
"""Extract balanced braces from Gemma 4 format, using <|"|> as string delimiters."""
if start >= len(text) or text[start] != '{':
return None
depth = 0
in_string = False
quote_token = '<|"|>'
quote_len = len(quote_token)
i = start
while i < len(text):
if text[i:i + quote_len] == quote_token:
in_string = not in_string
i += quote_len
continue
if in_string:
i += 1
continue
c = text[i]
if c == '{':
depth += 1
elif c == '}':
depth -= 1
if depth == 0:
return text[start:i + 1]
i += 1
return None
def _parse_gemma4_tool_calls(answer: str, tool_names: list[str]):
"""Parse Gemma 4-style tool calls.
Format:
<|tool_call>call:func_name{key:<|"|>value<|"|>,...}<tool_call|>
Values use <|"|> tokens instead of standard JSON quotes, and keys are
bare identifiers.
"""
matches = []
start_pos = None
for m in re.finditer(r'<\|tool_call>call:([^\s{]+)\s*', answer):
func_name = m.group(1).strip()
if func_name not in tool_names:
continue
brace_start = m.end()
if brace_start >= len(answer) or answer[brace_start] != '{':
continue
content = _extract_gemma4_balanced(answer, brace_start)
if content is None:
continue
# Convert to JSON: split on <|"|> tokens so that key quoting
# only applies outside string values (even-indexed parts),
# then rejoin with real quotes.
parts = content.split('<|"|>')
for idx in range(0, len(parts), 2):
parts[idx] = re.sub(r'(^|[{,\[])\s*(\w+)\s*:', r'\1"\2":', parts[idx])
json_str = '"'.join(parts)
try:
arguments = json.loads(json_str)
if start_pos is None:
start_pos = m.start()
matches.append(_make_tool_call(func_name, arguments))
except (json.JSONDecodeError, ValueError):
pass
return matches, start_pos
def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]):
"""Parse pythonic-style tool calls used by Llama 4 and similar models.
@ -472,6 +545,11 @@ TOOL_CALL_FORMATS = [
'parser': _parse_channel_tool_calls,
'markers': ['to=functions.', '<|channel|>commentary'],
},
{
'template_hints': ['<|tool_call>call:'],
'parser': _parse_gemma4_tool_calls,
'markers': ['<|tool_call>call:'],
},
{
'template_hints': ['minimax:tool_call'],
'parser': _parse_minimax_tool_calls,
@ -504,6 +582,7 @@ ALL_PARSERS = [
_parse_deep_seek_tool_calls,
_parse_kimi_tool_calls,
_parse_channel_tool_calls,
_parse_gemma4_tool_calls,
_parse_minimax_tool_calls,
_parse_glm_tool_calls,
_parse_xml_param_tool_calls,