mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-18 03:14:39 +01:00
UI: Fix Continue while in a tool-calling loop, remove the upper limit on number of tool calls
This commit is contained in:
parent
04213dff14
commit
85ec85e569
|
|
@ -161,6 +161,49 @@ def _deserialize_tool_call_arguments(tool_calls):
|
|||
return result
|
||||
|
||||
|
||||
def _expand_tool_sequence(tool_seq):
|
||||
"""Expand a tool_sequence list into API messages.
|
||||
|
||||
Returns a list of dicts (role: assistant with tool_calls, or role: tool).
|
||||
If any tool_call IDs are missing a matching tool result, a synthetic
|
||||
empty result is inserted so the prompt is never malformed.
|
||||
"""
|
||||
messages = []
|
||||
expected_ids = []
|
||||
seen_ids = set()
|
||||
|
||||
for item in tool_seq:
|
||||
if 'tool_calls' in item:
|
||||
deserialized = _deserialize_tool_call_arguments(item['tool_calls'])
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": deserialized
|
||||
})
|
||||
for tc in item['tool_calls']:
|
||||
tc_id = tc.get('id', '')
|
||||
if tc_id:
|
||||
expected_ids.append(tc_id)
|
||||
elif item.get('role') == 'tool':
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": item['content'],
|
||||
"tool_call_id": item.get('tool_call_id', '')
|
||||
})
|
||||
seen_ids.add(item.get('tool_call_id', ''))
|
||||
|
||||
# Fill in synthetic results for any orphaned tool call IDs
|
||||
for tc_id in expected_ids:
|
||||
if tc_id not in seen_ids:
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": "",
|
||||
"tool_call_id": tc_id
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def generate_chat_prompt(user_input, state, **kwargs):
|
||||
impersonate = kwargs.get('impersonate', False)
|
||||
_continue = kwargs.get('_continue', False)
|
||||
|
|
@ -312,17 +355,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
meta_key = f"assistant_{row_idx}"
|
||||
tool_seq = metadata.get(meta_key, {}).get('tool_sequence', [])
|
||||
if tool_seq:
|
||||
for item in reversed(tool_seq):
|
||||
if 'tool_calls' in item:
|
||||
messages.insert(insert_pos, {
|
||||
"role": "assistant", "content": "",
|
||||
"tool_calls": _deserialize_tool_call_arguments(item['tool_calls'])
|
||||
})
|
||||
elif item.get('role') == 'tool':
|
||||
messages.insert(insert_pos, {
|
||||
"role": "tool", "content": item['content'],
|
||||
"tool_call_id": item.get('tool_call_id', '')
|
||||
})
|
||||
for msg in reversed(_expand_tool_sequence(tool_seq)):
|
||||
messages.insert(insert_pos, msg)
|
||||
|
||||
if entry_meta.get('role') == 'system':
|
||||
if user_msg:
|
||||
|
|
@ -400,17 +434,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
# history loop during regenerate — needed so the model sees prior
|
||||
# tool calls and results when re-generating the final answer).
|
||||
current_tool_seq = metadata.get(f"assistant_{len(history)}", {}).get('tool_sequence', [])
|
||||
for item in current_tool_seq:
|
||||
if 'tool_calls' in item:
|
||||
messages.append({
|
||||
"role": "assistant", "content": "",
|
||||
"tool_calls": _deserialize_tool_call_arguments(item['tool_calls'])
|
||||
})
|
||||
elif item.get('role') == 'tool':
|
||||
messages.append({
|
||||
"role": "tool", "content": item['content'],
|
||||
"tool_call_id": item.get('tool_call_id', '')
|
||||
})
|
||||
messages.extend(_expand_tool_sequence(current_tool_seq))
|
||||
|
||||
if impersonate and state['mode'] != 'chat-instruct':
|
||||
messages.append({"role": "user", "content": "fake user message replace me"})
|
||||
|
|
@ -1181,9 +1205,8 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
visible_prefix = [] # Accumulated tool call summaries + results
|
||||
last_save_time = time.monotonic()
|
||||
save_interval = 8
|
||||
max_tool_turns = 10
|
||||
|
||||
for _tool_turn in range(max_tool_turns):
|
||||
_tool_turn = 0
|
||||
while True:
|
||||
history = state['history']
|
||||
|
||||
# Turn 0: use original flags; turns 2+: regenerate into the same entry.
|
||||
|
|
@ -1324,6 +1347,16 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
|
||||
# Execute tools, store results, and replace placeholders with real results
|
||||
for i, tc in enumerate(parsed_calls):
|
||||
# Check for stop request before each tool execution
|
||||
if shared.stop_everything:
|
||||
for j in range(i, len(parsed_calls)):
|
||||
seq.append({'role': 'tool', 'content': 'Tool execution was cancelled by the user.', 'tool_call_id': parsed_calls[j]['id']})
|
||||
pending_placeholders[j] = f'<tool_call>{tc_headers[j]}\nCancelled\n</tool_call>'
|
||||
|
||||
history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders)
|
||||
yield _render(), history
|
||||
break
|
||||
|
||||
fn_name = tc['function']['name']
|
||||
fn_args = tc['function'].get('arguments', {})
|
||||
result = execute_tool(fn_name, fn_args, tool_executors)
|
||||
|
|
@ -1345,6 +1378,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
save_history(history, state['unique_id'], state['character_menu'], state['mode'])
|
||||
|
||||
state['history'] = history
|
||||
_tool_turn += 1
|
||||
|
||||
state.pop('_tool_turn', None)
|
||||
state['history'] = history
|
||||
|
|
|
|||
Loading…
Reference in a new issue