diff --git a/modules/transformers_loader.py b/modules/transformers_loader.py index e4072125..7866f448 100644 --- a/modules/transformers_loader.py +++ b/modules/transformers_loader.py @@ -137,6 +137,7 @@ def load_model_HF(model_name): params = { 'low_cpu_mem_usage': True, 'attn_implementation': shared.args.attn_implementation, + 'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16, } if shared.args.trust_remote_code: