diff --git a/modules/chat.py b/modules/chat.py index 10969446..b0be2bc2 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -298,6 +298,23 @@ def generate_chat_prompt(user_input, state, **kwargs): if entry_meta.get('tool_calls') and messages[insert_pos].get('role') == 'assistant': messages[insert_pos]['tool_calls'] = _deserialize_tool_call_arguments(entry_meta['tool_calls']) + # Expand tool_sequence from metadata (inserted AFTER assistant so that + # the final order is: user → tool_calls → tool_results → final_answer) + meta_key = f"assistant_{row_idx}" + tool_seq = metadata.get(meta_key, {}).get('tool_sequence', []) + if tool_seq: + for item in reversed(tool_seq): + if 'tool_calls' in item: + messages.insert(insert_pos, { + "role": "assistant", "content": "", + "tool_calls": _deserialize_tool_call_arguments(item['tool_calls']) + }) + elif item.get('role') == 'tool': + messages.insert(insert_pos, { + "role": "tool", "content": item['content'], + "tool_call_id": item.get('tool_call_id', '') + }) + if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']: # Check for user message attachments in metadata user_key = f"user_{row_idx}" @@ -367,6 +384,22 @@ def generate_chat_prompt(user_input, state, **kwargs): messages.append({"role": "user", "content": user_input}) + # Expand tool_sequence for the current entry (excluded from the + # history loop during regenerate — needed so the model sees prior + # tool calls and results when re-generating the final answer). + current_tool_seq = metadata.get(f"assistant_{len(history)}", {}).get('tool_sequence', []) + for item in current_tool_seq: + if 'tool_calls' in item: + messages.append({ + "role": "assistant", "content": "", + "tool_calls": _deserialize_tool_call_arguments(item['tool_calls']) + }) + elif item.get('role') == 'tool': + messages.append({ + "role": "tool", "content": item['content'], + "tool_call_id": item.get('tool_call_id', '') + }) + if impersonate and state['mode'] != 'chat-instruct': messages.append({"role": "user", "content": "fake user message replace me"}) @@ -886,7 +919,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess } else: text, visible_text = output['internal'][-1][0], output['visible'][-1][0] - if regenerate: + if regenerate and not state.get('_tool_turn'): row_idx = len(output['internal']) - 1 # Store the old response as a version before regenerating @@ -984,7 +1017,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')] # Keep version metadata in sync during streaming (for regeneration) - if regenerate: + if regenerate and not state.get('_tool_turn'): row_idx = len(output['internal']) - 1 key = f"assistant_{row_idx}" current_idx = output['metadata'][key]['current_version_index'] @@ -1012,7 +1045,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True) # Final sync for version metadata (in case streaming was disabled) - if regenerate: + if regenerate and not state.get('_tool_turn'): row_idx = len(output['internal']) - 1 key = f"assistant_{row_idx}" current_idx = output['metadata'][key]['current_version_index'] @@ -1066,12 +1099,24 @@ def character_is_loaded(state, raise_exception=False): def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): ''' - Same as above but returns HTML for the UI + Same as above but returns HTML for the UI. + When tools are selected, wraps generation in a loop that detects + tool calls, executes them, and re-generates until the model stops. + All tool output is consolidated into a single visible chat bubble + using metadata['assistant_N']['tool_sequence']. ''' if not character_is_loaded(state): return + # On regenerate, clear old tool_sequence metadata so it gets rebuilt + if regenerate: + history = state['history'] + meta = history.get('metadata', {}) + row_idx = len(history['internal']) - 1 + if row_idx >= 0: + meta.get(f'assistant_{row_idx}', {}).pop('tool_sequence', None) + if state['start_with'] != '' and not _continue: if regenerate: text, state['history'] = remove_last_message(state['history']) @@ -1081,23 +1126,139 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): send_dummy_message(text, state) send_dummy_reply(state['start_with'], state) - history = state['history'] + # Load tools if any are selected + selected = state.get('selected_tools', []) + if selected: + from modules.tool_use import load_tools, execute_tool, generate_tool_call_id + try: + from extensions.openai.utils import parseToolCall + except ImportError: + logger.warning('Tool calling requires the openai extension for parseToolCall. Disabling tools.') + selected = [] + + if selected: + tool_defs, tool_executors = load_tools(selected) + state['tools'] = tool_defs + tool_func_names = [t['function']['name'] for t in tool_defs] + else: + tool_func_names = None + + visible_prefix = [] # Accumulated tool call summaries + results last_save_time = time.monotonic() save_interval = 8 - for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)): - yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history - if i == 0: - time.sleep(0.125) # We need this to make sure the first update goes through + max_tool_turns = 10 - current_time = time.monotonic() - # Save on first iteration or if save_interval seconds have passed - if i == 0 or (current_time - last_save_time) >= save_interval: - save_history(history, state['unique_id'], state['character_menu'], state['mode']) - last_save_time = current_time + for _tool_turn in range(max_tool_turns): + history = state['history'] + + # Turn 0: use original flags; turns 2+: regenerate into the same entry + if _tool_turn > 0: + state['_tool_turn'] = True + + regen = regenerate if _tool_turn == 0 else True + cont = _continue if _tool_turn == 0 else False + cur_text = text if _tool_turn == 0 else '' + + for i, history in enumerate(generate_chat_reply(cur_text, state, regen, cont, loading_message=True, for_ui=True)): + # Prepend accumulated tool output to visible reply + if visible_prefix: + history['visible'][-1][1] = '\n\n'.join(visible_prefix + [history['visible'][-1][1]]) + + yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history + + if i == 0: + time.sleep(0.125) + + current_time = time.monotonic() + if i == 0 or (current_time - last_save_time) >= save_interval: + save_history(history, state['unique_id'], state['character_menu'], state['mode']) + last_save_time = current_time + + # Early stop on tool call detection + if tool_func_names and parseToolCall(history['internal'][-1][1], tool_func_names): + break + + save_history(history, state['unique_id'], state['character_menu'], state['mode']) + + # Check for tool calls + if not tool_func_names or shared.stop_everything: + break + + answer = history['internal'][-1][1] + parsed_calls = parseToolCall(answer, tool_func_names) if answer else None + + if not parsed_calls: + break # No tool calls — done + + # --- Process tool calls --- + row_idx = len(history['internal']) - 1 + meta = history.get('metadata', {}) + seq = meta.setdefault(f'assistant_{row_idx}', {}).setdefault('tool_sequence', []) + + # Serialize tool calls + serialized = [] + for tc in parsed_calls: + tc['id'] = generate_tool_call_id() + args = tc['function'].get('arguments', {}) + serialized.append({ + 'id': tc['id'], + 'type': 'function', + 'function': { + 'name': tc['function']['name'], + 'arguments': json.dumps(args) if isinstance(args, dict) else args + } + }) + + seq.append({'tool_calls': serialized}) + + # Clear internal (raw tool markup) + history['internal'][-1][1] = '' + + # Add call summary to visible prefix + call_summary = ', '.join(f'{tc["function"]["name"]}(...)' for tc in parsed_calls) + visible_prefix.append('Calling: ' + call_summary) + + # Execute tools, store results + for tc in parsed_calls: + fn_name = tc['function']['name'] + fn_args = tc['function'].get('arguments', {}) + result = execute_tool(fn_name, fn_args, tool_executors) + + seq.append({'role': 'tool', 'content': result, 'tool_call_id': tc['id']}) + try: + pretty_result = json.dumps(json.loads(result), indent=2, ensure_ascii=False) + except (json.JSONDecodeError, TypeError): + pretty_result = result + + visible_prefix.append(f'**{fn_name}**\n```json\n{pretty_result}\n```') + + # Show tool results + history['visible'][-1][1] = '\n\n'.join(visible_prefix) + yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history + save_history(history, state['unique_id'], state['character_menu'], state['mode']) + + state['history'] = history + + state.pop('_tool_turn', None) + state['history'] = history + + # Sync version metadata so swipes show the full visible (with tool prefix) + if visible_prefix and history.get('metadata'): + row_idx = len(history['internal']) - 1 + key = f"assistant_{row_idx}" + meta_entry = history['metadata'].get(key, {}) + if 'versions' in meta_entry and 'current_version_index' in meta_entry: + current_idx = meta_entry['current_version_index'] + if current_idx < len(meta_entry['versions']): + meta_entry['versions'][current_idx].update({ + 'content': history['internal'][row_idx][1], + 'visible_content': history['visible'][row_idx][1] + }) save_history(history, state['unique_id'], state['character_menu'], state['mode']) + def remove_last_message(history): if 'metadata' not in history: history['metadata'] = {} diff --git a/modules/shared.py b/modules/shared.py index dbd805a1..395ca83c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -263,6 +263,7 @@ settings = { 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>". Reply directly, without starting the reply with the character name.\n\n<|prompt|>', 'enable_web_search': False, 'web_search_pages': 3, + 'selected_tools': [], 'prompt-notebook': '', 'preset': 'Qwen3 - Thinking' if (user_data_dir / 'presets/Qwen3 - Thinking.yaml').exists() else None, 'max_new_tokens': 512, diff --git a/modules/tool_use.py b/modules/tool_use.py new file mode 100644 index 00000000..cb1e140d --- /dev/null +++ b/modules/tool_use.py @@ -0,0 +1,70 @@ +import importlib.util +import json +import random + +from modules import shared +from modules.logging_colors import logger + + +def get_available_tools(): + """Return sorted list of tool script names from user_data/tools/*.py.""" + tools_dir = shared.user_data_dir / 'tools' + tools_dir.mkdir(parents=True, exist_ok=True) + return sorted(p.stem for p in tools_dir.glob('*.py')) + + +def load_tools(selected_names): + """ + Import selected tool scripts and return their definitions and executors. + Returns (tool_defs, executors) where: + - tool_defs: list of OpenAI-format tool dicts + - executors: dict mapping function_name -> execute callable + """ + tool_defs = [] + executors = {} + for name in selected_names: + path = shared.user_data_dir / 'tools' / f'{name}.py' + if not path.exists(): + continue + + try: + spec = importlib.util.spec_from_file_location(f"tool_{name}", str(path)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + except Exception: + logger.exception(f'Failed to load tool script "{name}"') + continue + + tool_def = getattr(module, 'tool', None) + execute_fn = getattr(module, 'execute', None) + if tool_def is None or execute_fn is None: + logger.warning(f'Tool "{name}" is missing a "tool" dict or "execute" function.') + continue + + func_name = tool_def.get('function', {}).get('name', name) + tool_defs.append(tool_def) + executors[func_name] = execute_fn + + return tool_defs, executors + + +def generate_tool_call_id(): + """Generate a unique tool call ID (e.g. 'call_a1b2c3d4').""" + chars = "abcdefghijklmnopqrstuvwxyz0123456789" + return "call_" + "".join(random.choice(chars) for _ in range(8)) + + +def execute_tool(func_name, arguments, executors): + """Execute a tool by function name. Returns result as a JSON string.""" + fn = executors.get(func_name) + if fn is None: + return json.dumps({"error": f"Unknown tool: {func_name}"}) + + try: + if isinstance(arguments, str): + arguments = json.loads(arguments) + result = fn(arguments) + return json.dumps(result) if not isinstance(result, str) else result + except Exception as e: + logger.exception(f'Tool "{func_name}" execution failed') + return json.dumps({"error": str(e)}) diff --git a/modules/ui.py b/modules/ui.py index 2ab30563..3f39a1a4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -199,6 +199,7 @@ def list_interface_input_elements(): 'unique_id', 'textbox', 'start_with', + 'selected_tools', 'mode', 'chat_style', 'chat-instruct_command', @@ -424,6 +425,7 @@ def setup_auto_save(): 'user_bio', 'custom_system_message', 'chat_template_str', + 'selected_tools', # Parameters tab (ui_parameters.py) - Generation parameters 'preset_menu', diff --git a/modules/ui_chat.py b/modules/ui_chat.py index 74da0a40..9c7424e7 100644 --- a/modules/ui_chat.py +++ b/modules/ui_chat.py @@ -91,6 +91,11 @@ def create_ui(): gr.HTML("") + from modules.tool_use import get_available_tools + shared.gradio['selected_tools'] = gr.CheckboxGroup(choices=get_available_tools(), value=[], label='Tools', info='Functions the model can call during generation.') + + gr.HTML("") + with gr.Row(): shared.gradio['mode'] = gr.Radio(choices=['instruct', 'chat-instruct', 'chat'], value=None, label='Mode', info='In instruct and chat-instruct modes, the template under Parameters > Instruction template is used.', elem_id='chat-mode')