Support installing user extensions in user_data/extensions/

This commit is contained in:
oobabooga 2025-07-06 17:29:29 -07:00
parent 959d4ddb91
commit e6bc7742fb
3 changed files with 48 additions and 23 deletions

View file

@ -2,10 +2,10 @@ import importlib
import traceback
from functools import partial
from inspect import signature
from pathlib import Path
import gradio as gr
import extensions
import modules.shared as shared
from modules.logging_colors import logger
@ -28,36 +28,51 @@ def apply_settings(extension, name):
def load_extensions():
global state, setup_called
state = {}
for i, name in enumerate(shared.args.extensions):
if name in available_extensions:
if name != 'api':
logger.info(f'Loading the extension "{name}"')
try:
try:
extension = importlib.import_module(f"extensions.{name}.script")
except ModuleNotFoundError:
logger.error(f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n* To install requirements for all available extensions, launch the\n update_wizard script for your OS and choose the B option.\n\n* To install the requirements for this extension alone, launch the\n cmd script for your OS and paste the following command in the\n terminal window that appears:\n\nLinux / Mac:\n\npip install -r extensions/{name}/requirements.txt --upgrade\n\nWindows:\n\npip install -r extensions\\{name}\\requirements.txt --upgrade\n")
raise
if name not in available_extensions:
continue
# Only run setup() and apply settings from settings.yaml once
if extension not in setup_called:
apply_settings(extension, name)
if hasattr(extension, "setup"):
extension.setup()
if name != 'api':
logger.info(f'Loading the extension "{name}"')
setup_called.add(extension)
try:
# Prefer user extension, fall back to system extension
user_script_path = Path(f'user_data/extensions/{name}/script.py')
if user_script_path.exists():
extension = importlib.import_module(f"user_data.extensions.{name}.script")
else:
extension = importlib.import_module(f"extensions.{name}.script")
state[name] = [True, i]
except:
logger.error(f'Failed to load the extension "{name}".')
traceback.print_exc()
if extension not in setup_called:
apply_settings(extension, name)
if hasattr(extension, "setup"):
extension.setup()
setup_called.add(extension)
state[name] = [True, i, extension] # Store extension object
except ModuleNotFoundError:
extension_location = Path('user_data/extensions') / name if user_script_path.exists() else Path('extensions') / name
logger.error(
f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n"
f"* To install requirements for all available extensions, launch the\n update_wizard script for your OS and choose the B option.\n\n"
f"* To install the requirements for this extension alone, launch the\n cmd script for your OS and paste the following command in the\n terminal window that appears:\n\n"
f"Linux / Mac:\n\npip install -r {extension_location}/requirements.txt --upgrade\n\n"
f"Windows:\n\npip install -r {extension_location}\\requirements.txt --upgrade\n"
)
raise
except Exception:
logger.error(f'Failed to load the extension "{name}".')
traceback.print_exc()
# This iterator returns the extensions in the order specified in the command-line
def iterator():
for name in sorted(state, key=lambda x: state[x][1]):
if state[name][0]:
yield getattr(extensions, name).script, name
yield state[name][2], name # Use stored extension object
# Extension functions that map string -> string