Support installing user extensions in user_data/extensions/

This commit is contained in:
oobabooga 2025-07-06 17:29:29 -07:00
parent 959d4ddb91
commit e6bc7742fb
3 changed files with 48 additions and 23 deletions

View file

@ -2,10 +2,10 @@ import importlib
import traceback
from functools import partial
from inspect import signature
from pathlib import Path
import gradio as gr
import extensions
import modules.shared as shared
from modules.logging_colors import logger
@ -28,36 +28,51 @@ def apply_settings(extension, name):
def load_extensions():
global state, setup_called
state = {}
for i, name in enumerate(shared.args.extensions):
if name in available_extensions:
if name != 'api':
logger.info(f'Loading the extension "{name}"')
try:
try:
extension = importlib.import_module(f"extensions.{name}.script")
except ModuleNotFoundError:
logger.error(f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n* To install requirements for all available extensions, launch the\n update_wizard script for your OS and choose the B option.\n\n* To install the requirements for this extension alone, launch the\n cmd script for your OS and paste the following command in the\n terminal window that appears:\n\nLinux / Mac:\n\npip install -r extensions/{name}/requirements.txt --upgrade\n\nWindows:\n\npip install -r extensions\\{name}\\requirements.txt --upgrade\n")
raise
if name not in available_extensions:
continue
# Only run setup() and apply settings from settings.yaml once
if extension not in setup_called:
apply_settings(extension, name)
if hasattr(extension, "setup"):
extension.setup()
if name != 'api':
logger.info(f'Loading the extension "{name}"')
setup_called.add(extension)
try:
# Prefer user extension, fall back to system extension
user_script_path = Path(f'user_data/extensions/{name}/script.py')
if user_script_path.exists():
extension = importlib.import_module(f"user_data.extensions.{name}.script")
else:
extension = importlib.import_module(f"extensions.{name}.script")
state[name] = [True, i]
except:
logger.error(f'Failed to load the extension "{name}".')
traceback.print_exc()
if extension not in setup_called:
apply_settings(extension, name)
if hasattr(extension, "setup"):
extension.setup()
setup_called.add(extension)
state[name] = [True, i, extension] # Store extension object
except ModuleNotFoundError:
extension_location = Path('user_data/extensions') / name if user_script_path.exists() else Path('extensions') / name
logger.error(
f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n"
f"* To install requirements for all available extensions, launch the\n update_wizard script for your OS and choose the B option.\n\n"
f"* To install the requirements for this extension alone, launch the\n cmd script for your OS and paste the following command in the\n terminal window that appears:\n\n"
f"Linux / Mac:\n\npip install -r {extension_location}/requirements.txt --upgrade\n\n"
f"Windows:\n\npip install -r {extension_location}\\requirements.txt --upgrade\n"
)
raise
except Exception:
logger.error(f'Failed to load the extension "{name}".')
traceback.print_exc()
# This iterator returns the extensions in the order specified in the command-line
def iterator():
for name in sorted(state, key=lambda x: state[x][1]):
if state[name][0]:
yield getattr(extensions, name).script, name
yield state[name][2], name # Use stored extension object
# Extension functions that map string -> string

View file

@ -183,8 +183,18 @@ def get_available_instruction_templates():
def get_available_extensions():
extensions = sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys)
return extensions
# User extensions (higher priority)
user_extensions = []
user_ext_path = Path('user_data/extensions')
if user_ext_path.exists():
user_exts = map(lambda x: x.parts[2], user_ext_path.glob('*/script.py'))
user_extensions = sorted(set(user_exts), key=natural_keys)
# System extensions (excluding those overridden by user extensions)
system_exts = map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))
system_extensions = sorted(set(system_exts) - set(user_extensions), key=natural_keys)
return user_extensions + system_extensions
def get_available_loras():