diff --git a/extensions/Training_PRO/script.py b/extensions/Training_PRO/script.py index cb11a8df..e2f90f17 100644 --- a/extensions/Training_PRO/script.py +++ b/extensions/Training_PRO/script.py @@ -823,7 +823,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch lora_model = get_peft_model(shared.model, config) if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file(): logger.info("Loading existing LoRA data...") - state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin") + state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True) set_peft_model_state_dict(lora_model, state_dict_peft) print(f" + Continue Training on {RED}{lora_file_path}/adapter_model.bin{RESET}")