From e2548f69a9e9a582f99b4deceeb5e5ed48f49418 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 5 Mar 2026 19:26:21 -0800 Subject: [PATCH] Make user_data configurable: add --user-data-dir flag, auto-detect ../user_data If --user-data-dir is not set, auto-detect: use ../user_data when ./user_data doesn't exist, making it easy to share user data across portable builds by placing it one folder up. --- download-model.py | 14 +++++--- modules/chat.py | 60 +++++++++++++++++----------------- modules/evaluate.py | 8 ++--- modules/extensions.py | 14 ++++++-- modules/html_generator.py | 8 ++--- modules/llama_cpp_server.py | 2 +- modules/presets.py | 2 +- modules/prompts.py | 4 +-- modules/shared.py | 18 ++++++---- modules/training.py | 34 +++++++++---------- modules/ui.py | 6 ++-- modules/ui_chat.py | 2 +- modules/ui_default.py | 12 +++---- modules/ui_file_saving.py | 14 ++++---- modules/ui_image_generation.py | 4 +-- modules/ui_model_menu.py | 6 ++-- modules/ui_notebook.py | 14 ++++---- modules/ui_parameters.py | 2 +- modules/ui_session.py | 4 +-- modules/utils.py | 55 +++++++++++++++++++------------ server.py | 16 ++++----- 21 files changed, 166 insertions(+), 133 deletions(-) diff --git a/download-model.py b/download-model.py index 756d529f..7acb7f46 100644 --- a/download-model.py +++ b/download-model.py @@ -24,6 +24,8 @@ from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, RequestException, Timeout from tqdm.contrib.concurrent import thread_map +from modules.paths import resolve_user_data_dir + base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co" @@ -182,11 +184,13 @@ class ModelDownloader: is_llamacpp = has_gguf and specific_file is not None return links, sha256, is_lora, is_llamacpp, file_sizes - def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, model_dir=None): + def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, model_dir=None, user_data_dir=None): if model_dir: base_folder = model_dir else: - base_folder = 'user_data/models' if not is_lora else 'user_data/loras' + if user_data_dir is None: + user_data_dir = resolve_user_data_dir() + base_folder = str(user_data_dir / 'models') if not is_lora else str(user_data_dir / 'loras') # If the model is of type GGUF, save directly in the base_folder if is_llamacpp: @@ -392,7 +396,8 @@ if __name__ == '__main__': parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).') parser.add_argument('--exclude-pattern', type=str, default=None, help='Regex pattern to exclude files from download.') parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.') - parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/user_data/models).') + parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (user_data/models).') + parser.add_argument('--user-data-dir', type=str, default=None, help='Path to the user data directory. Overrides auto-detection.') parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') parser.add_argument('--max-retries', type=int, default=7, help='Max retries count when get error in download time.') @@ -421,10 +426,11 @@ if __name__ == '__main__': ) # Get the output folder + user_data_dir = Path(args.user_data_dir) if args.user_data_dir else None if args.output: output_folder = Path(args.output) else: - output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=args.model_dir) + output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=args.model_dir, user_data_dir=user_data_dir) if args.check: # Check previously downloaded files diff --git a/modules/chat.py b/modules/chat.py index f9a98bb6..7c58542f 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -1126,9 +1126,9 @@ def start_new_chat(state): def get_history_file_path(unique_id, character, mode): if mode == 'instruct': - p = Path(f'user_data/logs/instruct/{unique_id}.json') + p = shared.user_data_dir / 'logs' / 'instruct' / f'{unique_id}.json' else: - p = Path(f'user_data/logs/chat/{character}/{unique_id}.json') + p = shared.user_data_dir / 'logs' / 'chat' / character / f'{unique_id}.json' return p @@ -1164,13 +1164,13 @@ def rename_history(old_id, new_id, character, mode): def get_paths(state): if state['mode'] == 'instruct': - return Path('user_data/logs/instruct').glob('*.json') + return (shared.user_data_dir / 'logs' / 'instruct').glob('*.json') else: character = state['character_menu'] # Handle obsolete filenames and paths - old_p = Path(f'user_data/logs/{character}_persistent.json') - new_p = Path(f'user_data/logs/persistent_{character}.json') + old_p = shared.user_data_dir / 'logs' / f'{character}_persistent.json' + new_p = shared.user_data_dir / 'logs' / f'persistent_{character}.json' if old_p.exists(): logger.warning(f"Renaming \"{old_p}\" to \"{new_p}\"") old_p.rename(new_p) @@ -1182,7 +1182,7 @@ def get_paths(state): p.parent.mkdir(exist_ok=True) new_p.rename(p) - return Path(f'user_data/logs/chat/{character}').glob('*.json') + return (shared.user_data_dir / 'logs' / 'chat' / character).glob('*.json') def find_all_histories(state): @@ -1307,7 +1307,7 @@ def get_chat_state_key(character, mode): def load_last_chat_state(): """Load the last chat state from file""" - state_file = Path('user_data/logs/chat_state.json') + state_file = shared.user_data_dir / 'logs' / 'chat_state.json' if state_file.exists(): try: with open(state_file, 'r', encoding='utf-8') as f: @@ -1327,7 +1327,7 @@ def save_last_chat_state(character, mode, unique_id): key = get_chat_state_key(character, mode) state["last_chats"][key] = unique_id - state_file = Path('user_data/logs/chat_state.json') + state_file = shared.user_data_dir / 'logs' / 'chat_state.json' state_file.parent.mkdir(exist_ok=True) with open(state_file, 'w', encoding='utf-8') as f: f.write(json.dumps(state, indent=2)) @@ -1403,7 +1403,7 @@ def generate_pfp_cache(character): if not cache_folder.exists(): cache_folder.mkdir() - for path in [Path(f"user_data/characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]: + for path in [shared.user_data_dir / 'characters' / f"{character}.{extension}" for extension in ['png', 'jpg', 'jpeg']]: if path.exists(): original_img = Image.open(path) # Define file paths @@ -1428,12 +1428,12 @@ def load_character(character, name1, name2): filepath = None for extension in ["yml", "yaml", "json"]: - filepath = Path(f'user_data/characters/{character}.{extension}') + filepath = shared.user_data_dir / 'characters' / f'{character}.{extension}' if filepath.exists(): break if filepath is None or not filepath.exists(): - logger.error(f"Could not find the character \"{character}\" inside user_data/characters. No character has been loaded.") + logger.error(f"Could not find the character \"{character}\" inside {shared.user_data_dir}/characters. No character has been loaded.") raise ValueError file_contents = open(filepath, 'r', encoding='utf-8').read() @@ -1509,7 +1509,7 @@ def load_instruction_template(template): if template == 'None': return '' - for filepath in [Path(f'user_data/instruction-templates/{template}.yaml'), Path('user_data/instruction-templates/Alpaca.yaml')]: + for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']: if filepath.exists(): break else: @@ -1552,17 +1552,17 @@ def upload_character(file, img_path, tavern=False): outfile_name = name i = 1 - while Path(f'user_data/characters/{outfile_name}.yaml').exists(): + while (shared.user_data_dir / 'characters' / f'{outfile_name}.yaml').exists(): outfile_name = f'{name}_{i:03d}' i += 1 - with open(Path(f'user_data/characters/{outfile_name}.yaml'), 'w', encoding='utf-8') as f: + with open(shared.user_data_dir / 'characters' / f'{outfile_name}.yaml', 'w', encoding='utf-8') as f: f.write(yaml_data) if img is not None: - img.save(Path(f'user_data/characters/{outfile_name}.png')) + img.save(shared.user_data_dir / 'characters' / f'{outfile_name}.png') - logger.info(f'New character saved to "user_data/characters/{outfile_name}.yaml".') + logger.info(f'New character saved to "{shared.user_data_dir}/characters/{outfile_name}.yaml".') return gr.update(value=outfile_name, choices=get_available_characters()) @@ -1643,9 +1643,9 @@ def save_character(name, greeting, context, picture, filename): return data = generate_character_yaml(name, greeting, context) - filepath = Path(f'user_data/characters/{filename}.yaml') + filepath = shared.user_data_dir / 'characters' / f'{filename}.yaml' save_file(filepath, data) - path_to_img = Path(f'user_data/characters/{filename}.png') + path_to_img = shared.user_data_dir / 'characters' / f'{filename}.png' if picture is not None: # Copy the image file from its source path to the character folder shutil.copy(picture, path_to_img) @@ -1655,11 +1655,11 @@ def save_character(name, greeting, context, picture, filename): def delete_character(name, instruct=False): # Check for character data files for extension in ["yml", "yaml", "json"]: - delete_file(Path(f'user_data/characters/{name}.{extension}')) + delete_file(shared.user_data_dir / 'characters' / f'{name}.{extension}') # Check for character image files for extension in ["png", "jpg", "jpeg"]: - delete_file(Path(f'user_data/characters/{name}.{extension}')) + delete_file(shared.user_data_dir / 'characters' / f'{name}.{extension}') def generate_user_pfp_cache(user): @@ -1668,7 +1668,7 @@ def generate_user_pfp_cache(user): if not cache_folder.exists(): cache_folder.mkdir() - for path in [Path(f"user_data/users/{user}.{extension}") for extension in ['png', 'jpg', 'jpeg']]: + for path in [shared.user_data_dir / 'users' / f"{user}.{extension}" for extension in ['png', 'jpg', 'jpeg']]: if path.exists(): original_img = Image.open(path) # Define file paths @@ -1690,12 +1690,12 @@ def load_user(user_name, name1, user_bio): filepath = None for extension in ["yml", "yaml", "json"]: - filepath = Path(f'user_data/users/{user_name}.{extension}') + filepath = shared.user_data_dir / 'users' / f'{user_name}.{extension}' if filepath.exists(): break if filepath is None or not filepath.exists(): - logger.error(f"Could not find the user \"{user_name}\" inside user_data/users. No user has been loaded.") + logger.error(f"Could not find the user \"{user_name}\" inside {shared.user_data_dir}/users. No user has been loaded.") raise ValueError with open(filepath, 'r', encoding='utf-8') as f: @@ -1741,14 +1741,14 @@ def save_user(name, user_bio, picture, filename): return # Ensure the users directory exists - users_dir = Path('user_data/users') + users_dir = shared.user_data_dir / 'users' users_dir.mkdir(parents=True, exist_ok=True) data = generate_user_yaml(name, user_bio) - filepath = Path(f'user_data/users/{filename}.yaml') + filepath = shared.user_data_dir / 'users' / f'{filename}.yaml' save_file(filepath, data) - path_to_img = Path(f'user_data/users/{filename}.png') + path_to_img = shared.user_data_dir / 'users' / f'{filename}.png' if picture is not None: # Copy the image file from its source path to the users folder shutil.copy(picture, path_to_img) @@ -1759,11 +1759,11 @@ def delete_user(name): """Delete user profile files""" # Check for user data files for extension in ["yml", "yaml", "json"]: - delete_file(Path(f'user_data/users/{name}.{extension}')) + delete_file(shared.user_data_dir / 'users' / f'{name}.{extension}') # Check for user image files for extension in ["png", "jpg", "jpeg"]: - delete_file(Path(f'user_data/users/{name}.{extension}')) + delete_file(shared.user_data_dir / 'users' / f'{name}.{extension}') def update_user_menu_after_deletion(idx): @@ -2224,7 +2224,7 @@ def handle_save_template_click(instruction_template_str): contents = generate_instruction_template_yaml(instruction_template_str) return [ "My Template.yaml", - "user_data/instruction-templates/", + str(shared.user_data_dir / 'instruction-templates') + '/', contents, gr.update(visible=True) ] @@ -2233,7 +2233,7 @@ def handle_save_template_click(instruction_template_str): def handle_delete_template_click(template): return [ f"{template}.yaml", - "user_data/instruction-templates/", + str(shared.user_data_dir / 'instruction-templates') + '/', gr.update(visible=False) ] diff --git a/modules/evaluate.py b/modules/evaluate.py index e3562d5f..78d375cd 100644 --- a/modules/evaluate.py +++ b/modules/evaluate.py @@ -12,8 +12,8 @@ from modules.text_generation import encode def load_past_evaluations(): - if Path('user_data/logs/evaluations.csv').exists(): - df = pd.read_csv(Path('user_data/logs/evaluations.csv'), dtype=str) + if (shared.user_data_dir / 'logs' / 'evaluations.csv').exists(): + df = pd.read_csv(shared.user_data_dir / 'logs' / 'evaluations.csv', dtype=str) df['Perplexity'] = pd.to_numeric(df['Perplexity']) return df else: @@ -26,7 +26,7 @@ past_evaluations = load_past_evaluations() def save_past_evaluations(df): global past_evaluations past_evaluations = df - filepath = Path('user_data/logs/evaluations.csv') + filepath = shared.user_data_dir / 'logs' / 'evaluations.csv' filepath.parent.mkdir(parents=True, exist_ok=True) df.to_csv(filepath, index=False) @@ -65,7 +65,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): data = load_dataset('ptb_text_only', 'penn_treebank', split='test') text = " ".join(data['sentence']) else: - with open(Path(f'user_data/training/datasets/{input_dataset}.txt'), 'r', encoding='utf-8') as f: + with open(shared.user_data_dir / 'training' / 'datasets' / f'{input_dataset}.txt', 'r', encoding='utf-8') as f: text = f.read() for model in models: diff --git a/modules/extensions.py b/modules/extensions.py index e0010312..dd327882 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,4 +1,6 @@ import importlib +import importlib.util +import sys import traceback from functools import partial from inspect import signature @@ -38,9 +40,15 @@ def load_extensions(): try: # Prefer user extension, fall back to system extension - user_script_path = Path(f'user_data/extensions/{name}/script.py') + user_script_path = shared.user_data_dir / 'extensions' / name / 'script.py' if user_script_path.exists(): - extension = importlib.import_module(f"user_data.extensions.{name}.script") + spec = importlib.util.spec_from_file_location( + f"user_ext_{name}", + str(user_script_path) + ) + extension = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = extension + spec.loader.exec_module(extension) else: extension = importlib.import_module(f"extensions.{name}.script") @@ -53,7 +61,7 @@ def load_extensions(): state[name] = [True, i, extension] # Store extension object except ModuleNotFoundError: - extension_location = Path('user_data/extensions') / name if user_script_path.exists() else Path('extensions') / name + extension_location = shared.user_data_dir / 'extensions' / name if user_script_path.exists() else Path('extensions') / name windows_path = str(extension_location).replace('/', '\\') logger.error( f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n" diff --git a/modules/html_generator.py b/modules/html_generator.py index 667a64d6..34a5bc57 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -627,10 +627,10 @@ def generate_instruct_html(history, last_message_only=False): def get_character_image_with_cache_buster(): """Get character image URL with cache busting based on file modification time""" - cache_path = Path("user_data/cache/pfp_character_thumb.png") + cache_path = shared.user_data_dir / "cache" / "pfp_character_thumb.png" if cache_path.exists(): mtime = int(cache_path.stat().st_mtime) - return f'' + return f'' return '' @@ -654,8 +654,8 @@ def generate_cai_chat_html(history, name1, name2, style, character, reset_cache= # Get appropriate image if role == "user": - img = (f'' - if Path("user_data/cache/pfp_me.png").exists() else '') + img = (f'' + if (shared.user_data_dir / "cache" / "pfp_me.png").exists() else '') else: img = img_bot diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 017c5d2a..eb99f2a8 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -384,7 +384,7 @@ class LlamaServer: if shared.args.mmproj not in [None, 'None']: path = Path(shared.args.mmproj) if not path.exists(): - path = Path('user_data/mmproj') / shared.args.mmproj + path = shared.user_data_dir / 'mmproj' / shared.args.mmproj if path.exists(): cmd += ["--mmproj", str(path)] diff --git a/modules/presets.py b/modules/presets.py index 16dfa6ad..9aab7e3c 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -66,7 +66,7 @@ def presets_params(): def load_preset(name, verbose=False): generate_params = default_preset() if name not in ['None', None, '']: - path = Path(f'user_data/presets/{name}.yaml') + path = shared.user_data_dir / 'presets' / f'{name}.yaml' if path.exists(): with open(path, 'r') as infile: preset = yaml.safe_load(infile) diff --git a/modules/prompts.py b/modules/prompts.py index 91f5812a..d107ce5a 100644 --- a/modules/prompts.py +++ b/modules/prompts.py @@ -8,7 +8,7 @@ def load_prompt(fname): if not fname: # Create new file new_name = utils.current_time() - prompt_path = Path("user_data/logs/notebook") / f"{new_name}.txt" + prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" prompt_path.parent.mkdir(parents=True, exist_ok=True) initial_content = "In this story," prompt_path.write_text(initial_content, encoding='utf-8') @@ -18,7 +18,7 @@ def load_prompt(fname): return initial_content - file_path = Path(f'user_data/logs/notebook/{fname}.txt') + file_path = shared.user_data_dir / 'logs' / 'notebook' / f'{fname}.txt' if file_path.exists(): with open(file_path, 'r', encoding='utf-8') as f: text = f.read() diff --git a/modules/shared.py b/modules/shared.py index 4e212d1b..57b7e3e9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -9,8 +9,12 @@ from pathlib import Path import yaml from modules.logging_colors import logger +from modules.paths import resolve_user_data_dir from modules.presets import default_preset +# Resolve user_data directory early (before argparse defaults are set) +user_data_dir = resolve_user_data_dir() + # Text model variables model = None tokenizer = None @@ -42,11 +46,12 @@ parser = argparse.ArgumentParser(description="Text Generation Web UI", conflict_ # Basic settings group = parser.add_argument_group('Basic settings') +group.add_argument('--user-data-dir', type=str, default=str(user_data_dir), help='Path to the user data directory. Default: auto-detected.') group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.') group.add_argument('--model', type=str, help='Name of the model to load by default.') group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.') -group.add_argument('--model-dir', type=str, default='user_data/models', help='Path to directory with all the models.') -group.add_argument('--lora-dir', type=str, default='user_data/loras', help='Path to directory with all the loras.') +group.add_argument('--model-dir', type=str, default=str(user_data_dir / 'models'), help='Path to directory with all the models.') +group.add_argument('--lora-dir', type=str, default=str(user_data_dir / 'loras'), help='Path to directory with all the loras.') group.add_argument('--model-menu', action='store_true', help='Show a model menu in the terminal when the web UI is first launched.') group.add_argument('--settings', type=str, help='Load the default interface settings from this yaml file. See user_data/settings-template.yaml for an example. If you create a file called user_data/settings.yaml, this file will be loaded by default without the need to use the --settings flag.') group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') @@ -56,7 +61,7 @@ group.add_argument('--idle-timeout', type=int, default=0, help='Unload model aft # Image generation group = parser.add_argument_group('Image model') group.add_argument('--image-model', type=str, help='Name of the image model to select on startup (overrides saved setting).') -group.add_argument('--image-model-dir', type=str, default='user_data/image_models', help='Path to directory with all the image models.') +group.add_argument('--image-model-dir', type=str, default=str(user_data_dir / 'image_models'), help='Path to directory with all the image models.') group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16', 'float16'], help='Data type for image model.') group.add_argument('--image-attn-backend', type=str, default=None, choices=['flash_attention_2', 'sdpa'], help='Attention backend for image model.') group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.') @@ -110,7 +115,7 @@ group = parser.add_argument_group('Transformers/Accelerate') group.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.') group.add_argument('--cpu-memory', type=float, default=0, help='Maximum CPU memory in GiB. Use this for CPU offloading.') group.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.') -group.add_argument('--disk-cache-dir', type=str, default='user_data/cache', help='Directory to save the disk cache to. Defaults to "user_data/cache".') +group.add_argument('--disk-cache-dir', type=str, default=str(user_data_dir / 'cache'), help='Directory to save the disk cache to.') group.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision (using bitsandbytes).') group.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') group.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.') @@ -167,7 +172,7 @@ group.add_argument('--api-disable-ipv4', action='store_true', help='Disable IPv4 group.add_argument('--nowebui', action='store_true', help='Do not launch the Gradio UI. Useful for launching the API in standalone mode.') # Handle CMD_FLAGS.txt -cmd_flags_path = Path(__file__).parent.parent / "user_data" / "CMD_FLAGS.txt" +cmd_flags_path = user_data_dir / "CMD_FLAGS.txt" if cmd_flags_path.exists(): with cmd_flags_path.open('r', encoding='utf-8') as f: cmd_flags = ' '.join( @@ -182,6 +187,7 @@ if cmd_flags_path.exists(): args = parser.parse_args() +user_data_dir = Path(args.user_data_dir) # Update from parsed args (may differ from pre-parse) original_args = copy.deepcopy(args) args_defaults = parser.parse_args([]) @@ -212,7 +218,7 @@ settings = { 'enable_web_search': False, 'web_search_pages': 3, 'prompt-notebook': '', - 'preset': 'Qwen3 - Thinking' if Path('user_data/presets/Qwen3 - Thinking.yaml').exists() else None, + 'preset': 'Qwen3 - Thinking' if (user_data_dir / 'presets/Qwen3 - Thinking.yaml').exists() else None, 'max_new_tokens': 512, 'max_new_tokens_min': 1, 'max_new_tokens_max': 4096, diff --git a/modules/training.py b/modules/training.py index ca8d1bd1..c9f32e64 100644 --- a/modules/training.py +++ b/modules/training.py @@ -107,8 +107,8 @@ def create_ui(): with gr.Column(): with gr.Tab(label='Chat Dataset'): with gr.Row(): - dataset = gr.Dropdown(choices=utils.get_chat_datasets('user_data/training/datasets'), value='None', label='Dataset File', info='A JSON file with chat conversations (messages or ShareGPT format). Each row is one conversation.', elem_classes=['slim-dropdown'], interactive=not mu) - ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_chat_datasets('user_data/training/datasets')}, 'refresh-button', interactive=not mu) + dataset = gr.Dropdown(choices=utils.get_chat_datasets(str(shared.user_data_dir / 'training/datasets')), value='None', label='Dataset File', info='A JSON file with chat conversations (messages or ShareGPT format). Each row is one conversation.', elem_classes=['slim-dropdown'], interactive=not mu) + ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_chat_datasets(str(shared.user_data_dir / 'training/datasets'))}, 'refresh-button', interactive=not mu) with gr.Row(): format = gr.Dropdown(choices=get_instruction_templates(), value='None', label='Instruction Template', info='Select an instruction template for formatting the dataset, or "Chat Template" to use the model\'s built-in chat template.', elem_classes=['slim-dropdown'], interactive=not mu) @@ -116,14 +116,14 @@ def create_ui(): with gr.Tab(label="Text Dataset"): with gr.Row(): - text_dataset = gr.Dropdown(choices=utils.get_text_datasets('user_data/training/datasets'), value='None', label='Dataset File', info='A JSON file with a "text" key per row, for pretraining-style training. Each row is one document.', elem_classes=['slim-dropdown'], interactive=not mu) - ui.create_refresh_button(text_dataset, lambda: None, lambda: {'choices': utils.get_text_datasets('user_data/training/datasets')}, 'refresh-button', interactive=not mu) + text_dataset = gr.Dropdown(choices=utils.get_text_datasets(str(shared.user_data_dir / 'training/datasets')), value='None', label='Dataset File', info='A JSON file with a "text" key per row, for pretraining-style training. Each row is one document.', elem_classes=['slim-dropdown'], interactive=not mu) + ui.create_refresh_button(text_dataset, lambda: None, lambda: {'choices': utils.get_text_datasets(str(shared.user_data_dir / 'training/datasets'))}, 'refresh-button', interactive=not mu) stride_length = gr.Slider(label='Stride Length', minimum=0, maximum=2048, value=256, step=32, info='Overlap between chunks in tokens. 0 = no overlap. Values like 256 or 512 help preserve context across chunk boundaries.') with gr.Row(): - eval_dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'], interactive=not mu) - ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'json')}, 'refresh-button', interactive=not mu) + eval_dataset = gr.Dropdown(choices=utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'], interactive=not mu) + ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'json')}, 'refresh-button', interactive=not mu) eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.') @@ -137,7 +137,7 @@ def create_ui(): with gr.Row(): with gr.Column(): models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True, interactive=not mu) - evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets('user_data/training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under user_data/training/datasets.', interactive=not mu) + evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'txt')[1:], value='wikitext', label='Input dataset', info=f'The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under {shared.user_data_dir}/training/datasets.', interactive=not mu) with gr.Row(): with gr.Column(): stride_length = gr.Slider(label='Stride', minimum=0, maximum=32768, value=512, step=256, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.') @@ -224,7 +224,7 @@ def clean_path(base_path: str, path: str): def get_instruction_templates(): - path = Path('user_data/instruction-templates') + path = shared.user_data_dir / 'instruction-templates' names = set() for ext in ['yaml', 'yml', 'jinja']: for f in path.glob(f'*.{ext}'): @@ -233,8 +233,8 @@ def get_instruction_templates(): def load_template(name): - """Load a Jinja2 template string from user_data/instruction-templates/.""" - path = Path('user_data/instruction-templates') + """Load a Jinja2 template string from {user_data_dir}/instruction-templates/.""" + path = shared.user_data_dir / 'instruction-templates' for ext in ['jinja', 'yaml', 'yml']: filepath = path / f'{name}.{ext}' if filepath.exists(): @@ -453,7 +453,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: if has_text_dataset: train_template["template_type"] = "text_dataset" logger.info("Loading text dataset") - data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{text_dataset}.json')) + data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{text_dataset}.json')) if "text" not in data['train'].column_names: yield "Error: text dataset must have a \"text\" key per row." @@ -467,7 +467,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: if eval_dataset == 'None': eval_data = None else: - eval_raw = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json')) + eval_raw = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{eval_dataset}.json')) if "text" not in eval_raw['train'].column_names: yield "Error: evaluation dataset must have a \"text\" key per row." return @@ -496,7 +496,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: train_template["template_type"] = "chat_template" logger.info("Loading JSON dataset with chat template format") - data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{dataset}.json')) + data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{dataset}.json')) # Validate the first row try: @@ -522,7 +522,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: if eval_dataset == 'None': eval_data = None else: - eval_data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json')) + eval_data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{eval_dataset}.json')) eval_data = eval_data['train'].map( tokenize_conversation, remove_columns=eval_data['train'].column_names, @@ -757,11 +757,11 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: decoded_entries.append({"value": decoded_text}) # Write the log file - Path('user_data/logs').mkdir(exist_ok=True) - with open(Path('user_data/logs/train_dataset_sample.json'), 'w') as json_file: + (shared.user_data_dir / 'logs').mkdir(exist_ok=True) + with open(shared.user_data_dir / 'logs' / 'train_dataset_sample.json', 'w') as json_file: json.dump(decoded_entries, json_file, indent=4) - logger.info("Log file 'train_dataset_sample.json' created in the 'user_data/logs' directory.") + logger.info(f"Log file 'train_dataset_sample.json' created in the '{shared.user_data_dir}/logs' directory.") except Exception as e: logger.error(f"Failed to create log file due to error: {e}") diff --git a/modules/ui.py b/modules/ui.py index 0cd2cf43..ae998ebb 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -113,7 +113,7 @@ if not shared.args.old_colors: block_radius='0', ) -if Path("user_data/notification.mp3").exists(): +if (shared.user_data_dir / "notification.mp3").exists(): audio_notification_js = "document.querySelector('#audio_notification audio')?.play();" else: audio_notification_js = "" @@ -381,7 +381,7 @@ def save_settings(state, preset, extensions_list, show_controls, theme_state, ma output[_id] = params[param] else: # Preserve existing extensions and extension parameters during autosave - settings_path = Path('user_data') / 'settings.yaml' + settings_path = shared.user_data_dir / 'settings.yaml' if settings_path.exists(): try: with open(settings_path, 'r', encoding='utf-8') as f: @@ -436,7 +436,7 @@ def _perform_debounced_save(): try: if _last_interface_state is not None: contents = save_settings(_last_interface_state, _last_preset, _last_extensions, _last_show_controls, _last_theme_state, manual_save=False) - settings_path = Path('user_data') / 'settings.yaml' + settings_path = shared.user_data_dir / 'settings.yaml' settings_path.parent.mkdir(exist_ok=True) with open(settings_path, 'w', encoding='utf-8') as f: f.write(contents) diff --git a/modules/ui_chat.py b/modules/ui_chat.py index aec9051e..74da0a40 100644 --- a/modules/ui_chat.py +++ b/modules/ui_chat.py @@ -175,7 +175,7 @@ def create_character_settings_ui(): with gr.Column(scale=1): shared.gradio['character_picture'] = gr.Image(label='Character picture', type='filepath', interactive=not mu) - shared.gradio['your_picture'] = gr.Image(label='Your picture', type='filepath', value=Image.open(Path('user_data/cache/pfp_me.png')) if Path('user_data/cache/pfp_me.png').exists() else None, interactive=not mu) + shared.gradio['your_picture'] = gr.Image(label='Your picture', type='filepath', value=Image.open(shared.user_data_dir / 'cache' / 'pfp_me.png') if (shared.user_data_dir / 'cache' / 'pfp_me.png').exists() else None, interactive=not mu) def create_chat_settings_ui(): diff --git a/modules/ui_default.py b/modules/ui_default.py index c0feae19..2c367cca 100644 --- a/modules/ui_default.py +++ b/modules/ui_default.py @@ -159,7 +159,7 @@ def handle_new_prompt(): new_name = utils.current_time() # Create the new prompt file - prompt_path = Path("user_data/logs/notebook") / f"{new_name}.txt" + prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" prompt_path.parent.mkdir(parents=True, exist_ok=True) prompt_path.write_text("In this story,", encoding='utf-8') @@ -170,15 +170,15 @@ def handle_delete_prompt_confirm_default(prompt_name): available_prompts = utils.get_available_prompts() current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0 - (Path("user_data/logs/notebook") / f"{prompt_name}.txt").unlink(missing_ok=True) + (shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt").unlink(missing_ok=True) available_prompts = utils.get_available_prompts() if available_prompts: new_value = available_prompts[min(current_index, len(available_prompts) - 1)] else: new_value = utils.current_time() - Path("user_data/logs/notebook").mkdir(parents=True, exist_ok=True) - (Path("user_data/logs/notebook") / f"{new_value}.txt").write_text("In this story,") + (shared.user_data_dir / "logs" / "notebook").mkdir(parents=True, exist_ok=True) + (shared.user_data_dir / "logs" / "notebook" / f"{new_value}.txt").write_text("In this story,") available_prompts = [new_value] return [ @@ -199,8 +199,8 @@ def handle_rename_prompt_click_default(current_name): def handle_rename_prompt_confirm_default(new_name, current_name): - old_path = Path("user_data/logs/notebook") / f"{current_name}.txt" - new_path = Path("user_data/logs/notebook") / f"{new_name}.txt" + old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt" + new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" if old_path.exists() and not new_path.exists(): old_path.rename(new_path) diff --git a/modules/ui_file_saving.py b/modules/ui_file_saving.py index 720bfdec..46087ace 100644 --- a/modules/ui_file_saving.py +++ b/modules/ui_file_saving.py @@ -28,7 +28,7 @@ def create_ui(): # Character saver/deleter with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['character_saver']: - shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name', info='The character will be saved to your user_data/characters folder with this base filename.') + shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name', info=f'The character will be saved to your {shared.user_data_dir}/characters folder with this base filename.') with gr.Row(): shared.gradio['save_character_cancel'] = gr.Button('Cancel', elem_classes="small-button") shared.gradio['save_character_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu) @@ -41,7 +41,7 @@ def create_ui(): # User saver/deleter with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['user_saver']: - shared.gradio['save_user_filename'] = gr.Textbox(lines=1, label='File name', info='The user profile will be saved to your user_data/users folder with this base filename.') + shared.gradio['save_user_filename'] = gr.Textbox(lines=1, label='File name', info=f'The user profile will be saved to your {shared.user_data_dir}/users folder with this base filename.') with gr.Row(): shared.gradio['save_user_cancel'] = gr.Button('Cancel', elem_classes="small-button") shared.gradio['save_user_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu) @@ -54,7 +54,7 @@ def create_ui(): # Preset saver with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['preset_saver']: - shared.gradio['save_preset_filename'] = gr.Textbox(lines=1, label='File name', info='The preset will be saved to your user_data/presets folder with this base filename.') + shared.gradio['save_preset_filename'] = gr.Textbox(lines=1, label='File name', info=f'The preset will be saved to your {shared.user_data_dir}/presets folder with this base filename.') shared.gradio['save_preset_contents'] = gr.Textbox(lines=10, label='File contents') with gr.Row(): shared.gradio['save_preset_cancel'] = gr.Button('Cancel', elem_classes="small-button") @@ -91,7 +91,7 @@ def create_event_handlers(): def handle_save_preset_confirm_click(filename, contents): try: - utils.save_file(f"user_data/presets/{filename}.yaml", contents) + utils.save_file(str(shared.user_data_dir / "presets" / f"{filename}.yaml"), contents) available_presets = utils.get_available_presets() output = gr.update(choices=available_presets, value=filename) except Exception: @@ -164,7 +164,7 @@ def handle_save_preset_click(state): def handle_delete_preset_click(preset): return [ f"{preset}.yaml", - "user_data/presets/", + str(shared.user_data_dir / "presets") + "/", gr.update(visible=True) ] @@ -173,7 +173,7 @@ def handle_save_grammar_click(grammar_string): return [ grammar_string, "My Fancy Grammar.gbnf", - "user_data/grammars/", + str(shared.user_data_dir / "grammars") + "/", gr.update(visible=True) ] @@ -181,7 +181,7 @@ def handle_save_grammar_click(grammar_string): def handle_delete_grammar_click(grammar_file): return [ grammar_file, - "user_data/grammars/", + str(shared.user_data_dir / "grammars") + "/", gr.update(visible=True) ] diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py index 6b39c5b5..e9df9bd3 100644 --- a/modules/ui_image_generation.py +++ b/modules/ui_image_generation.py @@ -138,7 +138,7 @@ def save_generated_images(images, state, actual_seed): return [] date_str = datetime.now().strftime("%Y-%m-%d") - folder_path = os.path.join("user_data", "image_outputs", date_str) + folder_path = str(shared.user_data_dir / "image_outputs" / date_str) os.makedirs(folder_path, exist_ok=True) metadata = build_generation_metadata(state, actual_seed) @@ -214,7 +214,7 @@ def get_all_history_images(force_refresh=False): """Get all history images sorted by modification time (newest first). Uses caching.""" global _image_cache, _cache_timestamp - output_dir = os.path.join("user_data", "image_outputs") + output_dir = str(shared.user_data_dir / "image_outputs") if not os.path.exists(output_dir): return [] diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index e6e57a22..33b39a25 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -65,7 +65,7 @@ def create_ui(): # Multimodal with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']: with gr.Row(): - shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info='Select a file that matches your model. Must be placed in user_data/mmproj/', interactive=not mu) + shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info=f'Select a file that matches your model. Must be placed in {shared.user_data_dir}/mmproj/', interactive=not mu) ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu) # Speculative decoding @@ -317,9 +317,9 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur model_dir=shared.args.model_dir if shared.args.model_dir != shared.args_defaults.model_dir else None ) - if output_folder == Path("user_data/models"): + if output_folder == shared.user_data_dir / "models": output_folder = Path(shared.args.model_dir) - elif output_folder == Path("user_data/loras"): + elif output_folder == shared.user_data_dir / "loras": output_folder = Path(shared.args.lora_dir) if check: diff --git a/modules/ui_notebook.py b/modules/ui_notebook.py index 9fab879b..f550e646 100644 --- a/modules/ui_notebook.py +++ b/modules/ui_notebook.py @@ -194,7 +194,7 @@ def handle_new_prompt(): new_name = utils.current_time() # Create the new prompt file - prompt_path = Path("user_data/logs/notebook") / f"{new_name}.txt" + prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" prompt_path.parent.mkdir(parents=True, exist_ok=True) prompt_path.write_text("In this story,", encoding='utf-8') @@ -205,15 +205,15 @@ def handle_delete_prompt_confirm_notebook(prompt_name): available_prompts = utils.get_available_prompts() current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0 - (Path("user_data/logs/notebook") / f"{prompt_name}.txt").unlink(missing_ok=True) + (shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt").unlink(missing_ok=True) available_prompts = utils.get_available_prompts() if available_prompts: new_value = available_prompts[min(current_index, len(available_prompts) - 1)] else: new_value = utils.current_time() - Path("user_data/logs/notebook").mkdir(parents=True, exist_ok=True) - (Path("user_data/logs/notebook") / f"{new_value}.txt").write_text("In this story,") + (shared.user_data_dir / "logs" / "notebook").mkdir(parents=True, exist_ok=True) + (shared.user_data_dir / "logs" / "notebook" / f"{new_value}.txt").write_text("In this story,") available_prompts = [new_value] return [ @@ -233,8 +233,8 @@ def handle_rename_prompt_click_notebook(current_name): def handle_rename_prompt_confirm_notebook(new_name, current_name): - old_path = Path("user_data/logs/notebook") / f"{current_name}.txt" - new_path = Path("user_data/logs/notebook") / f"{new_name}.txt" + old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt" + new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" if old_path.exists() and not new_path.exists(): old_path.rename(new_path) @@ -250,7 +250,7 @@ def handle_rename_prompt_confirm_notebook(new_name, current_name): def autosave_prompt(text, prompt_name): """Automatically save the text to the selected prompt file""" if prompt_name and text.strip(): - prompt_path = Path("user_data/logs/notebook") / f"{prompt_name}.txt" + prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt" prompt_path.parent.mkdir(parents=True, exist_ok=True) prompt_path.write_text(text, encoding='utf-8') diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 23882084..e5eb9210 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -135,7 +135,7 @@ def get_truncation_length(): def load_grammar(name): - p = Path(f'user_data/grammars/{name}') + p = shared.user_data_dir / 'grammars' / name if p.exists(): return open(p, 'r', encoding='utf-8').read() else: diff --git a/modules/ui_session.py b/modules/ui_session.py index 60b19f47..e1807dea 100644 --- a/modules/ui_session.py +++ b/modules/ui_session.py @@ -17,7 +17,7 @@ def create_ui(): with gr.Column(): gr.Markdown("## Extensions & flags") - shared.gradio['save_settings'] = gr.Button('Save extensions settings to user_data/settings.yaml', elem_classes='refresh-button', interactive=not mu) + shared.gradio['save_settings'] = gr.Button(f'Save extensions settings to {shared.user_data_dir}/settings.yaml', elem_classes='refresh-button', interactive=not mu) shared.gradio['reset_interface'] = gr.Button("Apply flags/extensions and restart", interactive=not mu) with gr.Row(): with gr.Column(): @@ -54,7 +54,7 @@ def handle_save_settings(state, preset, extensions, show_controls, theme): return [ contents, "settings.yaml", - "user_data/", + str(shared.user_data_dir) + "/", gr.update(visible=True) ] diff --git a/modules/utils.py b/modules/utils.py index 456b5f83..447685b7 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -15,16 +15,31 @@ def gradio(*keys): return [shared.gradio[k] for k in keys] +def _is_path_allowed(abs_path_str): + """Check if a path is under the project root or the configured user_data directory.""" + abs_path = Path(abs_path_str).resolve() + root_folder = Path(__file__).resolve().parent.parent + user_data_resolved = shared.user_data_dir.resolve() + try: + abs_path.relative_to(root_folder) + return True + except ValueError: + pass + try: + abs_path.relative_to(user_data_resolved) + return True + except ValueError: + pass + return False + + def save_file(fname, contents): if fname == '': logger.error('File name is empty!') return - root_folder = Path(__file__).resolve().parent.parent abs_path_str = os.path.abspath(fname) - rel_path_str = os.path.relpath(abs_path_str, root_folder) - rel_path = Path(rel_path_str) - if rel_path.parts[0] == '..': + if not _is_path_allowed(abs_path_str): logger.error(f'Invalid file path: \"{fname}\"') return @@ -39,16 +54,14 @@ def delete_file(fname): logger.error('File name is empty!') return - root_folder = Path(__file__).resolve().parent.parent abs_path_str = os.path.abspath(fname) - rel_path_str = os.path.relpath(abs_path_str, root_folder) - rel_path = Path(rel_path_str) - if rel_path.parts[0] == '..': + if not _is_path_allowed(abs_path_str): logger.error(f'Invalid file path: \"{fname}\"') return - if rel_path.exists(): - rel_path.unlink() + p = Path(abs_path_str) + if p.exists(): + p.unlink() logger.info(f'Deleted \"{fname}\".') @@ -75,7 +88,7 @@ def natural_keys(text): def check_model_loaded(): if shared.model_name == 'None' or shared.model is None: if len(get_available_models()) == 0: - error_msg = "No model is loaded.\n\nTo get started:\n1) Place a GGUF file in your user_data/models folder\n2) Go to the Model tab and select it" + error_msg = f"No model is loaded.\n\nTo get started:\n1) Place a GGUF file in your {shared.user_data_dir}/models folder\n2) Go to the Model tab and select it" logger.error(error_msg) return False, error_msg else: @@ -188,7 +201,7 @@ def get_available_ggufs(): def get_available_mmproj(): - mmproj_dir = Path('user_data/mmproj') + mmproj_dir = shared.user_data_dir / 'mmproj' if not mmproj_dir.exists(): return ['None'] @@ -201,11 +214,11 @@ def get_available_mmproj(): def get_available_presets(): - return sorted(set((k.stem for k in Path('user_data/presets').glob('*.yaml'))), key=natural_keys) + return sorted(set((k.stem for k in (shared.user_data_dir / 'presets').glob('*.yaml'))), key=natural_keys) def get_available_prompts(): - notebook_dir = Path('user_data/logs/notebook') + notebook_dir = shared.user_data_dir / 'logs' / 'notebook' notebook_dir.mkdir(parents=True, exist_ok=True) prompt_files = list(notebook_dir.glob('*.txt')) @@ -221,19 +234,19 @@ def get_available_prompts(): def get_available_characters(): - paths = (x for x in Path('user_data/characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) + paths = (x for x in (shared.user_data_dir / 'characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) return sorted(set((k.stem for k in paths)), key=natural_keys) def get_available_users(): - users_dir = Path('user_data/users') + users_dir = shared.user_data_dir / 'users' users_dir.mkdir(parents=True, exist_ok=True) paths = (x for x in users_dir.iterdir() if x.suffix in ('.json', '.yaml', '.yml')) return sorted(set((k.stem for k in paths)), key=natural_keys) def get_available_instruction_templates(): - path = "user_data/instruction-templates" + path = str(shared.user_data_dir / "instruction-templates") paths = [] if os.path.exists(path): paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml')) @@ -244,13 +257,13 @@ def get_available_instruction_templates(): def get_available_extensions(): # User extensions (higher priority) user_extensions = [] - user_ext_path = Path('user_data/extensions') + user_ext_path = shared.user_data_dir / 'extensions' if user_ext_path.exists(): - user_exts = map(lambda x: x.parts[2], user_ext_path.glob('*/script.py')) + user_exts = map(lambda x: x.parent.name, user_ext_path.glob('*/script.py')) user_extensions = sorted(set(user_exts), key=natural_keys) # System extensions (excluding those overridden by user extensions) - system_exts = map(lambda x: x.parts[1], Path('extensions').glob('*/script.py')) + system_exts = map(lambda x: x.parent.name, Path('extensions').glob('*/script.py')) system_extensions = sorted(set(system_exts) - set(user_extensions), key=natural_keys) return user_extensions + system_extensions @@ -306,4 +319,4 @@ def get_available_chat_styles(): def get_available_grammars(): - return ['None'] + sorted([item.name for item in list(Path('user_data/grammars').glob('*.gbnf'))], key=natural_keys) + return ['None'] + sorted([item.name for item in list((shared.user_data_dir / 'grammars').glob('*.gbnf'))], key=natural_keys) diff --git a/server.py b/server.py index 000ea9fb..a809cd7b 100644 --- a/server.py +++ b/server.py @@ -9,7 +9,7 @@ from modules.logging_colors import logger from modules.prompts import load_prompt # Set up Gradio temp directory path -gradio_temp_path = Path('user_data') / 'cache' / 'gradio' +gradio_temp_path = shared.user_data_dir / 'cache' / 'gradio' shutil.rmtree(gradio_temp_path, ignore_errors=True) gradio_temp_path.mkdir(parents=True, exist_ok=True) @@ -94,9 +94,9 @@ def create_interface(): auth = [tuple(cred.split(':')) for cred in auth] # Allowed paths - allowed_paths = ["css", "js", "extensions", "user_data/cache"] + allowed_paths = ["css", "js", "extensions", str(shared.user_data_dir / "cache")] if not shared.args.multi_user: - allowed_paths.append("user_data/image_outputs") + allowed_paths.append(str(shared.user_data_dir / "image_outputs")) # Import the extensions and execute their setup() functions if shared.args.extensions is not None and len(shared.args.extensions) > 0: @@ -120,7 +120,7 @@ def create_interface(): # Clear existing cache files for cache_file in ['pfp_character.png', 'pfp_character_thumb.png']: - cache_path = Path(f"user_data/cache/{cache_file}") + cache_path = shared.user_data_dir / "cache" / cache_file if cache_path.exists(): cache_path.unlink() @@ -160,8 +160,8 @@ def create_interface(): shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) # Audio notification - if Path("user_data/notification.mp3").exists(): - shared.gradio['audio_notification'] = gr.Audio(interactive=False, value="user_data/notification.mp3", elem_id="audio_notification", visible=False) + if (shared.user_data_dir / "notification.mp3").exists(): + shared.gradio['audio_notification'] = gr.Audio(interactive=False, value=str(shared.user_data_dir / "notification.mp3"), elem_id="audio_notification", visible=False) # Floating menus for saving/deleting files ui_file_saving.create_ui() @@ -244,8 +244,8 @@ if __name__ == "__main__": settings_file = None if shared.args.settings is not None and Path(shared.args.settings).exists(): settings_file = Path(shared.args.settings) - elif Path('user_data/settings.yaml').exists(): - settings_file = Path('user_data/settings.yaml') + elif (shared.user_data_dir / 'settings.yaml').exists(): + settings_file = shared.user_data_dir / 'settings.yaml' if settings_file is not None: logger.info(f"Loading settings from \"{settings_file}\"")