Clean up LoRA loading parameter handling

This commit is contained in:
oobabooga 2026-03-05 15:59:49 -03:00
parent 7a1fa8c9ea
commit 33ff3773a0

View file

@ -77,9 +77,7 @@ def add_lora_transformers(lora_names):
if len(lora_names) > 0:
params = {}
if not shared.args.cpu:
if shared.args.load_in_4bit or shared.args.load_in_8bit:
params['peft_type'] = shared.model.dtype
else:
if not shared.args.load_in_4bit and not shared.args.load_in_8bit:
params['dtype'] = shared.model.dtype
if hasattr(shared.model, "hf_device_map"):
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}