From e6bc7742fb1e28fde10ee939f870eeb75545ec56 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 6 Jul 2025 17:29:29 -0700 Subject: [PATCH] Support installing user extensions in user_data/extensions/ --- modules/extensions.py | 57 ++++++++++++------- modules/utils.py | 14 ++++- .../extensions/place-your-extensions-here.txt | 0 3 files changed, 48 insertions(+), 23 deletions(-) create mode 100644 user_data/extensions/place-your-extensions-here.txt diff --git a/modules/extensions.py b/modules/extensions.py index 6729b996..be9bc38c 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -2,10 +2,10 @@ import importlib import traceback from functools import partial from inspect import signature +from pathlib import Path import gradio as gr -import extensions import modules.shared as shared from modules.logging_colors import logger @@ -28,36 +28,51 @@ def apply_settings(extension, name): def load_extensions(): global state, setup_called state = {} + for i, name in enumerate(shared.args.extensions): - if name in available_extensions: - if name != 'api': - logger.info(f'Loading the extension "{name}"') - try: - try: - extension = importlib.import_module(f"extensions.{name}.script") - except ModuleNotFoundError: - logger.error(f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n* To install requirements for all available extensions, launch the\n update_wizard script for your OS and choose the B option.\n\n* To install the requirements for this extension alone, launch the\n cmd script for your OS and paste the following command in the\n terminal window that appears:\n\nLinux / Mac:\n\npip install -r extensions/{name}/requirements.txt --upgrade\n\nWindows:\n\npip install -r extensions\\{name}\\requirements.txt --upgrade\n") - raise + if name not in available_extensions: + continue - # Only run setup() and apply settings from settings.yaml once - if extension not in setup_called: - apply_settings(extension, name) - if hasattr(extension, "setup"): - extension.setup() + if name != 'api': + logger.info(f'Loading the extension "{name}"') - setup_called.add(extension) + try: + # Prefer user extension, fall back to system extension + user_script_path = Path(f'user_data/extensions/{name}/script.py') + if user_script_path.exists(): + extension = importlib.import_module(f"user_data.extensions.{name}.script") + else: + extension = importlib.import_module(f"extensions.{name}.script") - state[name] = [True, i] - except: - logger.error(f'Failed to load the extension "{name}".') - traceback.print_exc() + if extension not in setup_called: + apply_settings(extension, name) + if hasattr(extension, "setup"): + extension.setup() + setup_called.add(extension) + + state[name] = [True, i, extension] # Store extension object + + except ModuleNotFoundError: + extension_location = Path('user_data/extensions') / name if user_script_path.exists() else Path('extensions') / name + logger.error( + f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n" + f"* To install requirements for all available extensions, launch the\n update_wizard script for your OS and choose the B option.\n\n" + f"* To install the requirements for this extension alone, launch the\n cmd script for your OS and paste the following command in the\n terminal window that appears:\n\n" + f"Linux / Mac:\n\npip install -r {extension_location}/requirements.txt --upgrade\n\n" + f"Windows:\n\npip install -r {extension_location}\\requirements.txt --upgrade\n" + ) + raise + + except Exception: + logger.error(f'Failed to load the extension "{name}".') + traceback.print_exc() # This iterator returns the extensions in the order specified in the command-line def iterator(): for name in sorted(state, key=lambda x: state[x][1]): if state[name][0]: - yield getattr(extensions, name).script, name + yield state[name][2], name # Use stored extension object # Extension functions that map string -> string diff --git a/modules/utils.py b/modules/utils.py index c285d401..117ad590 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -183,8 +183,18 @@ def get_available_instruction_templates(): def get_available_extensions(): - extensions = sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys) - return extensions + # User extensions (higher priority) + user_extensions = [] + user_ext_path = Path('user_data/extensions') + if user_ext_path.exists(): + user_exts = map(lambda x: x.parts[2], user_ext_path.glob('*/script.py')) + user_extensions = sorted(set(user_exts), key=natural_keys) + + # System extensions (excluding those overridden by user extensions) + system_exts = map(lambda x: x.parts[1], Path('extensions').glob('*/script.py')) + system_extensions = sorted(set(system_exts) - set(user_extensions), key=natural_keys) + + return user_extensions + system_extensions def get_available_loras(): diff --git a/user_data/extensions/place-your-extensions-here.txt b/user_data/extensions/place-your-extensions-here.txt new file mode 100644 index 00000000..e69de29b