diff --git a/modules/LoRA.py b/modules/LoRA.py index a9e9a895..63d764b4 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -77,9 +77,7 @@ def add_lora_transformers(lora_names): if len(lora_names) > 0: params = {} if not shared.args.cpu: - if shared.args.load_in_4bit or shared.args.load_in_8bit: - params['peft_type'] = shared.model.dtype - else: + if not shared.args.load_in_4bit and not shared.args.load_in_8bit: params['dtype'] = shared.model.dtype if hasattr(shared.model, "hf_device_map"): params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}