From d78abe480b06a977388798099c068ed3c1fed812 Mon Sep 17 00:00:00 2001 From: Googolplexed <65880807+Googolplexed0@users.noreply.github.com> Date: Fri, 18 Apr 2025 01:53:59 -0400 Subject: [PATCH] Allow for model subfolder organization for GGUF files (#6686) --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com> --- modules/models_settings.py | 8 +++--- modules/utils.py | 58 ++++++++++++++++++++++++++++++++------ 2 files changed, 53 insertions(+), 13 deletions(-) diff --git a/modules/models_settings.py b/modules/models_settings.py index b83544d4..693c4dde 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -29,7 +29,7 @@ def get_model_metadata(model): # Get settings from models/config.yaml and models/config-user.yaml settings = shared.model_config for pat in settings: - if re.match(pat.lower(), model.lower()): + if re.match(pat.lower(), Path(model).name.lower()): for k in settings[pat]: model_settings[k] = settings[pat][k] @@ -148,7 +148,7 @@ def get_model_metadata(model): # Apply user settings from models/config-user.yaml settings = shared.user_config for pat in settings: - if re.match(pat.lower(), model.lower()): + if re.match(pat.lower(), Path(model).name.lower()): for k in settings[pat]: model_settings[k] = settings[pat][k] @@ -254,7 +254,7 @@ def save_model_settings(model, state): return user_config = shared.load_user_config() - model_regex = model + '$' # For exact matches + model_regex = Path(model).name + '$' # For exact matches if model_regex not in user_config: user_config[model_regex] = {} @@ -281,7 +281,7 @@ def save_instruction_template(model, template): return user_config = shared.load_user_config() - model_regex = model + '$' # For exact matches + model_regex = Path(model).name + '$' # For exact matches if model_regex not in user_config: user_config[model_regex] = {} diff --git a/modules/utils.py b/modules/utils.py index f4333031..81cd085a 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -73,21 +73,61 @@ def natural_keys(text): def get_available_models(): - model_list = [] - for item in list(Path(f'{shared.args.model_dir}/').glob('*')): - if not item.name.endswith(('.txt', '-np', '.pt', '.json', '.yaml', '.py')) and 'llama-tokenizer' not in item.name: - model_list.append(item.name) + # Get all GGUF files + gguf_files = get_available_ggufs() - return ['None'] + sorted(model_list, key=natural_keys) + model_dir = Path(shared.args.model_dir) + + # Find top-level directories containing GGUF files + dirs_with_gguf = set() + for gguf_path in gguf_files: + path = Path(gguf_path) + if path.parts: # If in a subdirectory + dirs_with_gguf.add(path.parts[0]) # Add top-level directory + + # Find directories with safetensors files directly under them + dirs_with_safetensors = set() + for item in os.listdir(model_dir): + item_path = model_dir / item + if item_path.is_dir(): + # Check if there are safetensors files directly under this directory + if any(file.lower().endswith('.safetensors') for file in os.listdir(item_path) if (item_path / file).is_file()): + dirs_with_safetensors.add(item) + + # Find valid model directories + model_dirs = [] + + for item in os.listdir(model_dir): + item_path = model_dir / item + + # Skip if not a directory + if not item_path.is_dir(): + continue + + # Include directory if it either: + # 1. Doesn't contain GGUF files, OR + # 2. Contains both GGUF and safetensors files + if item not in dirs_with_gguf or item in dirs_with_safetensors: + model_dirs.append(item) + + model_dirs = sorted(model_dirs, key=natural_keys) + + # Combine all models + return ['None'] + gguf_files + model_dirs def get_available_ggufs(): model_list = [] - for item in Path(f'{shared.args.model_dir}/').glob('*'): - if item.is_file() and item.name.lower().endswith(".gguf"): - model_list.append(item.name) + model_dir = Path(shared.args.model_dir) - return ['None'] + sorted(model_list, key=natural_keys) + for dirpath, _, files in os.walk(model_dir, followlinks=True): + for file in files: + if file.lower().endswith(".gguf"): + model_path = Path(dirpath) / file + rel_path = model_path.relative_to(model_dir) + model_list.append(str(rel_path)) + + return sorted(model_list, key=natural_keys) def get_available_presets():