2023-05-06 04:14:56 +02:00
import os
import re
2023-06-14 01:34:35 +02:00
from datetime import datetime
2023-05-06 04:14:56 +02:00
from pathlib import Path
2025-06-08 18:20:21 +02:00
from modules import shared
2023-06-11 17:19:18 +02:00
from modules . logging_colors import logger
2023-07-04 05:03:30 +02:00
# Helper function to get multiple values from shared.gradio
def gradio ( * keys ) :
2023-08-13 06:12:15 +02:00
if len ( keys ) == 1 and type ( keys [ 0 ] ) in [ list , tuple ] :
2023-07-04 05:03:30 +02:00
keys = keys [ 0 ]
return [ shared . gradio [ k ] for k in keys ]
2026-03-06 04:26:21 +01:00
def _is_path_allowed ( abs_path_str ) :
""" Check if a path is under the project root or the configured user_data directory. """
abs_path = Path ( abs_path_str ) . resolve ( )
root_folder = Path ( __file__ ) . resolve ( ) . parent . parent
user_data_resolved = shared . user_data_dir . resolve ( )
try :
abs_path . relative_to ( root_folder )
return True
except ValueError :
pass
try :
abs_path . relative_to ( user_data_resolved )
return True
except ValueError :
pass
return False
2023-06-11 17:19:18 +02:00
def save_file ( fname , contents ) :
if fname == ' ' :
logger . error ( ' File name is empty! ' )
return
2023-12-13 22:08:21 +01:00
abs_path_str = os . path . abspath ( fname )
2026-03-06 04:26:21 +01:00
if not _is_path_allowed ( abs_path_str ) :
2024-01-09 12:32:01 +01:00
logger . error ( f ' Invalid file path: \" { fname } \" ' )
2023-06-11 17:19:18 +02:00
return
2023-12-13 22:08:21 +01:00
with open ( abs_path_str , ' w ' , encoding = ' utf-8 ' ) as f :
2023-06-11 17:19:18 +02:00
f . write ( contents )
2024-01-09 12:32:01 +01:00
logger . info ( f ' Saved \" { abs_path_str } \" . ' )
2023-06-11 17:19:18 +02:00
def delete_file ( fname ) :
if fname == ' ' :
logger . error ( ' File name is empty! ' )
return
2023-12-13 22:08:21 +01:00
abs_path_str = os . path . abspath ( fname )
2026-03-06 04:26:21 +01:00
if not _is_path_allowed ( abs_path_str ) :
2024-01-09 12:32:01 +01:00
logger . error ( f ' Invalid file path: \" { fname } \" ' )
2023-06-11 17:19:18 +02:00
return
2026-03-06 04:26:21 +01:00
p = Path ( abs_path_str )
if p . exists ( ) :
p . unlink ( )
2024-01-09 12:32:01 +01:00
logger . info ( f ' Deleted \" { fname } \" . ' )
2023-05-06 04:14:56 +02:00
2023-06-14 01:34:35 +02:00
def current_time ( ) :
2025-06-18 00:10:44 +02:00
return f " { datetime . now ( ) . strftime ( ' % Y- % m- %d _ % Hh % Mm % Ss ' ) } "
2023-06-14 01:34:35 +02:00
2023-05-06 04:14:56 +02:00
def atoi ( text ) :
return int ( text ) if text . isdigit ( ) else text . lower ( )
2023-05-10 06:34:04 +02:00
# Replace multiple string pairs in a string
def replace_all ( text , dic ) :
for i , j in dic . items ( ) :
text = text . replace ( i , j )
return text
2023-05-06 04:14:56 +02:00
def natural_keys ( text ) :
return [ atoi ( c ) for c in re . split ( r ' ( \ d+) ' , text ) ]
2025-05-09 16:21:05 +02:00
def check_model_loaded ( ) :
if shared . model_name == ' None ' or shared . model is None :
2025-05-29 19:49:29 +02:00
if len ( get_available_models ( ) ) == 0 :
2026-03-06 04:26:21 +01:00
error_msg = f " No model is loaded. \n \n To get started: \n 1) Place a GGUF file in your { shared . user_data_dir } /models folder \n 2) Go to the Model tab and select it "
2025-05-09 16:21:05 +02:00
logger . error ( error_msg )
return False , error_msg
else :
error_msg = " No model is loaded. Please select one in the Model tab. "
logger . error ( error_msg )
return False , error_msg
return True , None
2025-12-02 18:55:38 +01:00
def resolve_model_path ( model_name_or_path , image_model = False ) :
2025-08-22 20:46:02 +02:00
"""
Resolves a model path , checking for a direct path
before the default models directory .
"""
path_candidate = Path ( model_name_or_path )
if path_candidate . exists ( ) :
return path_candidate
2025-12-02 18:55:38 +01:00
elif image_model :
return Path ( f ' { shared . args . image_model_dir } / { model_name_or_path } ' )
2025-08-22 20:46:02 +02:00
else :
return Path ( f ' { shared . args . model_dir } / { model_name_or_path } ' )
2023-05-06 04:14:56 +02:00
def get_available_models ( ) :
2025-04-18 07:53:59 +02:00
# Get all GGUF files
gguf_files = get_available_ggufs ( )
2025-04-23 04:56:42 +02:00
# Filter out non-first parts of multipart GGUF files
filtered_gguf_files = [ ]
for gguf_path in gguf_files :
filename = os . path . basename ( gguf_path )
match = re . search ( r ' -( \ d+)-of- \ d+ \ .gguf$ ' , filename )
if match :
part_number = match . group ( 1 )
# Keep only if it's part 1
if part_number . lstrip ( " 0 " ) == " 1 " :
filtered_gguf_files . append ( gguf_path )
else :
# Not a multi-part file
filtered_gguf_files . append ( gguf_path )
2025-04-18 07:53:59 +02:00
model_dir = Path ( shared . args . model_dir )
# Find top-level directories containing GGUF files
dirs_with_gguf = set ( )
for gguf_path in gguf_files :
path = Path ( gguf_path )
2025-04-26 13:56:54 +02:00
if len ( path . parts ) > 0 :
2025-04-23 04:56:42 +02:00
dirs_with_gguf . add ( path . parts [ 0 ] )
2025-04-18 07:53:59 +02:00
2025-04-23 04:56:42 +02:00
# Find directories with safetensors files
2025-04-18 07:53:59 +02:00
dirs_with_safetensors = set ( )
for item in os . listdir ( model_dir ) :
item_path = model_dir / item
if item_path . is_dir ( ) :
2025-04-18 08:10:43 +02:00
if any ( file . lower ( ) . endswith ( ( ' .safetensors ' , ' .pt ' ) ) for file in os . listdir ( item_path ) if ( item_path / file ) . is_file ( ) ) :
2025-04-18 07:53:59 +02:00
dirs_with_safetensors . add ( item )
# Find valid model directories
model_dirs = [ ]
for item in os . listdir ( model_dir ) :
item_path = model_dir / item
if not item_path . is_dir ( ) :
continue
2025-04-23 04:56:42 +02:00
# Include directory if it either doesn't contain GGUF files
# or contains both GGUF and safetensors files
2025-04-18 07:53:59 +02:00
if item not in dirs_with_gguf or item in dirs_with_safetensors :
model_dirs . append ( item )
model_dirs = sorted ( model_dirs , key = natural_keys )
2025-05-17 02:53:20 +02:00
return filtered_gguf_files + model_dirs
2024-02-16 16:43:24 +01:00
2025-12-02 18:55:38 +01:00
def get_available_image_models ( ) :
model_dir = Path ( shared . args . image_model_dir )
model_dir . mkdir ( parents = True , exist_ok = True )
# Find valid model directories
model_dirs = [ ]
for item in os . listdir ( model_dir ) :
item_path = model_dir / item
if not item_path . is_dir ( ) :
continue
model_dirs . append ( item )
model_dirs = sorted ( model_dirs , key = natural_keys )
return model_dirs
2024-02-16 16:43:24 +01:00
def get_available_ggufs ( ) :
model_list = [ ]
2025-04-18 07:53:59 +02:00
model_dir = Path ( shared . args . model_dir )
for dirpath , _ , files in os . walk ( model_dir , followlinks = True ) :
for file in files :
if file . lower ( ) . endswith ( " .gguf " ) :
model_path = Path ( dirpath ) / file
rel_path = model_path . relative_to ( model_dir )
model_list . append ( str ( rel_path ) )
2023-08-10 19:01:12 +02:00
2025-04-18 07:53:59 +02:00
return sorted ( model_list , key = natural_keys )
2023-05-06 04:14:56 +02:00
2025-08-10 06:27:25 +02:00
def get_available_mmproj ( ) :
2026-03-06 04:26:21 +01:00
mmproj_dir = shared . user_data_dir / ' mmproj '
2025-08-10 06:27:25 +02:00
if not mmproj_dir . exists ( ) :
return [ ' None ' ]
mmproj_files = [ ]
for item in mmproj_dir . iterdir ( ) :
if item . is_file ( ) and item . suffix . lower ( ) in ( ' .gguf ' , ' .bin ' ) :
mmproj_files . append ( item . name )
return [ ' None ' ] + sorted ( mmproj_files , key = natural_keys )
2023-05-06 04:14:56 +02:00
def get_available_presets ( ) :
2026-03-06 04:26:21 +01:00
return sorted ( set ( ( k . stem for k in ( shared . user_data_dir / ' presets ' ) . glob ( ' *.yaml ' ) ) ) , key = natural_keys )
2023-05-06 04:14:56 +02:00
def get_available_prompts ( ) :
2026-03-06 04:26:21 +01:00
notebook_dir = shared . user_data_dir / ' logs ' / ' notebook '
2025-06-17 16:11:59 +02:00
notebook_dir . mkdir ( parents = True , exist_ok = True )
2025-06-16 22:36:05 +02:00
prompt_files = list ( notebook_dir . glob ( ' *.txt ' ) )
2026-03-05 04:07:55 +01:00
if not prompt_files :
new_name = current_time ( )
new_path = notebook_dir / f " { new_name } .txt "
new_path . write_text ( " In this story, " , encoding = ' utf-8 ' )
prompt_files = [ new_path ]
2024-07-21 05:01:42 +02:00
sorted_files = sorted ( prompt_files , key = lambda x : x . stat ( ) . st_mtime , reverse = True )
prompts = [ file . stem for file in sorted_files ]
2023-05-06 04:14:56 +02:00
return prompts
def get_available_characters ( ) :
2026-03-06 04:26:21 +01:00
paths = ( x for x in ( shared . user_data_dir / ' characters ' ) . iterdir ( ) if x . suffix in ( ' .json ' , ' .yaml ' , ' .yml ' ) )
2023-09-21 22:19:32 +02:00
return sorted ( set ( ( k . stem for k in paths ) ) , key = natural_keys )
2023-05-06 04:14:56 +02:00
2026-01-15 00:35:08 +01:00
def get_available_users ( ) :
2026-03-06 04:26:21 +01:00
users_dir = shared . user_data_dir / ' users '
2026-01-15 00:35:08 +01:00
users_dir . mkdir ( parents = True , exist_ok = True )
paths = ( x for x in users_dir . iterdir ( ) if x . suffix in ( ' .json ' , ' .yaml ' , ' .yml ' ) )
return sorted ( set ( ( k . stem for k in paths ) ) , key = natural_keys )
2023-05-06 04:14:56 +02:00
def get_available_instruction_templates ( ) :
2026-03-06 04:26:21 +01:00
path = str ( shared . user_data_dir / " instruction-templates " )
2023-05-06 04:14:56 +02:00
paths = [ ]
if os . path . exists ( path ) :
paths = ( x for x in Path ( path ) . iterdir ( ) if x . suffix in ( ' .json ' , ' .yaml ' , ' .yml ' ) )
2023-05-10 06:34:04 +02:00
2024-02-16 18:21:17 +01:00
return [ ' None ' ] + sorted ( set ( ( k . stem for k in paths ) ) , key = natural_keys )
2023-05-06 04:14:56 +02:00
def get_available_extensions ( ) :
2025-07-07 02:29:29 +02:00
# User extensions (higher priority)
user_extensions = [ ]
2026-03-06 04:26:21 +01:00
user_ext_path = shared . user_data_dir / ' extensions '
2025-07-07 02:29:29 +02:00
if user_ext_path . exists ( ) :
2026-03-06 04:26:21 +01:00
user_exts = map ( lambda x : x . parent . name , user_ext_path . glob ( ' */script.py ' ) )
2025-07-07 02:29:29 +02:00
user_extensions = sorted ( set ( user_exts ) , key = natural_keys )
# System extensions (excluding those overridden by user extensions)
2026-03-06 04:26:21 +01:00
system_exts = map ( lambda x : x . parent . name , Path ( ' extensions ' ) . glob ( ' */script.py ' ) )
2025-07-07 02:29:29 +02:00
system_extensions = sorted ( set ( system_exts ) - set ( user_extensions ) , key = natural_keys )
return user_extensions + system_extensions
2023-05-06 04:14:56 +02:00
def get_available_loras ( ) :
2023-10-11 03:20:49 +02:00
return [ ' None ' ] + sorted ( [ item . name for item in list ( Path ( shared . args . lora_dir ) . glob ( ' * ' ) ) if not item . name . endswith ( ( ' .txt ' , ' -np ' , ' .pt ' , ' .json ' ) ) ] , key = natural_keys )
2023-05-06 04:14:56 +02:00
def get_datasets ( path : str , ext : str ) :
2023-07-12 16:44:30 +02:00
# include subdirectories for raw txt files to allow training from a subdirectory of txt files
if ext == " txt " :
2023-11-07 23:41:11 +01:00
return [ ' None ' ] + sorted ( set ( [ k . stem for k in list ( Path ( path ) . glob ( ' *.txt ' ) ) + list ( Path ( path ) . glob ( ' */ ' ) ) if k . stem != ' put-trainer-datasets-here ' ] ) , key = natural_keys )
2023-07-12 16:44:30 +02:00
2023-05-06 04:14:56 +02:00
return [ ' None ' ] + sorted ( set ( [ k . stem for k in Path ( path ) . glob ( f ' *. { ext } ' ) if k . stem != ' put-trainer-datasets-here ' ] ) , key = natural_keys )
2023-05-08 17:35:03 +02:00
2026-03-05 20:15:16 +01:00
def get_chat_datasets ( path : str ) :
""" List JSON datasets that contain chat conversations (messages or ShareGPT format). """
return [ ' None ' ] + sorted ( set ( [ k . stem for k in Path ( path ) . glob ( ' *.json ' ) if k . stem != ' put-trainer-datasets-here ' and _is_chat_dataset ( k ) ] ) , key = natural_keys )
def get_text_datasets ( path : str ) :
""" List JSON datasets that contain raw text ( { " text " : ...} format). """
return [ ' None ' ] + sorted ( set ( [ k . stem for k in Path ( path ) . glob ( ' *.json ' ) if k . stem != ' put-trainer-datasets-here ' and _is_text_dataset ( k ) ] ) , key = natural_keys )
def _peek_json_keys ( filepath ) :
""" Read the first object in a JSON array file and return its keys. """
import json
try :
with open ( filepath , ' r ' , encoding = ' utf-8 ' ) as f :
data = json . load ( f )
if isinstance ( data , list ) and len ( data ) > 0 and isinstance ( data [ 0 ] , dict ) :
return set ( data [ 0 ] . keys ( ) )
except Exception :
pass
return set ( )
def _is_chat_dataset ( filepath ) :
keys = _peek_json_keys ( filepath )
return bool ( keys & { ' messages ' , ' conversations ' } )
def _is_text_dataset ( filepath ) :
keys = _peek_json_keys ( filepath )
return ' text ' in keys
2023-05-08 17:35:03 +02:00
def get_available_chat_styles ( ) :
return sorted ( set ( ( ' - ' . join ( k . stem . split ( ' - ' ) [ 1 : ] ) for k in Path ( ' css ' ) . glob ( ' chat_style*.css ' ) ) ) , key = natural_keys )
2023-09-24 16:08:41 +02:00
def get_available_grammars ( ) :
2026-03-06 04:26:21 +01:00
return [ ' None ' ] + sorted ( [ item . name for item in list ( ( shared . user_data_dir / ' grammars ' ) . glob ( ' *.gbnf ' ) ) ] , key = natural_keys )