From e0e20ab9e7f0dfc529898b80c1a6c44561e85658 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 19 Mar 2026 08:02:23 -0700 Subject: [PATCH] Minor cleanup across multiple modules --- extensions/openai/completions.py | 4 +- modules/llama_cpp_server.py | 5 +-- modules/shared.py | 6 +-- modules/tool_parsing.py | 76 ++++++-------------------------- modules/training.py | 12 ++--- modules/ui.py | 7 ++- 6 files changed, 28 insertions(+), 82 deletions(-) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index fc17a19a..d0cd9802 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -263,7 +263,7 @@ def convert_history(history): seen_non_system = True meta = {} tool_calls = entry.get("tool_calls") - if tool_calls and isinstance(tool_calls, list) and len(tool_calls) > 0: + if tool_calls and isinstance(tool_calls, list): meta["tool_calls"] = tool_calls if content.strip() == "": content = "" # keep empty content, don't skip @@ -315,7 +315,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p raise InvalidRequestError(message="messages is required", param='messages') tools = None - if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0: + if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and body['tools']: tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails tool_choice = body.get('tool_choice', None) diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 6dd36b2a..2ae01ddc 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -500,9 +500,8 @@ class LlamaServer: health_url = f"http://127.0.0.1:{self.port}/health" while True: # Check if process is still alive - if self.process.poll() is not None: - # Process has terminated - exit_code = self.process.poll() + exit_code = self.process.poll() + if exit_code is not None: raise RuntimeError(f"Server process terminated unexpectedly with exit code: {exit_code}") try: diff --git a/modules/shared.py b/modules/shared.py index 2382e714..37bc5876 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -453,15 +453,11 @@ def load_user_config(): ''' Loads custom model-specific settings ''' + user_config = {} if Path(f'{args.model_dir}/config-user.yaml').exists(): file_content = open(f'{args.model_dir}/config-user.yaml', 'r').read().strip() - if file_content: user_config = yaml.safe_load(file_content) - else: - user_config = {} - else: - user_config = {} return user_config diff --git a/modules/tool_parsing.py b/modules/tool_parsing.py index 0454e901..7a7ed5d8 100644 --- a/modules/tool_parsing.py +++ b/modules/tool_parsing.py @@ -3,6 +3,10 @@ import random import re +def _make_tool_call(name, arguments): + return {"type": "function", "function": {"name": name, "arguments": arguments}} + + def get_tool_call_id() -> str: letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789" b = [random.choice(letter_bytes) for _ in range(8)] @@ -149,13 +153,7 @@ def _parse_channel_tool_calls(answer: str, tool_names: list[str]): if start_pos is None: prefix = answer.rfind('<|start|>assistant', 0, m.start()) start_pos = prefix if prefix != -1 else m.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) + matches.append(_make_tool_call(func_name, arguments)) except json.JSONDecodeError: pass if matches: @@ -185,13 +183,7 @@ def _parse_mistral_token_tool_calls(answer: str, tool_names: list[str]): arguments = json.loads(json_str) if start_pos is None: start_pos = m.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) + matches.append(_make_tool_call(func_name, arguments)) except json.JSONDecodeError: pass return matches, start_pos @@ -226,13 +218,7 @@ def _parse_bare_name_tool_calls(answer: str, tool_names: list[str]): arguments = json.loads(json_str) if start_pos is None: start_pos = match.start() - matches.append({ - "type": "function", - "function": { - "name": name, - "arguments": arguments - } - }) + matches.append(_make_tool_call(name, arguments)) except json.JSONDecodeError: pass return matches, start_pos @@ -269,13 +255,7 @@ def _parse_xml_param_tool_calls(answer: str, tool_names: list[str]): arguments[param_name] = param_value if start_pos is None: start_pos = tc_match.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) + matches.append(_make_tool_call(func_name, arguments)) return matches, start_pos @@ -305,13 +285,7 @@ def _parse_kimi_tool_calls(answer: str, tool_names: list[str]): # Check for section begin marker before the call marker section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start()) start_pos = section if section != -1 else m.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) + matches.append(_make_tool_call(func_name, arguments)) except json.JSONDecodeError: pass return matches, start_pos @@ -348,13 +322,7 @@ def _parse_minimax_tool_calls(answer: str, tool_names: list[str]): arguments[param_name] = param_value if start_pos is None: start_pos = tc_match.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) + matches.append(_make_tool_call(func_name, arguments)) return matches, start_pos @@ -382,13 +350,7 @@ def _parse_deep_seek_tool_calls(answer: str, tool_names: list[str]): # Check for section begin marker before the call marker section = answer.rfind('<|tool▁calls▁begin|>', 0, m.start()) start_pos = section if section != -1 else m.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) + matches.append(_make_tool_call(func_name, arguments)) except json.JSONDecodeError: pass return matches, start_pos @@ -428,13 +390,7 @@ def _parse_glm_tool_calls(answer: str, tool_names: list[str]): arguments[k] = v if start_pos is None: start_pos = tc_match.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) + matches.append(_make_tool_call(func_name, arguments)) return matches, start_pos @@ -486,13 +442,7 @@ def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]): if start_pos is None: start_pos = bracket_match.start() - matches.append({ - "type": "function", - "function": { - "name": func_name, - "arguments": arguments - } - }) + matches.append(_make_tool_call(func_name, arguments)) return matches, start_pos diff --git a/modules/training.py b/modules/training.py index a13a2864..145353c6 100644 --- a/modules/training.py +++ b/modules/training.py @@ -732,11 +732,13 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: if lora_all_param > 0: print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})") - train_log.update({"base_model_name": shared.model_name}) - train_log.update({"base_model_class": shared.model.__class__.__name__}) - train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)}) - train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)}) - train_log.update({"projections": projections_string}) + train_log.update({ + "base_model_name": shared.model_name, + "base_model_class": shared.model.__class__.__name__, + "base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False), + "base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False), + "projections": projections_string, + }) if stop_at_loss > 0: print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m") diff --git a/modules/ui.py b/modules/ui.py index bbb22266..20bc8373 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -299,7 +299,7 @@ def apply_interface_values(state, use_persistent=False): elements = list_interface_input_elements() - if len(state) == 0: + if not state: return [gr.update() for k in elements] # Dummy, do nothing else: return [state[k] if k in state else gr.update() for k in elements] @@ -307,9 +307,8 @@ def apply_interface_values(state, use_persistent=False): def save_settings(state, preset, extensions_list, show_controls, theme_state, manual_save=False): output = copy.deepcopy(shared.settings) - exclude = [] for k in state: - if k in shared.settings and k not in exclude: + if k in shared.settings: output[k] = state[k] if preset: @@ -323,7 +322,7 @@ def save_settings(state, preset, extensions_list, show_controls, theme_state, ma output['custom_stopping_strings'] = output.get('custom_stopping_strings') or '' output['custom_token_bans'] = output.get('custom_token_bans') or '' output['show_controls'] = show_controls - output['dark_theme'] = True if theme_state == 'dark' else False + output['dark_theme'] = theme_state == 'dark' output.pop('instruction_template_str') output.pop('truncation_length')