text-generation-webui/modules/models.py

191 lines
5.9 KiB
Python
Raw Normal View History

import sys
2023-02-23 17:28:30 +01:00
import time
2023-02-23 18:41:42 +01:00
import modules.shared as shared
from modules.logging_colors import logger
2023-09-11 23:49:30 +02:00
from modules.models_settings import get_model_metadata
from modules.utils import resolve_model_path
2023-02-23 17:28:30 +01:00
last_generation_time = time.time()
def load_model(model_name, loader=None):
2024-02-06 17:22:08 +01:00
logger.info(f"Loading \"{model_name}\"")
2023-02-23 17:28:30 +01:00
t0 = time.time()
shared.is_seq2seq = False
shared.model_name = model_name
load_func_map = {
2025-04-18 14:59:37 +02:00
'llama.cpp': llama_cpp_server_loader,
'Transformers': transformers_loader,
2025-04-09 05:07:08 +02:00
'ExLlamav3_HF': ExLlamav3_HF_loader,
'ExLlamav3': ExLlamav3_loader,
'ExLlamav2_HF': ExLlamav2_HF_loader,
2025-04-09 05:07:08 +02:00
'ExLlamav2': ExLlamav2_loader,
2024-06-24 07:30:03 +02:00
'TensorRT-LLM': TensorRT_LLM_loader,
'ktransformers': ktransformers_loader,
}
metadata = get_model_metadata(model_name)
if loader is None:
if shared.args.loader is not None:
loader = shared.args.loader
else:
loader = metadata['loader']
if loader is None:
logger.error('The path to the model does not exist. Exiting.')
raise ValueError
2023-05-17 00:52:22 +02:00
if loader != 'llama.cpp' and 'sampler_hijack' not in sys.modules:
from modules import sampler_hijack
sampler_hijack.hijack_samplers()
shared.args.loader = loader
output = load_func_map[loader](model_name)
2023-05-17 00:52:22 +02:00
if type(output) is tuple:
model, tokenizer = output
else:
model = output
2025-09-02 19:16:26 +02:00
if model is not None:
from modules.transformers_loader import load_tokenizer
tokenizer = load_tokenizer(model_name)
2023-05-17 00:52:22 +02:00
2025-09-02 19:16:26 +02:00
if model is None:
return None, None
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
shared.settings['truncation_length'] = shared.args.ctx_size
2025-08-19 15:50:40 +02:00
shared.is_multimodal = False
if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'):
shared.is_multimodal = model.is_multimodal()
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
2024-02-06 15:31:27 +01:00
logger.info(f"LOADER: \"{loader}\"")
logger.info(f"TRUNCATION LENGTH: {shared.settings['truncation_length']}")
2024-02-06 15:31:27 +01:00
logger.info(f"INSTRUCTION TEMPLATE: \"{metadata['instruction_template']}\"")
2023-05-17 00:52:22 +02:00
return model, tokenizer
2025-04-18 14:59:37 +02:00
def llama_cpp_server_loader(model_name):
from modules.llama_cpp_server import LlamaServer
2023-05-17 00:52:22 +02:00
path = resolve_model_path(model_name)
2023-05-17 00:52:22 +02:00
if path.is_file():
model_file = path
2023-02-23 17:28:30 +01:00
else:
gguf_files = sorted(path.glob('*.gguf'))
if not gguf_files:
logger.error(f"No .gguf models found in the directory: {path}")
return None, None
model_file = gguf_files[0]
2023-02-23 17:28:30 +01:00
2025-04-18 14:59:37 +02:00
try:
model = LlamaServer(model_file)
return model, model
except Exception as e:
logger.error(f"Error loading the model with llama.cpp: {str(e)}")
return None, None
2023-07-16 07:21:13 +02:00
def transformers_loader(model_name):
from modules.transformers_loader import load_model_HF
return load_model_HF(model_name)
2025-04-09 05:07:08 +02:00
def ExLlamav3_HF_loader(model_name):
from modules.exllamav3_hf import Exllamav3HF
2025-04-09 05:07:08 +02:00
return Exllamav3HF.from_pretrained(model_name)
def ExLlamav3_loader(model_name):
from modules.exllamav3 import Exllamav3Model
model, tokenizer = Exllamav3Model.from_pretrained(model_name)
return model, tokenizer
def ExLlamav2_HF_loader(model_name):
from modules.exllamav2_hf import Exllamav2HF
return Exllamav2HF.from_pretrained(model_name)
2025-04-09 05:07:08 +02:00
def ExLlamav2_loader(model_name):
from modules.exllamav2 import Exllamav2Model
model, tokenizer = Exllamav2Model.from_pretrained(model_name)
return model, tokenizer
2024-06-24 07:30:03 +02:00
def TensorRT_LLM_loader(model_name):
try:
from modules.tensorrt_llm import TensorRTLLMModel
except ModuleNotFoundError:
raise ModuleNotFoundError("Failed to import 'tensorrt_llm'. Please install it manually following the instructions in the TensorRT-LLM GitHub repository.")
2024-06-24 07:30:03 +02:00
model = TensorRTLLMModel.from_pretrained(model_name)
return model
def ktransformers_loader(model_name):
try:
import ktransformers # aktiviert die Patches / Beschleuniger
except ModuleNotFoundError as e:
from modules.logging_colors import logger
logger.error("KTransformers ist nicht installiert: pip install ktransformers")
raise
from modules.transformers_loader import load_model_HF
return load_model_HF(model_name)
2024-07-29 03:30:06 +02:00
def unload_model(keep_model_name=False):
if shared.model is None:
return
model_class_name = shared.model.__class__.__name__
is_llamacpp = (model_class_name == 'LlamaServer')
if model_class_name in ['Exllamav3Model', 'Exllamav3HF']:
shared.model.unload()
elif model_class_name in ['Exllamav2Model', 'Exllamav2HF'] and hasattr(shared.model, 'unload'):
shared.model.unload()
2023-04-08 02:36:04 +02:00
shared.model = shared.tokenizer = None
shared.lora_names = []
shared.model_dirty_from_training = False
if not is_llamacpp:
from modules.torch_utils import clear_torch_cache
clear_torch_cache()
2023-04-08 02:36:04 +02:00
2024-07-29 03:30:06 +02:00
if not keep_model_name:
shared.model_name = 'None'
2023-04-08 02:36:04 +02:00
def reload_model():
2023-04-08 02:37:41 +02:00
unload_model()
2023-04-08 02:36:04 +02:00
shared.model, shared.tokenizer = load_model(shared.model_name)
def unload_model_if_idle():
global last_generation_time
logger.info(f"Setting a timeout of {shared.args.idle_timeout} minutes to unload the model in case of inactivity.")
while True:
shared.generation_lock.acquire()
try:
if time.time() - last_generation_time > shared.args.idle_timeout * 60:
if shared.model is not None:
logger.info("Unloading the model for inactivity.")
2024-07-29 03:30:06 +02:00
unload_model(keep_model_name=True)
finally:
shared.generation_lock.release()
time.sleep(60)