UI: Fix Continue while in a tool-calling loop, remove the upper limit on number of tool calls

This commit is contained in:
oobabooga 2026-03-12 20:21:01 -07:00
parent 04213dff14
commit 85ec85e569

View file

@ -161,6 +161,49 @@ def _deserialize_tool_call_arguments(tool_calls):
return result
def _expand_tool_sequence(tool_seq):
"""Expand a tool_sequence list into API messages.
Returns a list of dicts (role: assistant with tool_calls, or role: tool).
If any tool_call IDs are missing a matching tool result, a synthetic
empty result is inserted so the prompt is never malformed.
"""
messages = []
expected_ids = []
seen_ids = set()
for item in tool_seq:
if 'tool_calls' in item:
deserialized = _deserialize_tool_call_arguments(item['tool_calls'])
messages.append({
"role": "assistant",
"content": "",
"tool_calls": deserialized
})
for tc in item['tool_calls']:
tc_id = tc.get('id', '')
if tc_id:
expected_ids.append(tc_id)
elif item.get('role') == 'tool':
messages.append({
"role": "tool",
"content": item['content'],
"tool_call_id": item.get('tool_call_id', '')
})
seen_ids.add(item.get('tool_call_id', ''))
# Fill in synthetic results for any orphaned tool call IDs
for tc_id in expected_ids:
if tc_id not in seen_ids:
messages.append({
"role": "tool",
"content": "",
"tool_call_id": tc_id
})
return messages
def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs.get('impersonate', False)
_continue = kwargs.get('_continue', False)
@ -312,17 +355,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
meta_key = f"assistant_{row_idx}"
tool_seq = metadata.get(meta_key, {}).get('tool_sequence', [])
if tool_seq:
for item in reversed(tool_seq):
if 'tool_calls' in item:
messages.insert(insert_pos, {
"role": "assistant", "content": "",
"tool_calls": _deserialize_tool_call_arguments(item['tool_calls'])
})
elif item.get('role') == 'tool':
messages.insert(insert_pos, {
"role": "tool", "content": item['content'],
"tool_call_id": item.get('tool_call_id', '')
})
for msg in reversed(_expand_tool_sequence(tool_seq)):
messages.insert(insert_pos, msg)
if entry_meta.get('role') == 'system':
if user_msg:
@ -400,17 +434,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
# history loop during regenerate — needed so the model sees prior
# tool calls and results when re-generating the final answer).
current_tool_seq = metadata.get(f"assistant_{len(history)}", {}).get('tool_sequence', [])
for item in current_tool_seq:
if 'tool_calls' in item:
messages.append({
"role": "assistant", "content": "",
"tool_calls": _deserialize_tool_call_arguments(item['tool_calls'])
})
elif item.get('role') == 'tool':
messages.append({
"role": "tool", "content": item['content'],
"tool_call_id": item.get('tool_call_id', '')
})
messages.extend(_expand_tool_sequence(current_tool_seq))
if impersonate and state['mode'] != 'chat-instruct':
messages.append({"role": "user", "content": "fake user message replace me"})
@ -1181,9 +1205,8 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
visible_prefix = [] # Accumulated tool call summaries + results
last_save_time = time.monotonic()
save_interval = 8
max_tool_turns = 10
for _tool_turn in range(max_tool_turns):
_tool_turn = 0
while True:
history = state['history']
# Turn 0: use original flags; turns 2+: regenerate into the same entry.
@ -1324,6 +1347,16 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
# Execute tools, store results, and replace placeholders with real results
for i, tc in enumerate(parsed_calls):
# Check for stop request before each tool execution
if shared.stop_everything:
for j in range(i, len(parsed_calls)):
seq.append({'role': 'tool', 'content': 'Tool execution was cancelled by the user.', 'tool_call_id': parsed_calls[j]['id']})
pending_placeholders[j] = f'<tool_call>{tc_headers[j]}\nCancelled\n</tool_call>'
history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders)
yield _render(), history
break
fn_name = tc['function']['name']
fn_args = tc['function'].get('arguments', {})
result = execute_tool(fn_name, fn_args, tool_executors)
@ -1345,6 +1378,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
state['history'] = history
_tool_turn += 1
state.pop('_tool_turn', None)
state['history'] = history