text-generation-webui/modules/LoRA.py
oobabooga 2f08dce7b0 Remove ExLlamaV2 backend
- archived upstream: 7dc12af3a8
- replaced by ExLlamaV3, which has much better quantization accuracy
2026-03-05 14:02:13 -08:00

81 lines
2.7 KiB
Python

from pathlib import Path
import modules.shared as shared
from modules.logging_colors import logger
def add_lora_to_model(lora_names):
add_lora_transformers(lora_names)
def get_lora_path(lora_name):
p = Path(lora_name)
if p.exists():
lora_name = p.parts[-1]
return Path(f"{shared.args.lora_dir}/{lora_name}")
def add_lora_transformers(lora_names):
from peft import PeftModel
from modules.torch_utils import get_device
prior_set = set(shared.lora_names)
added_set = set(lora_names) - prior_set
removed_set = prior_set - set(lora_names)
# If no LoRA needs to be added or removed, exit
if len(added_set) == 0 and len(removed_set) == 0:
return
# Add a LoRA when another LoRA is already present
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
logger.info(f"Adding the LoRA(s) named {added_set} to the model")
for lora in added_set:
shared.model.load_adapter(get_lora_path(lora), lora)
if len(lora_names) > 1:
merge_loras()
shared.lora_names = lora_names
return
# If any LoRA needs to be removed, start over
if len(removed_set) > 0:
shared.model = shared.model.unload()
if len(lora_names) > 0:
params = {}
if not shared.args.cpu:
if not shared.args.load_in_4bit and not shared.args.load_in_8bit:
params['dtype'] = shared.model.dtype
if hasattr(shared.model, "hf_device_map"):
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
shared.model = PeftModel.from_pretrained(shared.model, get_lora_path(lora_names[0]), adapter_name=lora_names[0], **params)
for lora in lora_names[1:]:
shared.model.load_adapter(get_lora_path(lora), lora)
if len(lora_names) > 1:
merge_loras()
if not shared.args.load_in_8bit and not shared.args.cpu:
shared.model.half()
if not hasattr(shared.model, "hf_device_map"):
device = get_device()
if device:
shared.model = shared.model.to(device)
shared.lora_names = lora_names
def merge_loras():
if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1:
logger.warning("The loaded LoRAs cannot be merged, as they have dissimilar ranks. Only the first one will be active.")
return
shared.model.add_weighted_adapter(shared.lora_names, [1] * len(shared.lora_names), "__merged")
shared.model.set_adapter("__merged")