Support installing user extensions in user_data/extensions/

This commit is contained in:
oobabooga 2025-07-06 17:29:29 -07:00
parent 959d4ddb91
commit e6bc7742fb
3 changed files with 48 additions and 23 deletions

View file

@ -2,10 +2,10 @@ import importlib
import traceback import traceback
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from pathlib import Path
import gradio as gr import gradio as gr
import extensions
import modules.shared as shared import modules.shared as shared
from modules.logging_colors import logger from modules.logging_colors import logger
@ -28,36 +28,51 @@ def apply_settings(extension, name):
def load_extensions(): def load_extensions():
global state, setup_called global state, setup_called
state = {} state = {}
for i, name in enumerate(shared.args.extensions): for i, name in enumerate(shared.args.extensions):
if name in available_extensions: if name not in available_extensions:
if name != 'api': continue
logger.info(f'Loading the extension "{name}"')
try:
try:
extension = importlib.import_module(f"extensions.{name}.script")
except ModuleNotFoundError:
logger.error(f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n* To install requirements for all available extensions, launch the\n update_wizard script for your OS and choose the B option.\n\n* To install the requirements for this extension alone, launch the\n cmd script for your OS and paste the following command in the\n terminal window that appears:\n\nLinux / Mac:\n\npip install -r extensions/{name}/requirements.txt --upgrade\n\nWindows:\n\npip install -r extensions\\{name}\\requirements.txt --upgrade\n")
raise
# Only run setup() and apply settings from settings.yaml once if name != 'api':
if extension not in setup_called: logger.info(f'Loading the extension "{name}"')
apply_settings(extension, name)
if hasattr(extension, "setup"):
extension.setup()
setup_called.add(extension) try:
# Prefer user extension, fall back to system extension
user_script_path = Path(f'user_data/extensions/{name}/script.py')
if user_script_path.exists():
extension = importlib.import_module(f"user_data.extensions.{name}.script")
else:
extension = importlib.import_module(f"extensions.{name}.script")
state[name] = [True, i] if extension not in setup_called:
except: apply_settings(extension, name)
logger.error(f'Failed to load the extension "{name}".') if hasattr(extension, "setup"):
traceback.print_exc() extension.setup()
setup_called.add(extension)
state[name] = [True, i, extension] # Store extension object
except ModuleNotFoundError:
extension_location = Path('user_data/extensions') / name if user_script_path.exists() else Path('extensions') / name
logger.error(
f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n"
f"* To install requirements for all available extensions, launch the\n update_wizard script for your OS and choose the B option.\n\n"
f"* To install the requirements for this extension alone, launch the\n cmd script for your OS and paste the following command in the\n terminal window that appears:\n\n"
f"Linux / Mac:\n\npip install -r {extension_location}/requirements.txt --upgrade\n\n"
f"Windows:\n\npip install -r {extension_location}\\requirements.txt --upgrade\n"
)
raise
except Exception:
logger.error(f'Failed to load the extension "{name}".')
traceback.print_exc()
# This iterator returns the extensions in the order specified in the command-line # This iterator returns the extensions in the order specified in the command-line
def iterator(): def iterator():
for name in sorted(state, key=lambda x: state[x][1]): for name in sorted(state, key=lambda x: state[x][1]):
if state[name][0]: if state[name][0]:
yield getattr(extensions, name).script, name yield state[name][2], name # Use stored extension object
# Extension functions that map string -> string # Extension functions that map string -> string

View file

@ -183,8 +183,18 @@ def get_available_instruction_templates():
def get_available_extensions(): def get_available_extensions():
extensions = sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys) # User extensions (higher priority)
return extensions user_extensions = []
user_ext_path = Path('user_data/extensions')
if user_ext_path.exists():
user_exts = map(lambda x: x.parts[2], user_ext_path.glob('*/script.py'))
user_extensions = sorted(set(user_exts), key=natural_keys)
# System extensions (excluding those overridden by user extensions)
system_exts = map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))
system_extensions = sorted(set(system_exts) - set(user_extensions), key=natural_keys)
return user_extensions + system_extensions
def get_available_loras(): def get_available_loras():