mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-09 15:13:56 +01:00
Make user_data configurable: add --user-data-dir flag, auto-detect ../user_data
If --user-data-dir is not set, auto-detect: use ../user_data when ./user_data doesn't exist, making it easy to share user data across portable builds by placing it one folder up.
This commit is contained in:
parent
4c406e024f
commit
e2548f69a9
|
|
@ -24,6 +24,8 @@ from requests.adapters import HTTPAdapter
|
||||||
from requests.exceptions import ConnectionError, RequestException, Timeout
|
from requests.exceptions import ConnectionError, RequestException, Timeout
|
||||||
from tqdm.contrib.concurrent import thread_map
|
from tqdm.contrib.concurrent import thread_map
|
||||||
|
|
||||||
|
from modules.paths import resolve_user_data_dir
|
||||||
|
|
||||||
base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co"
|
base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -182,11 +184,13 @@ class ModelDownloader:
|
||||||
is_llamacpp = has_gguf and specific_file is not None
|
is_llamacpp = has_gguf and specific_file is not None
|
||||||
return links, sha256, is_lora, is_llamacpp, file_sizes
|
return links, sha256, is_lora, is_llamacpp, file_sizes
|
||||||
|
|
||||||
def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, model_dir=None):
|
def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, model_dir=None, user_data_dir=None):
|
||||||
if model_dir:
|
if model_dir:
|
||||||
base_folder = model_dir
|
base_folder = model_dir
|
||||||
else:
|
else:
|
||||||
base_folder = 'user_data/models' if not is_lora else 'user_data/loras'
|
if user_data_dir is None:
|
||||||
|
user_data_dir = resolve_user_data_dir()
|
||||||
|
base_folder = str(user_data_dir / 'models') if not is_lora else str(user_data_dir / 'loras')
|
||||||
|
|
||||||
# If the model is of type GGUF, save directly in the base_folder
|
# If the model is of type GGUF, save directly in the base_folder
|
||||||
if is_llamacpp:
|
if is_llamacpp:
|
||||||
|
|
@ -392,7 +396,8 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
|
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
|
||||||
parser.add_argument('--exclude-pattern', type=str, default=None, help='Regex pattern to exclude files from download.')
|
parser.add_argument('--exclude-pattern', type=str, default=None, help='Regex pattern to exclude files from download.')
|
||||||
parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.')
|
parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.')
|
||||||
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/user_data/models).')
|
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (user_data/models).')
|
||||||
|
parser.add_argument('--user-data-dir', type=str, default=None, help='Path to the user data directory. Overrides auto-detection.')
|
||||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||||
parser.add_argument('--max-retries', type=int, default=7, help='Max retries count when get error in download time.')
|
parser.add_argument('--max-retries', type=int, default=7, help='Max retries count when get error in download time.')
|
||||||
|
|
@ -421,10 +426,11 @@ if __name__ == '__main__':
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the output folder
|
# Get the output folder
|
||||||
|
user_data_dir = Path(args.user_data_dir) if args.user_data_dir else None
|
||||||
if args.output:
|
if args.output:
|
||||||
output_folder = Path(args.output)
|
output_folder = Path(args.output)
|
||||||
else:
|
else:
|
||||||
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=args.model_dir)
|
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=args.model_dir, user_data_dir=user_data_dir)
|
||||||
|
|
||||||
if args.check:
|
if args.check:
|
||||||
# Check previously downloaded files
|
# Check previously downloaded files
|
||||||
|
|
|
||||||
|
|
@ -1126,9 +1126,9 @@ def start_new_chat(state):
|
||||||
|
|
||||||
def get_history_file_path(unique_id, character, mode):
|
def get_history_file_path(unique_id, character, mode):
|
||||||
if mode == 'instruct':
|
if mode == 'instruct':
|
||||||
p = Path(f'user_data/logs/instruct/{unique_id}.json')
|
p = shared.user_data_dir / 'logs' / 'instruct' / f'{unique_id}.json'
|
||||||
else:
|
else:
|
||||||
p = Path(f'user_data/logs/chat/{character}/{unique_id}.json')
|
p = shared.user_data_dir / 'logs' / 'chat' / character / f'{unique_id}.json'
|
||||||
|
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
@ -1164,13 +1164,13 @@ def rename_history(old_id, new_id, character, mode):
|
||||||
|
|
||||||
def get_paths(state):
|
def get_paths(state):
|
||||||
if state['mode'] == 'instruct':
|
if state['mode'] == 'instruct':
|
||||||
return Path('user_data/logs/instruct').glob('*.json')
|
return (shared.user_data_dir / 'logs' / 'instruct').glob('*.json')
|
||||||
else:
|
else:
|
||||||
character = state['character_menu']
|
character = state['character_menu']
|
||||||
|
|
||||||
# Handle obsolete filenames and paths
|
# Handle obsolete filenames and paths
|
||||||
old_p = Path(f'user_data/logs/{character}_persistent.json')
|
old_p = shared.user_data_dir / 'logs' / f'{character}_persistent.json'
|
||||||
new_p = Path(f'user_data/logs/persistent_{character}.json')
|
new_p = shared.user_data_dir / 'logs' / f'persistent_{character}.json'
|
||||||
if old_p.exists():
|
if old_p.exists():
|
||||||
logger.warning(f"Renaming \"{old_p}\" to \"{new_p}\"")
|
logger.warning(f"Renaming \"{old_p}\" to \"{new_p}\"")
|
||||||
old_p.rename(new_p)
|
old_p.rename(new_p)
|
||||||
|
|
@ -1182,7 +1182,7 @@ def get_paths(state):
|
||||||
p.parent.mkdir(exist_ok=True)
|
p.parent.mkdir(exist_ok=True)
|
||||||
new_p.rename(p)
|
new_p.rename(p)
|
||||||
|
|
||||||
return Path(f'user_data/logs/chat/{character}').glob('*.json')
|
return (shared.user_data_dir / 'logs' / 'chat' / character).glob('*.json')
|
||||||
|
|
||||||
|
|
||||||
def find_all_histories(state):
|
def find_all_histories(state):
|
||||||
|
|
@ -1307,7 +1307,7 @@ def get_chat_state_key(character, mode):
|
||||||
|
|
||||||
def load_last_chat_state():
|
def load_last_chat_state():
|
||||||
"""Load the last chat state from file"""
|
"""Load the last chat state from file"""
|
||||||
state_file = Path('user_data/logs/chat_state.json')
|
state_file = shared.user_data_dir / 'logs' / 'chat_state.json'
|
||||||
if state_file.exists():
|
if state_file.exists():
|
||||||
try:
|
try:
|
||||||
with open(state_file, 'r', encoding='utf-8') as f:
|
with open(state_file, 'r', encoding='utf-8') as f:
|
||||||
|
|
@ -1327,7 +1327,7 @@ def save_last_chat_state(character, mode, unique_id):
|
||||||
key = get_chat_state_key(character, mode)
|
key = get_chat_state_key(character, mode)
|
||||||
state["last_chats"][key] = unique_id
|
state["last_chats"][key] = unique_id
|
||||||
|
|
||||||
state_file = Path('user_data/logs/chat_state.json')
|
state_file = shared.user_data_dir / 'logs' / 'chat_state.json'
|
||||||
state_file.parent.mkdir(exist_ok=True)
|
state_file.parent.mkdir(exist_ok=True)
|
||||||
with open(state_file, 'w', encoding='utf-8') as f:
|
with open(state_file, 'w', encoding='utf-8') as f:
|
||||||
f.write(json.dumps(state, indent=2))
|
f.write(json.dumps(state, indent=2))
|
||||||
|
|
@ -1403,7 +1403,7 @@ def generate_pfp_cache(character):
|
||||||
if not cache_folder.exists():
|
if not cache_folder.exists():
|
||||||
cache_folder.mkdir()
|
cache_folder.mkdir()
|
||||||
|
|
||||||
for path in [Path(f"user_data/characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
|
for path in [shared.user_data_dir / 'characters' / f"{character}.{extension}" for extension in ['png', 'jpg', 'jpeg']]:
|
||||||
if path.exists():
|
if path.exists():
|
||||||
original_img = Image.open(path)
|
original_img = Image.open(path)
|
||||||
# Define file paths
|
# Define file paths
|
||||||
|
|
@ -1428,12 +1428,12 @@ def load_character(character, name1, name2):
|
||||||
|
|
||||||
filepath = None
|
filepath = None
|
||||||
for extension in ["yml", "yaml", "json"]:
|
for extension in ["yml", "yaml", "json"]:
|
||||||
filepath = Path(f'user_data/characters/{character}.{extension}')
|
filepath = shared.user_data_dir / 'characters' / f'{character}.{extension}'
|
||||||
if filepath.exists():
|
if filepath.exists():
|
||||||
break
|
break
|
||||||
|
|
||||||
if filepath is None or not filepath.exists():
|
if filepath is None or not filepath.exists():
|
||||||
logger.error(f"Could not find the character \"{character}\" inside user_data/characters. No character has been loaded.")
|
logger.error(f"Could not find the character \"{character}\" inside {shared.user_data_dir}/characters. No character has been loaded.")
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
file_contents = open(filepath, 'r', encoding='utf-8').read()
|
file_contents = open(filepath, 'r', encoding='utf-8').read()
|
||||||
|
|
@ -1509,7 +1509,7 @@ def load_instruction_template(template):
|
||||||
if template == 'None':
|
if template == 'None':
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
for filepath in [Path(f'user_data/instruction-templates/{template}.yaml'), Path('user_data/instruction-templates/Alpaca.yaml')]:
|
for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']:
|
||||||
if filepath.exists():
|
if filepath.exists():
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|
@ -1552,17 +1552,17 @@ def upload_character(file, img_path, tavern=False):
|
||||||
|
|
||||||
outfile_name = name
|
outfile_name = name
|
||||||
i = 1
|
i = 1
|
||||||
while Path(f'user_data/characters/{outfile_name}.yaml').exists():
|
while (shared.user_data_dir / 'characters' / f'{outfile_name}.yaml').exists():
|
||||||
outfile_name = f'{name}_{i:03d}'
|
outfile_name = f'{name}_{i:03d}'
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
with open(Path(f'user_data/characters/{outfile_name}.yaml'), 'w', encoding='utf-8') as f:
|
with open(shared.user_data_dir / 'characters' / f'{outfile_name}.yaml', 'w', encoding='utf-8') as f:
|
||||||
f.write(yaml_data)
|
f.write(yaml_data)
|
||||||
|
|
||||||
if img is not None:
|
if img is not None:
|
||||||
img.save(Path(f'user_data/characters/{outfile_name}.png'))
|
img.save(shared.user_data_dir / 'characters' / f'{outfile_name}.png')
|
||||||
|
|
||||||
logger.info(f'New character saved to "user_data/characters/{outfile_name}.yaml".')
|
logger.info(f'New character saved to "{shared.user_data_dir}/characters/{outfile_name}.yaml".')
|
||||||
return gr.update(value=outfile_name, choices=get_available_characters())
|
return gr.update(value=outfile_name, choices=get_available_characters())
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1643,9 +1643,9 @@ def save_character(name, greeting, context, picture, filename):
|
||||||
return
|
return
|
||||||
|
|
||||||
data = generate_character_yaml(name, greeting, context)
|
data = generate_character_yaml(name, greeting, context)
|
||||||
filepath = Path(f'user_data/characters/{filename}.yaml')
|
filepath = shared.user_data_dir / 'characters' / f'{filename}.yaml'
|
||||||
save_file(filepath, data)
|
save_file(filepath, data)
|
||||||
path_to_img = Path(f'user_data/characters/{filename}.png')
|
path_to_img = shared.user_data_dir / 'characters' / f'{filename}.png'
|
||||||
if picture is not None:
|
if picture is not None:
|
||||||
# Copy the image file from its source path to the character folder
|
# Copy the image file from its source path to the character folder
|
||||||
shutil.copy(picture, path_to_img)
|
shutil.copy(picture, path_to_img)
|
||||||
|
|
@ -1655,11 +1655,11 @@ def save_character(name, greeting, context, picture, filename):
|
||||||
def delete_character(name, instruct=False):
|
def delete_character(name, instruct=False):
|
||||||
# Check for character data files
|
# Check for character data files
|
||||||
for extension in ["yml", "yaml", "json"]:
|
for extension in ["yml", "yaml", "json"]:
|
||||||
delete_file(Path(f'user_data/characters/{name}.{extension}'))
|
delete_file(shared.user_data_dir / 'characters' / f'{name}.{extension}')
|
||||||
|
|
||||||
# Check for character image files
|
# Check for character image files
|
||||||
for extension in ["png", "jpg", "jpeg"]:
|
for extension in ["png", "jpg", "jpeg"]:
|
||||||
delete_file(Path(f'user_data/characters/{name}.{extension}'))
|
delete_file(shared.user_data_dir / 'characters' / f'{name}.{extension}')
|
||||||
|
|
||||||
|
|
||||||
def generate_user_pfp_cache(user):
|
def generate_user_pfp_cache(user):
|
||||||
|
|
@ -1668,7 +1668,7 @@ def generate_user_pfp_cache(user):
|
||||||
if not cache_folder.exists():
|
if not cache_folder.exists():
|
||||||
cache_folder.mkdir()
|
cache_folder.mkdir()
|
||||||
|
|
||||||
for path in [Path(f"user_data/users/{user}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
|
for path in [shared.user_data_dir / 'users' / f"{user}.{extension}" for extension in ['png', 'jpg', 'jpeg']]:
|
||||||
if path.exists():
|
if path.exists():
|
||||||
original_img = Image.open(path)
|
original_img = Image.open(path)
|
||||||
# Define file paths
|
# Define file paths
|
||||||
|
|
@ -1690,12 +1690,12 @@ def load_user(user_name, name1, user_bio):
|
||||||
|
|
||||||
filepath = None
|
filepath = None
|
||||||
for extension in ["yml", "yaml", "json"]:
|
for extension in ["yml", "yaml", "json"]:
|
||||||
filepath = Path(f'user_data/users/{user_name}.{extension}')
|
filepath = shared.user_data_dir / 'users' / f'{user_name}.{extension}'
|
||||||
if filepath.exists():
|
if filepath.exists():
|
||||||
break
|
break
|
||||||
|
|
||||||
if filepath is None or not filepath.exists():
|
if filepath is None or not filepath.exists():
|
||||||
logger.error(f"Could not find the user \"{user_name}\" inside user_data/users. No user has been loaded.")
|
logger.error(f"Could not find the user \"{user_name}\" inside {shared.user_data_dir}/users. No user has been loaded.")
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
with open(filepath, 'r', encoding='utf-8') as f:
|
with open(filepath, 'r', encoding='utf-8') as f:
|
||||||
|
|
@ -1741,14 +1741,14 @@ def save_user(name, user_bio, picture, filename):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Ensure the users directory exists
|
# Ensure the users directory exists
|
||||||
users_dir = Path('user_data/users')
|
users_dir = shared.user_data_dir / 'users'
|
||||||
users_dir.mkdir(parents=True, exist_ok=True)
|
users_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
data = generate_user_yaml(name, user_bio)
|
data = generate_user_yaml(name, user_bio)
|
||||||
filepath = Path(f'user_data/users/{filename}.yaml')
|
filepath = shared.user_data_dir / 'users' / f'{filename}.yaml'
|
||||||
save_file(filepath, data)
|
save_file(filepath, data)
|
||||||
|
|
||||||
path_to_img = Path(f'user_data/users/{filename}.png')
|
path_to_img = shared.user_data_dir / 'users' / f'{filename}.png'
|
||||||
if picture is not None:
|
if picture is not None:
|
||||||
# Copy the image file from its source path to the users folder
|
# Copy the image file from its source path to the users folder
|
||||||
shutil.copy(picture, path_to_img)
|
shutil.copy(picture, path_to_img)
|
||||||
|
|
@ -1759,11 +1759,11 @@ def delete_user(name):
|
||||||
"""Delete user profile files"""
|
"""Delete user profile files"""
|
||||||
# Check for user data files
|
# Check for user data files
|
||||||
for extension in ["yml", "yaml", "json"]:
|
for extension in ["yml", "yaml", "json"]:
|
||||||
delete_file(Path(f'user_data/users/{name}.{extension}'))
|
delete_file(shared.user_data_dir / 'users' / f'{name}.{extension}')
|
||||||
|
|
||||||
# Check for user image files
|
# Check for user image files
|
||||||
for extension in ["png", "jpg", "jpeg"]:
|
for extension in ["png", "jpg", "jpeg"]:
|
||||||
delete_file(Path(f'user_data/users/{name}.{extension}'))
|
delete_file(shared.user_data_dir / 'users' / f'{name}.{extension}')
|
||||||
|
|
||||||
|
|
||||||
def update_user_menu_after_deletion(idx):
|
def update_user_menu_after_deletion(idx):
|
||||||
|
|
@ -2224,7 +2224,7 @@ def handle_save_template_click(instruction_template_str):
|
||||||
contents = generate_instruction_template_yaml(instruction_template_str)
|
contents = generate_instruction_template_yaml(instruction_template_str)
|
||||||
return [
|
return [
|
||||||
"My Template.yaml",
|
"My Template.yaml",
|
||||||
"user_data/instruction-templates/",
|
str(shared.user_data_dir / 'instruction-templates') + '/',
|
||||||
contents,
|
contents,
|
||||||
gr.update(visible=True)
|
gr.update(visible=True)
|
||||||
]
|
]
|
||||||
|
|
@ -2233,7 +2233,7 @@ def handle_save_template_click(instruction_template_str):
|
||||||
def handle_delete_template_click(template):
|
def handle_delete_template_click(template):
|
||||||
return [
|
return [
|
||||||
f"{template}.yaml",
|
f"{template}.yaml",
|
||||||
"user_data/instruction-templates/",
|
str(shared.user_data_dir / 'instruction-templates') + '/',
|
||||||
gr.update(visible=False)
|
gr.update(visible=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,8 @@ from modules.text_generation import encode
|
||||||
|
|
||||||
|
|
||||||
def load_past_evaluations():
|
def load_past_evaluations():
|
||||||
if Path('user_data/logs/evaluations.csv').exists():
|
if (shared.user_data_dir / 'logs' / 'evaluations.csv').exists():
|
||||||
df = pd.read_csv(Path('user_data/logs/evaluations.csv'), dtype=str)
|
df = pd.read_csv(shared.user_data_dir / 'logs' / 'evaluations.csv', dtype=str)
|
||||||
df['Perplexity'] = pd.to_numeric(df['Perplexity'])
|
df['Perplexity'] = pd.to_numeric(df['Perplexity'])
|
||||||
return df
|
return df
|
||||||
else:
|
else:
|
||||||
|
|
@ -26,7 +26,7 @@ past_evaluations = load_past_evaluations()
|
||||||
def save_past_evaluations(df):
|
def save_past_evaluations(df):
|
||||||
global past_evaluations
|
global past_evaluations
|
||||||
past_evaluations = df
|
past_evaluations = df
|
||||||
filepath = Path('user_data/logs/evaluations.csv')
|
filepath = shared.user_data_dir / 'logs' / 'evaluations.csv'
|
||||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
df.to_csv(filepath, index=False)
|
df.to_csv(filepath, index=False)
|
||||||
|
|
||||||
|
|
@ -65,7 +65,7 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
|
||||||
data = load_dataset('ptb_text_only', 'penn_treebank', split='test')
|
data = load_dataset('ptb_text_only', 'penn_treebank', split='test')
|
||||||
text = " ".join(data['sentence'])
|
text = " ".join(data['sentence'])
|
||||||
else:
|
else:
|
||||||
with open(Path(f'user_data/training/datasets/{input_dataset}.txt'), 'r', encoding='utf-8') as f:
|
with open(shared.user_data_dir / 'training' / 'datasets' / f'{input_dataset}.txt', 'r', encoding='utf-8') as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
|
@ -38,9 +40,15 @@ def load_extensions():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Prefer user extension, fall back to system extension
|
# Prefer user extension, fall back to system extension
|
||||||
user_script_path = Path(f'user_data/extensions/{name}/script.py')
|
user_script_path = shared.user_data_dir / 'extensions' / name / 'script.py'
|
||||||
if user_script_path.exists():
|
if user_script_path.exists():
|
||||||
extension = importlib.import_module(f"user_data.extensions.{name}.script")
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
f"user_ext_{name}",
|
||||||
|
str(user_script_path)
|
||||||
|
)
|
||||||
|
extension = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[spec.name] = extension
|
||||||
|
spec.loader.exec_module(extension)
|
||||||
else:
|
else:
|
||||||
extension = importlib.import_module(f"extensions.{name}.script")
|
extension = importlib.import_module(f"extensions.{name}.script")
|
||||||
|
|
||||||
|
|
@ -53,7 +61,7 @@ def load_extensions():
|
||||||
state[name] = [True, i, extension] # Store extension object
|
state[name] = [True, i, extension] # Store extension object
|
||||||
|
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
extension_location = Path('user_data/extensions') / name if user_script_path.exists() else Path('extensions') / name
|
extension_location = shared.user_data_dir / 'extensions' / name if user_script_path.exists() else Path('extensions') / name
|
||||||
windows_path = str(extension_location).replace('/', '\\')
|
windows_path = str(extension_location).replace('/', '\\')
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n"
|
f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n"
|
||||||
|
|
|
||||||
|
|
@ -627,10 +627,10 @@ def generate_instruct_html(history, last_message_only=False):
|
||||||
|
|
||||||
def get_character_image_with_cache_buster():
|
def get_character_image_with_cache_buster():
|
||||||
"""Get character image URL with cache busting based on file modification time"""
|
"""Get character image URL with cache busting based on file modification time"""
|
||||||
cache_path = Path("user_data/cache/pfp_character_thumb.png")
|
cache_path = shared.user_data_dir / "cache" / "pfp_character_thumb.png"
|
||||||
if cache_path.exists():
|
if cache_path.exists():
|
||||||
mtime = int(cache_path.stat().st_mtime)
|
mtime = int(cache_path.stat().st_mtime)
|
||||||
return f'<img src="file/user_data/cache/pfp_character_thumb.png?{mtime}" class="pfp_character">'
|
return f'<img src="file/{shared.user_data_dir}/cache/pfp_character_thumb.png?{mtime}" class="pfp_character">'
|
||||||
|
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
@ -654,8 +654,8 @@ def generate_cai_chat_html(history, name1, name2, style, character, reset_cache=
|
||||||
|
|
||||||
# Get appropriate image
|
# Get appropriate image
|
||||||
if role == "user":
|
if role == "user":
|
||||||
img = (f'<img src="file/user_data/cache/pfp_me.png?{time.time() if reset_cache else ""}">'
|
img = (f'<img src="file/{shared.user_data_dir}/cache/pfp_me.png?{time.time() if reset_cache else ""}">'
|
||||||
if Path("user_data/cache/pfp_me.png").exists() else '')
|
if (shared.user_data_dir / "cache" / "pfp_me.png").exists() else '')
|
||||||
else:
|
else:
|
||||||
img = img_bot
|
img = img_bot
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -384,7 +384,7 @@ class LlamaServer:
|
||||||
if shared.args.mmproj not in [None, 'None']:
|
if shared.args.mmproj not in [None, 'None']:
|
||||||
path = Path(shared.args.mmproj)
|
path = Path(shared.args.mmproj)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
path = Path('user_data/mmproj') / shared.args.mmproj
|
path = shared.user_data_dir / 'mmproj' / shared.args.mmproj
|
||||||
|
|
||||||
if path.exists():
|
if path.exists():
|
||||||
cmd += ["--mmproj", str(path)]
|
cmd += ["--mmproj", str(path)]
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,7 @@ def presets_params():
|
||||||
def load_preset(name, verbose=False):
|
def load_preset(name, verbose=False):
|
||||||
generate_params = default_preset()
|
generate_params = default_preset()
|
||||||
if name not in ['None', None, '']:
|
if name not in ['None', None, '']:
|
||||||
path = Path(f'user_data/presets/{name}.yaml')
|
path = shared.user_data_dir / 'presets' / f'{name}.yaml'
|
||||||
if path.exists():
|
if path.exists():
|
||||||
with open(path, 'r') as infile:
|
with open(path, 'r') as infile:
|
||||||
preset = yaml.safe_load(infile)
|
preset = yaml.safe_load(infile)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ def load_prompt(fname):
|
||||||
if not fname:
|
if not fname:
|
||||||
# Create new file
|
# Create new file
|
||||||
new_name = utils.current_time()
|
new_name = utils.current_time()
|
||||||
prompt_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
|
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
||||||
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
initial_content = "In this story,"
|
initial_content = "In this story,"
|
||||||
prompt_path.write_text(initial_content, encoding='utf-8')
|
prompt_path.write_text(initial_content, encoding='utf-8')
|
||||||
|
|
@ -18,7 +18,7 @@ def load_prompt(fname):
|
||||||
|
|
||||||
return initial_content
|
return initial_content
|
||||||
|
|
||||||
file_path = Path(f'user_data/logs/notebook/{fname}.txt')
|
file_path = shared.user_data_dir / 'logs' / 'notebook' / f'{fname}.txt'
|
||||||
if file_path.exists():
|
if file_path.exists():
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,12 @@ from pathlib import Path
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
from modules.paths import resolve_user_data_dir
|
||||||
from modules.presets import default_preset
|
from modules.presets import default_preset
|
||||||
|
|
||||||
|
# Resolve user_data directory early (before argparse defaults are set)
|
||||||
|
user_data_dir = resolve_user_data_dir()
|
||||||
|
|
||||||
# Text model variables
|
# Text model variables
|
||||||
model = None
|
model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
|
|
@ -42,11 +46,12 @@ parser = argparse.ArgumentParser(description="Text Generation Web UI", conflict_
|
||||||
|
|
||||||
# Basic settings
|
# Basic settings
|
||||||
group = parser.add_argument_group('Basic settings')
|
group = parser.add_argument_group('Basic settings')
|
||||||
|
group.add_argument('--user-data-dir', type=str, default=str(user_data_dir), help='Path to the user data directory. Default: auto-detected.')
|
||||||
group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.')
|
group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Warning: this is likely not safe for sharing publicly.')
|
||||||
group.add_argument('--model', type=str, help='Name of the model to load by default.')
|
group.add_argument('--model', type=str, help='Name of the model to load by default.')
|
||||||
group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.')
|
group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.')
|
||||||
group.add_argument('--model-dir', type=str, default='user_data/models', help='Path to directory with all the models.')
|
group.add_argument('--model-dir', type=str, default=str(user_data_dir / 'models'), help='Path to directory with all the models.')
|
||||||
group.add_argument('--lora-dir', type=str, default='user_data/loras', help='Path to directory with all the loras.')
|
group.add_argument('--lora-dir', type=str, default=str(user_data_dir / 'loras'), help='Path to directory with all the loras.')
|
||||||
group.add_argument('--model-menu', action='store_true', help='Show a model menu in the terminal when the web UI is first launched.')
|
group.add_argument('--model-menu', action='store_true', help='Show a model menu in the terminal when the web UI is first launched.')
|
||||||
group.add_argument('--settings', type=str, help='Load the default interface settings from this yaml file. See user_data/settings-template.yaml for an example. If you create a file called user_data/settings.yaml, this file will be loaded by default without the need to use the --settings flag.')
|
group.add_argument('--settings', type=str, help='Load the default interface settings from this yaml file. See user_data/settings-template.yaml for an example. If you create a file called user_data/settings.yaml, this file will be loaded by default without the need to use the --settings flag.')
|
||||||
group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
|
group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
|
||||||
|
|
@ -56,7 +61,7 @@ group.add_argument('--idle-timeout', type=int, default=0, help='Unload model aft
|
||||||
# Image generation
|
# Image generation
|
||||||
group = parser.add_argument_group('Image model')
|
group = parser.add_argument_group('Image model')
|
||||||
group.add_argument('--image-model', type=str, help='Name of the image model to select on startup (overrides saved setting).')
|
group.add_argument('--image-model', type=str, help='Name of the image model to select on startup (overrides saved setting).')
|
||||||
group.add_argument('--image-model-dir', type=str, default='user_data/image_models', help='Path to directory with all the image models.')
|
group.add_argument('--image-model-dir', type=str, default=str(user_data_dir / 'image_models'), help='Path to directory with all the image models.')
|
||||||
group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16', 'float16'], help='Data type for image model.')
|
group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16', 'float16'], help='Data type for image model.')
|
||||||
group.add_argument('--image-attn-backend', type=str, default=None, choices=['flash_attention_2', 'sdpa'], help='Attention backend for image model.')
|
group.add_argument('--image-attn-backend', type=str, default=None, choices=['flash_attention_2', 'sdpa'], help='Attention backend for image model.')
|
||||||
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
|
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
|
||||||
|
|
@ -110,7 +115,7 @@ group = parser.add_argument_group('Transformers/Accelerate')
|
||||||
group.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
group.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
||||||
group.add_argument('--cpu-memory', type=float, default=0, help='Maximum CPU memory in GiB. Use this for CPU offloading.')
|
group.add_argument('--cpu-memory', type=float, default=0, help='Maximum CPU memory in GiB. Use this for CPU offloading.')
|
||||||
group.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
|
group.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
|
||||||
group.add_argument('--disk-cache-dir', type=str, default='user_data/cache', help='Directory to save the disk cache to. Defaults to "user_data/cache".')
|
group.add_argument('--disk-cache-dir', type=str, default=str(user_data_dir / 'cache'), help='Directory to save the disk cache to.')
|
||||||
group.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision (using bitsandbytes).')
|
group.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision (using bitsandbytes).')
|
||||||
group.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
group.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||||
group.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.')
|
group.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.')
|
||||||
|
|
@ -167,7 +172,7 @@ group.add_argument('--api-disable-ipv4', action='store_true', help='Disable IPv4
|
||||||
group.add_argument('--nowebui', action='store_true', help='Do not launch the Gradio UI. Useful for launching the API in standalone mode.')
|
group.add_argument('--nowebui', action='store_true', help='Do not launch the Gradio UI. Useful for launching the API in standalone mode.')
|
||||||
|
|
||||||
# Handle CMD_FLAGS.txt
|
# Handle CMD_FLAGS.txt
|
||||||
cmd_flags_path = Path(__file__).parent.parent / "user_data" / "CMD_FLAGS.txt"
|
cmd_flags_path = user_data_dir / "CMD_FLAGS.txt"
|
||||||
if cmd_flags_path.exists():
|
if cmd_flags_path.exists():
|
||||||
with cmd_flags_path.open('r', encoding='utf-8') as f:
|
with cmd_flags_path.open('r', encoding='utf-8') as f:
|
||||||
cmd_flags = ' '.join(
|
cmd_flags = ' '.join(
|
||||||
|
|
@ -182,6 +187,7 @@ if cmd_flags_path.exists():
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
user_data_dir = Path(args.user_data_dir) # Update from parsed args (may differ from pre-parse)
|
||||||
original_args = copy.deepcopy(args)
|
original_args = copy.deepcopy(args)
|
||||||
args_defaults = parser.parse_args([])
|
args_defaults = parser.parse_args([])
|
||||||
|
|
||||||
|
|
@ -212,7 +218,7 @@ settings = {
|
||||||
'enable_web_search': False,
|
'enable_web_search': False,
|
||||||
'web_search_pages': 3,
|
'web_search_pages': 3,
|
||||||
'prompt-notebook': '',
|
'prompt-notebook': '',
|
||||||
'preset': 'Qwen3 - Thinking' if Path('user_data/presets/Qwen3 - Thinking.yaml').exists() else None,
|
'preset': 'Qwen3 - Thinking' if (user_data_dir / 'presets/Qwen3 - Thinking.yaml').exists() else None,
|
||||||
'max_new_tokens': 512,
|
'max_new_tokens': 512,
|
||||||
'max_new_tokens_min': 1,
|
'max_new_tokens_min': 1,
|
||||||
'max_new_tokens_max': 4096,
|
'max_new_tokens_max': 4096,
|
||||||
|
|
|
||||||
|
|
@ -107,8 +107,8 @@ def create_ui():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Tab(label='Chat Dataset'):
|
with gr.Tab(label='Chat Dataset'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset = gr.Dropdown(choices=utils.get_chat_datasets('user_data/training/datasets'), value='None', label='Dataset File', info='A JSON file with chat conversations (messages or ShareGPT format). Each row is one conversation.', elem_classes=['slim-dropdown'], interactive=not mu)
|
dataset = gr.Dropdown(choices=utils.get_chat_datasets(str(shared.user_data_dir / 'training/datasets')), value='None', label='Dataset File', info='A JSON file with chat conversations (messages or ShareGPT format). Each row is one conversation.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||||
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_chat_datasets('user_data/training/datasets')}, 'refresh-button', interactive=not mu)
|
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_chat_datasets(str(shared.user_data_dir / 'training/datasets'))}, 'refresh-button', interactive=not mu)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
format = gr.Dropdown(choices=get_instruction_templates(), value='None', label='Instruction Template', info='Select an instruction template for formatting the dataset, or "Chat Template" to use the model\'s built-in chat template.', elem_classes=['slim-dropdown'], interactive=not mu)
|
format = gr.Dropdown(choices=get_instruction_templates(), value='None', label='Instruction Template', info='Select an instruction template for formatting the dataset, or "Chat Template" to use the model\'s built-in chat template.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||||
|
|
@ -116,14 +116,14 @@ def create_ui():
|
||||||
|
|
||||||
with gr.Tab(label="Text Dataset"):
|
with gr.Tab(label="Text Dataset"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
text_dataset = gr.Dropdown(choices=utils.get_text_datasets('user_data/training/datasets'), value='None', label='Dataset File', info='A JSON file with a "text" key per row, for pretraining-style training. Each row is one document.', elem_classes=['slim-dropdown'], interactive=not mu)
|
text_dataset = gr.Dropdown(choices=utils.get_text_datasets(str(shared.user_data_dir / 'training/datasets')), value='None', label='Dataset File', info='A JSON file with a "text" key per row, for pretraining-style training. Each row is one document.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||||
ui.create_refresh_button(text_dataset, lambda: None, lambda: {'choices': utils.get_text_datasets('user_data/training/datasets')}, 'refresh-button', interactive=not mu)
|
ui.create_refresh_button(text_dataset, lambda: None, lambda: {'choices': utils.get_text_datasets(str(shared.user_data_dir / 'training/datasets'))}, 'refresh-button', interactive=not mu)
|
||||||
|
|
||||||
stride_length = gr.Slider(label='Stride Length', minimum=0, maximum=2048, value=256, step=32, info='Overlap between chunks in tokens. 0 = no overlap. Values like 256 or 512 help preserve context across chunk boundaries.')
|
stride_length = gr.Slider(label='Stride Length', minimum=0, maximum=2048, value=256, step=32, info='Overlap between chunks in tokens. 0 = no overlap. Values like 256 or 512 help preserve context across chunk boundaries.')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
eval_dataset = gr.Dropdown(choices=utils.get_datasets('user_data/training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'], interactive=not mu)
|
eval_dataset = gr.Dropdown(choices=utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'], interactive=not mu)
|
||||||
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('user_data/training/datasets', 'json')}, 'refresh-button', interactive=not mu)
|
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'json')}, 'refresh-button', interactive=not mu)
|
||||||
|
|
||||||
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
||||||
|
|
||||||
|
|
@ -137,7 +137,7 @@ def create_ui():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True, interactive=not mu)
|
models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True, interactive=not mu)
|
||||||
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets('user_data/training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under user_data/training/datasets.', interactive=not mu)
|
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'txt')[1:], value='wikitext', label='Input dataset', info=f'The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under {shared.user_data_dir}/training/datasets.', interactive=not mu)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
stride_length = gr.Slider(label='Stride', minimum=0, maximum=32768, value=512, step=256, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
|
stride_length = gr.Slider(label='Stride', minimum=0, maximum=32768, value=512, step=256, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
|
||||||
|
|
@ -224,7 +224,7 @@ def clean_path(base_path: str, path: str):
|
||||||
|
|
||||||
|
|
||||||
def get_instruction_templates():
|
def get_instruction_templates():
|
||||||
path = Path('user_data/instruction-templates')
|
path = shared.user_data_dir / 'instruction-templates'
|
||||||
names = set()
|
names = set()
|
||||||
for ext in ['yaml', 'yml', 'jinja']:
|
for ext in ['yaml', 'yml', 'jinja']:
|
||||||
for f in path.glob(f'*.{ext}'):
|
for f in path.glob(f'*.{ext}'):
|
||||||
|
|
@ -233,8 +233,8 @@ def get_instruction_templates():
|
||||||
|
|
||||||
|
|
||||||
def load_template(name):
|
def load_template(name):
|
||||||
"""Load a Jinja2 template string from user_data/instruction-templates/."""
|
"""Load a Jinja2 template string from {user_data_dir}/instruction-templates/."""
|
||||||
path = Path('user_data/instruction-templates')
|
path = shared.user_data_dir / 'instruction-templates'
|
||||||
for ext in ['jinja', 'yaml', 'yml']:
|
for ext in ['jinja', 'yaml', 'yml']:
|
||||||
filepath = path / f'{name}.{ext}'
|
filepath = path / f'{name}.{ext}'
|
||||||
if filepath.exists():
|
if filepath.exists():
|
||||||
|
|
@ -453,7 +453,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
||||||
if has_text_dataset:
|
if has_text_dataset:
|
||||||
train_template["template_type"] = "text_dataset"
|
train_template["template_type"] = "text_dataset"
|
||||||
logger.info("Loading text dataset")
|
logger.info("Loading text dataset")
|
||||||
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{text_dataset}.json'))
|
data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{text_dataset}.json'))
|
||||||
|
|
||||||
if "text" not in data['train'].column_names:
|
if "text" not in data['train'].column_names:
|
||||||
yield "Error: text dataset must have a \"text\" key per row."
|
yield "Error: text dataset must have a \"text\" key per row."
|
||||||
|
|
@ -467,7 +467,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
||||||
if eval_dataset == 'None':
|
if eval_dataset == 'None':
|
||||||
eval_data = None
|
eval_data = None
|
||||||
else:
|
else:
|
||||||
eval_raw = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
|
eval_raw = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{eval_dataset}.json'))
|
||||||
if "text" not in eval_raw['train'].column_names:
|
if "text" not in eval_raw['train'].column_names:
|
||||||
yield "Error: evaluation dataset must have a \"text\" key per row."
|
yield "Error: evaluation dataset must have a \"text\" key per row."
|
||||||
return
|
return
|
||||||
|
|
@ -496,7 +496,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
||||||
train_template["template_type"] = "chat_template"
|
train_template["template_type"] = "chat_template"
|
||||||
|
|
||||||
logger.info("Loading JSON dataset with chat template format")
|
logger.info("Loading JSON dataset with chat template format")
|
||||||
data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{dataset}.json'))
|
data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{dataset}.json'))
|
||||||
|
|
||||||
# Validate the first row
|
# Validate the first row
|
||||||
try:
|
try:
|
||||||
|
|
@ -522,7 +522,7 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
||||||
if eval_dataset == 'None':
|
if eval_dataset == 'None':
|
||||||
eval_data = None
|
eval_data = None
|
||||||
else:
|
else:
|
||||||
eval_data = load_dataset("json", data_files=clean_path('user_data/training/datasets', f'{eval_dataset}.json'))
|
eval_data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{eval_dataset}.json'))
|
||||||
eval_data = eval_data['train'].map(
|
eval_data = eval_data['train'].map(
|
||||||
tokenize_conversation,
|
tokenize_conversation,
|
||||||
remove_columns=eval_data['train'].column_names,
|
remove_columns=eval_data['train'].column_names,
|
||||||
|
|
@ -757,11 +757,11 @@ def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en:
|
||||||
decoded_entries.append({"value": decoded_text})
|
decoded_entries.append({"value": decoded_text})
|
||||||
|
|
||||||
# Write the log file
|
# Write the log file
|
||||||
Path('user_data/logs').mkdir(exist_ok=True)
|
(shared.user_data_dir / 'logs').mkdir(exist_ok=True)
|
||||||
with open(Path('user_data/logs/train_dataset_sample.json'), 'w') as json_file:
|
with open(shared.user_data_dir / 'logs' / 'train_dataset_sample.json', 'w') as json_file:
|
||||||
json.dump(decoded_entries, json_file, indent=4)
|
json.dump(decoded_entries, json_file, indent=4)
|
||||||
|
|
||||||
logger.info("Log file 'train_dataset_sample.json' created in the 'user_data/logs' directory.")
|
logger.info(f"Log file 'train_dataset_sample.json' created in the '{shared.user_data_dir}/logs' directory.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to create log file due to error: {e}")
|
logger.error(f"Failed to create log file due to error: {e}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,7 @@ if not shared.args.old_colors:
|
||||||
block_radius='0',
|
block_radius='0',
|
||||||
)
|
)
|
||||||
|
|
||||||
if Path("user_data/notification.mp3").exists():
|
if (shared.user_data_dir / "notification.mp3").exists():
|
||||||
audio_notification_js = "document.querySelector('#audio_notification audio')?.play();"
|
audio_notification_js = "document.querySelector('#audio_notification audio')?.play();"
|
||||||
else:
|
else:
|
||||||
audio_notification_js = ""
|
audio_notification_js = ""
|
||||||
|
|
@ -381,7 +381,7 @@ def save_settings(state, preset, extensions_list, show_controls, theme_state, ma
|
||||||
output[_id] = params[param]
|
output[_id] = params[param]
|
||||||
else:
|
else:
|
||||||
# Preserve existing extensions and extension parameters during autosave
|
# Preserve existing extensions and extension parameters during autosave
|
||||||
settings_path = Path('user_data') / 'settings.yaml'
|
settings_path = shared.user_data_dir / 'settings.yaml'
|
||||||
if settings_path.exists():
|
if settings_path.exists():
|
||||||
try:
|
try:
|
||||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||||
|
|
@ -436,7 +436,7 @@ def _perform_debounced_save():
|
||||||
try:
|
try:
|
||||||
if _last_interface_state is not None:
|
if _last_interface_state is not None:
|
||||||
contents = save_settings(_last_interface_state, _last_preset, _last_extensions, _last_show_controls, _last_theme_state, manual_save=False)
|
contents = save_settings(_last_interface_state, _last_preset, _last_extensions, _last_show_controls, _last_theme_state, manual_save=False)
|
||||||
settings_path = Path('user_data') / 'settings.yaml'
|
settings_path = shared.user_data_dir / 'settings.yaml'
|
||||||
settings_path.parent.mkdir(exist_ok=True)
|
settings_path.parent.mkdir(exist_ok=True)
|
||||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||||
f.write(contents)
|
f.write(contents)
|
||||||
|
|
|
||||||
|
|
@ -175,7 +175,7 @@ def create_character_settings_ui():
|
||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
shared.gradio['character_picture'] = gr.Image(label='Character picture', type='filepath', interactive=not mu)
|
shared.gradio['character_picture'] = gr.Image(label='Character picture', type='filepath', interactive=not mu)
|
||||||
shared.gradio['your_picture'] = gr.Image(label='Your picture', type='filepath', value=Image.open(Path('user_data/cache/pfp_me.png')) if Path('user_data/cache/pfp_me.png').exists() else None, interactive=not mu)
|
shared.gradio['your_picture'] = gr.Image(label='Your picture', type='filepath', value=Image.open(shared.user_data_dir / 'cache' / 'pfp_me.png') if (shared.user_data_dir / 'cache' / 'pfp_me.png').exists() else None, interactive=not mu)
|
||||||
|
|
||||||
|
|
||||||
def create_chat_settings_ui():
|
def create_chat_settings_ui():
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,7 @@ def handle_new_prompt():
|
||||||
new_name = utils.current_time()
|
new_name = utils.current_time()
|
||||||
|
|
||||||
# Create the new prompt file
|
# Create the new prompt file
|
||||||
prompt_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
|
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
||||||
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
prompt_path.write_text("In this story,", encoding='utf-8')
|
prompt_path.write_text("In this story,", encoding='utf-8')
|
||||||
|
|
||||||
|
|
@ -170,15 +170,15 @@ def handle_delete_prompt_confirm_default(prompt_name):
|
||||||
available_prompts = utils.get_available_prompts()
|
available_prompts = utils.get_available_prompts()
|
||||||
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
|
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
|
||||||
|
|
||||||
(Path("user_data/logs/notebook") / f"{prompt_name}.txt").unlink(missing_ok=True)
|
(shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt").unlink(missing_ok=True)
|
||||||
available_prompts = utils.get_available_prompts()
|
available_prompts = utils.get_available_prompts()
|
||||||
|
|
||||||
if available_prompts:
|
if available_prompts:
|
||||||
new_value = available_prompts[min(current_index, len(available_prompts) - 1)]
|
new_value = available_prompts[min(current_index, len(available_prompts) - 1)]
|
||||||
else:
|
else:
|
||||||
new_value = utils.current_time()
|
new_value = utils.current_time()
|
||||||
Path("user_data/logs/notebook").mkdir(parents=True, exist_ok=True)
|
(shared.user_data_dir / "logs" / "notebook").mkdir(parents=True, exist_ok=True)
|
||||||
(Path("user_data/logs/notebook") / f"{new_value}.txt").write_text("In this story,")
|
(shared.user_data_dir / "logs" / "notebook" / f"{new_value}.txt").write_text("In this story,")
|
||||||
available_prompts = [new_value]
|
available_prompts = [new_value]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|
@ -199,8 +199,8 @@ def handle_rename_prompt_click_default(current_name):
|
||||||
|
|
||||||
|
|
||||||
def handle_rename_prompt_confirm_default(new_name, current_name):
|
def handle_rename_prompt_confirm_default(new_name, current_name):
|
||||||
old_path = Path("user_data/logs/notebook") / f"{current_name}.txt"
|
old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt"
|
||||||
new_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
|
new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
||||||
|
|
||||||
if old_path.exists() and not new_path.exists():
|
if old_path.exists() and not new_path.exists():
|
||||||
old_path.rename(new_path)
|
old_path.rename(new_path)
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ def create_ui():
|
||||||
|
|
||||||
# Character saver/deleter
|
# Character saver/deleter
|
||||||
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['character_saver']:
|
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['character_saver']:
|
||||||
shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name', info='The character will be saved to your user_data/characters folder with this base filename.')
|
shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name', info=f'The character will be saved to your {shared.user_data_dir}/characters folder with this base filename.')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['save_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
shared.gradio['save_character_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||||
shared.gradio['save_character_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu)
|
shared.gradio['save_character_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu)
|
||||||
|
|
@ -41,7 +41,7 @@ def create_ui():
|
||||||
|
|
||||||
# User saver/deleter
|
# User saver/deleter
|
||||||
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['user_saver']:
|
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['user_saver']:
|
||||||
shared.gradio['save_user_filename'] = gr.Textbox(lines=1, label='File name', info='The user profile will be saved to your user_data/users folder with this base filename.')
|
shared.gradio['save_user_filename'] = gr.Textbox(lines=1, label='File name', info=f'The user profile will be saved to your {shared.user_data_dir}/users folder with this base filename.')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['save_user_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
shared.gradio['save_user_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||||
shared.gradio['save_user_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu)
|
shared.gradio['save_user_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu)
|
||||||
|
|
@ -54,7 +54,7 @@ def create_ui():
|
||||||
|
|
||||||
# Preset saver
|
# Preset saver
|
||||||
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['preset_saver']:
|
with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['preset_saver']:
|
||||||
shared.gradio['save_preset_filename'] = gr.Textbox(lines=1, label='File name', info='The preset will be saved to your user_data/presets folder with this base filename.')
|
shared.gradio['save_preset_filename'] = gr.Textbox(lines=1, label='File name', info=f'The preset will be saved to your {shared.user_data_dir}/presets folder with this base filename.')
|
||||||
shared.gradio['save_preset_contents'] = gr.Textbox(lines=10, label='File contents')
|
shared.gradio['save_preset_contents'] = gr.Textbox(lines=10, label='File contents')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['save_preset_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
shared.gradio['save_preset_cancel'] = gr.Button('Cancel', elem_classes="small-button")
|
||||||
|
|
@ -91,7 +91,7 @@ def create_event_handlers():
|
||||||
|
|
||||||
def handle_save_preset_confirm_click(filename, contents):
|
def handle_save_preset_confirm_click(filename, contents):
|
||||||
try:
|
try:
|
||||||
utils.save_file(f"user_data/presets/{filename}.yaml", contents)
|
utils.save_file(str(shared.user_data_dir / "presets" / f"{filename}.yaml"), contents)
|
||||||
available_presets = utils.get_available_presets()
|
available_presets = utils.get_available_presets()
|
||||||
output = gr.update(choices=available_presets, value=filename)
|
output = gr.update(choices=available_presets, value=filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -164,7 +164,7 @@ def handle_save_preset_click(state):
|
||||||
def handle_delete_preset_click(preset):
|
def handle_delete_preset_click(preset):
|
||||||
return [
|
return [
|
||||||
f"{preset}.yaml",
|
f"{preset}.yaml",
|
||||||
"user_data/presets/",
|
str(shared.user_data_dir / "presets") + "/",
|
||||||
gr.update(visible=True)
|
gr.update(visible=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -173,7 +173,7 @@ def handle_save_grammar_click(grammar_string):
|
||||||
return [
|
return [
|
||||||
grammar_string,
|
grammar_string,
|
||||||
"My Fancy Grammar.gbnf",
|
"My Fancy Grammar.gbnf",
|
||||||
"user_data/grammars/",
|
str(shared.user_data_dir / "grammars") + "/",
|
||||||
gr.update(visible=True)
|
gr.update(visible=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -181,7 +181,7 @@ def handle_save_grammar_click(grammar_string):
|
||||||
def handle_delete_grammar_click(grammar_file):
|
def handle_delete_grammar_click(grammar_file):
|
||||||
return [
|
return [
|
||||||
grammar_file,
|
grammar_file,
|
||||||
"user_data/grammars/",
|
str(shared.user_data_dir / "grammars") + "/",
|
||||||
gr.update(visible=True)
|
gr.update(visible=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -138,7 +138,7 @@ def save_generated_images(images, state, actual_seed):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||||
folder_path = os.path.join("user_data", "image_outputs", date_str)
|
folder_path = str(shared.user_data_dir / "image_outputs" / date_str)
|
||||||
os.makedirs(folder_path, exist_ok=True)
|
os.makedirs(folder_path, exist_ok=True)
|
||||||
|
|
||||||
metadata = build_generation_metadata(state, actual_seed)
|
metadata = build_generation_metadata(state, actual_seed)
|
||||||
|
|
@ -214,7 +214,7 @@ def get_all_history_images(force_refresh=False):
|
||||||
"""Get all history images sorted by modification time (newest first). Uses caching."""
|
"""Get all history images sorted by modification time (newest first). Uses caching."""
|
||||||
global _image_cache, _cache_timestamp
|
global _image_cache, _cache_timestamp
|
||||||
|
|
||||||
output_dir = os.path.join("user_data", "image_outputs")
|
output_dir = str(shared.user_data_dir / "image_outputs")
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ def create_ui():
|
||||||
# Multimodal
|
# Multimodal
|
||||||
with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']:
|
with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info='Select a file that matches your model. Must be placed in user_data/mmproj/', interactive=not mu)
|
shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info=f'Select a file that matches your model. Must be placed in {shared.user_data_dir}/mmproj/', interactive=not mu)
|
||||||
ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
|
ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
|
|
@ -317,9 +317,9 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
|
||||||
model_dir=shared.args.model_dir if shared.args.model_dir != shared.args_defaults.model_dir else None
|
model_dir=shared.args.model_dir if shared.args.model_dir != shared.args_defaults.model_dir else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if output_folder == Path("user_data/models"):
|
if output_folder == shared.user_data_dir / "models":
|
||||||
output_folder = Path(shared.args.model_dir)
|
output_folder = Path(shared.args.model_dir)
|
||||||
elif output_folder == Path("user_data/loras"):
|
elif output_folder == shared.user_data_dir / "loras":
|
||||||
output_folder = Path(shared.args.lora_dir)
|
output_folder = Path(shared.args.lora_dir)
|
||||||
|
|
||||||
if check:
|
if check:
|
||||||
|
|
|
||||||
|
|
@ -194,7 +194,7 @@ def handle_new_prompt():
|
||||||
new_name = utils.current_time()
|
new_name = utils.current_time()
|
||||||
|
|
||||||
# Create the new prompt file
|
# Create the new prompt file
|
||||||
prompt_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
|
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
||||||
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
prompt_path.write_text("In this story,", encoding='utf-8')
|
prompt_path.write_text("In this story,", encoding='utf-8')
|
||||||
|
|
||||||
|
|
@ -205,15 +205,15 @@ def handle_delete_prompt_confirm_notebook(prompt_name):
|
||||||
available_prompts = utils.get_available_prompts()
|
available_prompts = utils.get_available_prompts()
|
||||||
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
|
current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0
|
||||||
|
|
||||||
(Path("user_data/logs/notebook") / f"{prompt_name}.txt").unlink(missing_ok=True)
|
(shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt").unlink(missing_ok=True)
|
||||||
available_prompts = utils.get_available_prompts()
|
available_prompts = utils.get_available_prompts()
|
||||||
|
|
||||||
if available_prompts:
|
if available_prompts:
|
||||||
new_value = available_prompts[min(current_index, len(available_prompts) - 1)]
|
new_value = available_prompts[min(current_index, len(available_prompts) - 1)]
|
||||||
else:
|
else:
|
||||||
new_value = utils.current_time()
|
new_value = utils.current_time()
|
||||||
Path("user_data/logs/notebook").mkdir(parents=True, exist_ok=True)
|
(shared.user_data_dir / "logs" / "notebook").mkdir(parents=True, exist_ok=True)
|
||||||
(Path("user_data/logs/notebook") / f"{new_value}.txt").write_text("In this story,")
|
(shared.user_data_dir / "logs" / "notebook" / f"{new_value}.txt").write_text("In this story,")
|
||||||
available_prompts = [new_value]
|
available_prompts = [new_value]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|
@ -233,8 +233,8 @@ def handle_rename_prompt_click_notebook(current_name):
|
||||||
|
|
||||||
|
|
||||||
def handle_rename_prompt_confirm_notebook(new_name, current_name):
|
def handle_rename_prompt_confirm_notebook(new_name, current_name):
|
||||||
old_path = Path("user_data/logs/notebook") / f"{current_name}.txt"
|
old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt"
|
||||||
new_path = Path("user_data/logs/notebook") / f"{new_name}.txt"
|
new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt"
|
||||||
|
|
||||||
if old_path.exists() and not new_path.exists():
|
if old_path.exists() and not new_path.exists():
|
||||||
old_path.rename(new_path)
|
old_path.rename(new_path)
|
||||||
|
|
@ -250,7 +250,7 @@ def handle_rename_prompt_confirm_notebook(new_name, current_name):
|
||||||
def autosave_prompt(text, prompt_name):
|
def autosave_prompt(text, prompt_name):
|
||||||
"""Automatically save the text to the selected prompt file"""
|
"""Automatically save the text to the selected prompt file"""
|
||||||
if prompt_name and text.strip():
|
if prompt_name and text.strip():
|
||||||
prompt_path = Path("user_data/logs/notebook") / f"{prompt_name}.txt"
|
prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt"
|
||||||
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
prompt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
prompt_path.write_text(text, encoding='utf-8')
|
prompt_path.write_text(text, encoding='utf-8')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -135,7 +135,7 @@ def get_truncation_length():
|
||||||
|
|
||||||
|
|
||||||
def load_grammar(name):
|
def load_grammar(name):
|
||||||
p = Path(f'user_data/grammars/{name}')
|
p = shared.user_data_dir / 'grammars' / name
|
||||||
if p.exists():
|
if p.exists():
|
||||||
return open(p, 'r', encoding='utf-8').read()
|
return open(p, 'r', encoding='utf-8').read()
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ def create_ui():
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown("## Extensions & flags")
|
gr.Markdown("## Extensions & flags")
|
||||||
shared.gradio['save_settings'] = gr.Button('Save extensions settings to user_data/settings.yaml', elem_classes='refresh-button', interactive=not mu)
|
shared.gradio['save_settings'] = gr.Button(f'Save extensions settings to {shared.user_data_dir}/settings.yaml', elem_classes='refresh-button', interactive=not mu)
|
||||||
shared.gradio['reset_interface'] = gr.Button("Apply flags/extensions and restart", interactive=not mu)
|
shared.gradio['reset_interface'] = gr.Button("Apply flags/extensions and restart", interactive=not mu)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
|
@ -54,7 +54,7 @@ def handle_save_settings(state, preset, extensions, show_controls, theme):
|
||||||
return [
|
return [
|
||||||
contents,
|
contents,
|
||||||
"settings.yaml",
|
"settings.yaml",
|
||||||
"user_data/",
|
str(shared.user_data_dir) + "/",
|
||||||
gr.update(visible=True)
|
gr.update(visible=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,16 +15,31 @@ def gradio(*keys):
|
||||||
return [shared.gradio[k] for k in keys]
|
return [shared.gradio[k] for k in keys]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_path_allowed(abs_path_str):
|
||||||
|
"""Check if a path is under the project root or the configured user_data directory."""
|
||||||
|
abs_path = Path(abs_path_str).resolve()
|
||||||
|
root_folder = Path(__file__).resolve().parent.parent
|
||||||
|
user_data_resolved = shared.user_data_dir.resolve()
|
||||||
|
try:
|
||||||
|
abs_path.relative_to(root_folder)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
abs_path.relative_to(user_data_resolved)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def save_file(fname, contents):
|
def save_file(fname, contents):
|
||||||
if fname == '':
|
if fname == '':
|
||||||
logger.error('File name is empty!')
|
logger.error('File name is empty!')
|
||||||
return
|
return
|
||||||
|
|
||||||
root_folder = Path(__file__).resolve().parent.parent
|
|
||||||
abs_path_str = os.path.abspath(fname)
|
abs_path_str = os.path.abspath(fname)
|
||||||
rel_path_str = os.path.relpath(abs_path_str, root_folder)
|
if not _is_path_allowed(abs_path_str):
|
||||||
rel_path = Path(rel_path_str)
|
|
||||||
if rel_path.parts[0] == '..':
|
|
||||||
logger.error(f'Invalid file path: \"{fname}\"')
|
logger.error(f'Invalid file path: \"{fname}\"')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -39,16 +54,14 @@ def delete_file(fname):
|
||||||
logger.error('File name is empty!')
|
logger.error('File name is empty!')
|
||||||
return
|
return
|
||||||
|
|
||||||
root_folder = Path(__file__).resolve().parent.parent
|
|
||||||
abs_path_str = os.path.abspath(fname)
|
abs_path_str = os.path.abspath(fname)
|
||||||
rel_path_str = os.path.relpath(abs_path_str, root_folder)
|
if not _is_path_allowed(abs_path_str):
|
||||||
rel_path = Path(rel_path_str)
|
|
||||||
if rel_path.parts[0] == '..':
|
|
||||||
logger.error(f'Invalid file path: \"{fname}\"')
|
logger.error(f'Invalid file path: \"{fname}\"')
|
||||||
return
|
return
|
||||||
|
|
||||||
if rel_path.exists():
|
p = Path(abs_path_str)
|
||||||
rel_path.unlink()
|
if p.exists():
|
||||||
|
p.unlink()
|
||||||
logger.info(f'Deleted \"{fname}\".')
|
logger.info(f'Deleted \"{fname}\".')
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -75,7 +88,7 @@ def natural_keys(text):
|
||||||
def check_model_loaded():
|
def check_model_loaded():
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
if len(get_available_models()) == 0:
|
if len(get_available_models()) == 0:
|
||||||
error_msg = "No model is loaded.\n\nTo get started:\n1) Place a GGUF file in your user_data/models folder\n2) Go to the Model tab and select it"
|
error_msg = f"No model is loaded.\n\nTo get started:\n1) Place a GGUF file in your {shared.user_data_dir}/models folder\n2) Go to the Model tab and select it"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return False, error_msg
|
return False, error_msg
|
||||||
else:
|
else:
|
||||||
|
|
@ -188,7 +201,7 @@ def get_available_ggufs():
|
||||||
|
|
||||||
|
|
||||||
def get_available_mmproj():
|
def get_available_mmproj():
|
||||||
mmproj_dir = Path('user_data/mmproj')
|
mmproj_dir = shared.user_data_dir / 'mmproj'
|
||||||
if not mmproj_dir.exists():
|
if not mmproj_dir.exists():
|
||||||
return ['None']
|
return ['None']
|
||||||
|
|
||||||
|
|
@ -201,11 +214,11 @@ def get_available_mmproj():
|
||||||
|
|
||||||
|
|
||||||
def get_available_presets():
|
def get_available_presets():
|
||||||
return sorted(set((k.stem for k in Path('user_data/presets').glob('*.yaml'))), key=natural_keys)
|
return sorted(set((k.stem for k in (shared.user_data_dir / 'presets').glob('*.yaml'))), key=natural_keys)
|
||||||
|
|
||||||
|
|
||||||
def get_available_prompts():
|
def get_available_prompts():
|
||||||
notebook_dir = Path('user_data/logs/notebook')
|
notebook_dir = shared.user_data_dir / 'logs' / 'notebook'
|
||||||
notebook_dir.mkdir(parents=True, exist_ok=True)
|
notebook_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
prompt_files = list(notebook_dir.glob('*.txt'))
|
prompt_files = list(notebook_dir.glob('*.txt'))
|
||||||
|
|
@ -221,19 +234,19 @@ def get_available_prompts():
|
||||||
|
|
||||||
|
|
||||||
def get_available_characters():
|
def get_available_characters():
|
||||||
paths = (x for x in Path('user_data/characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
paths = (x for x in (shared.user_data_dir / 'characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||||
return sorted(set((k.stem for k in paths)), key=natural_keys)
|
return sorted(set((k.stem for k in paths)), key=natural_keys)
|
||||||
|
|
||||||
|
|
||||||
def get_available_users():
|
def get_available_users():
|
||||||
users_dir = Path('user_data/users')
|
users_dir = shared.user_data_dir / 'users'
|
||||||
users_dir.mkdir(parents=True, exist_ok=True)
|
users_dir.mkdir(parents=True, exist_ok=True)
|
||||||
paths = (x for x in users_dir.iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
paths = (x for x in users_dir.iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||||
return sorted(set((k.stem for k in paths)), key=natural_keys)
|
return sorted(set((k.stem for k in paths)), key=natural_keys)
|
||||||
|
|
||||||
|
|
||||||
def get_available_instruction_templates():
|
def get_available_instruction_templates():
|
||||||
path = "user_data/instruction-templates"
|
path = str(shared.user_data_dir / "instruction-templates")
|
||||||
paths = []
|
paths = []
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||||
|
|
@ -244,13 +257,13 @@ def get_available_instruction_templates():
|
||||||
def get_available_extensions():
|
def get_available_extensions():
|
||||||
# User extensions (higher priority)
|
# User extensions (higher priority)
|
||||||
user_extensions = []
|
user_extensions = []
|
||||||
user_ext_path = Path('user_data/extensions')
|
user_ext_path = shared.user_data_dir / 'extensions'
|
||||||
if user_ext_path.exists():
|
if user_ext_path.exists():
|
||||||
user_exts = map(lambda x: x.parts[2], user_ext_path.glob('*/script.py'))
|
user_exts = map(lambda x: x.parent.name, user_ext_path.glob('*/script.py'))
|
||||||
user_extensions = sorted(set(user_exts), key=natural_keys)
|
user_extensions = sorted(set(user_exts), key=natural_keys)
|
||||||
|
|
||||||
# System extensions (excluding those overridden by user extensions)
|
# System extensions (excluding those overridden by user extensions)
|
||||||
system_exts = map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))
|
system_exts = map(lambda x: x.parent.name, Path('extensions').glob('*/script.py'))
|
||||||
system_extensions = sorted(set(system_exts) - set(user_extensions), key=natural_keys)
|
system_extensions = sorted(set(system_exts) - set(user_extensions), key=natural_keys)
|
||||||
|
|
||||||
return user_extensions + system_extensions
|
return user_extensions + system_extensions
|
||||||
|
|
@ -306,4 +319,4 @@ def get_available_chat_styles():
|
||||||
|
|
||||||
|
|
||||||
def get_available_grammars():
|
def get_available_grammars():
|
||||||
return ['None'] + sorted([item.name for item in list(Path('user_data/grammars').glob('*.gbnf'))], key=natural_keys)
|
return ['None'] + sorted([item.name for item in list((shared.user_data_dir / 'grammars').glob('*.gbnf'))], key=natural_keys)
|
||||||
|
|
|
||||||
16
server.py
16
server.py
|
|
@ -9,7 +9,7 @@ from modules.logging_colors import logger
|
||||||
from modules.prompts import load_prompt
|
from modules.prompts import load_prompt
|
||||||
|
|
||||||
# Set up Gradio temp directory path
|
# Set up Gradio temp directory path
|
||||||
gradio_temp_path = Path('user_data') / 'cache' / 'gradio'
|
gradio_temp_path = shared.user_data_dir / 'cache' / 'gradio'
|
||||||
shutil.rmtree(gradio_temp_path, ignore_errors=True)
|
shutil.rmtree(gradio_temp_path, ignore_errors=True)
|
||||||
gradio_temp_path.mkdir(parents=True, exist_ok=True)
|
gradio_temp_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
@ -94,9 +94,9 @@ def create_interface():
|
||||||
auth = [tuple(cred.split(':')) for cred in auth]
|
auth = [tuple(cred.split(':')) for cred in auth]
|
||||||
|
|
||||||
# Allowed paths
|
# Allowed paths
|
||||||
allowed_paths = ["css", "js", "extensions", "user_data/cache"]
|
allowed_paths = ["css", "js", "extensions", str(shared.user_data_dir / "cache")]
|
||||||
if not shared.args.multi_user:
|
if not shared.args.multi_user:
|
||||||
allowed_paths.append("user_data/image_outputs")
|
allowed_paths.append(str(shared.user_data_dir / "image_outputs"))
|
||||||
|
|
||||||
# Import the extensions and execute their setup() functions
|
# Import the extensions and execute their setup() functions
|
||||||
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
||||||
|
|
@ -120,7 +120,7 @@ def create_interface():
|
||||||
|
|
||||||
# Clear existing cache files
|
# Clear existing cache files
|
||||||
for cache_file in ['pfp_character.png', 'pfp_character_thumb.png']:
|
for cache_file in ['pfp_character.png', 'pfp_character_thumb.png']:
|
||||||
cache_path = Path(f"user_data/cache/{cache_file}")
|
cache_path = shared.user_data_dir / "cache" / cache_file
|
||||||
if cache_path.exists():
|
if cache_path.exists():
|
||||||
cache_path.unlink()
|
cache_path.unlink()
|
||||||
|
|
||||||
|
|
@ -160,8 +160,8 @@ def create_interface():
|
||||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
|
|
||||||
# Audio notification
|
# Audio notification
|
||||||
if Path("user_data/notification.mp3").exists():
|
if (shared.user_data_dir / "notification.mp3").exists():
|
||||||
shared.gradio['audio_notification'] = gr.Audio(interactive=False, value="user_data/notification.mp3", elem_id="audio_notification", visible=False)
|
shared.gradio['audio_notification'] = gr.Audio(interactive=False, value=str(shared.user_data_dir / "notification.mp3"), elem_id="audio_notification", visible=False)
|
||||||
|
|
||||||
# Floating menus for saving/deleting files
|
# Floating menus for saving/deleting files
|
||||||
ui_file_saving.create_ui()
|
ui_file_saving.create_ui()
|
||||||
|
|
@ -244,8 +244,8 @@ if __name__ == "__main__":
|
||||||
settings_file = None
|
settings_file = None
|
||||||
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
||||||
settings_file = Path(shared.args.settings)
|
settings_file = Path(shared.args.settings)
|
||||||
elif Path('user_data/settings.yaml').exists():
|
elif (shared.user_data_dir / 'settings.yaml').exists():
|
||||||
settings_file = Path('user_data/settings.yaml')
|
settings_file = shared.user_data_dir / 'settings.yaml'
|
||||||
|
|
||||||
if settings_file is not None:
|
if settings_file is not None:
|
||||||
logger.info(f"Loading settings from \"{settings_file}\"")
|
logger.info(f"Loading settings from \"{settings_file}\"")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue