Changing transformers_loader.py to Match User Expectations for --bf16 and Flash Attention 2 (#7217)

This commit is contained in:
stevenxdavis 2025-09-17 14:39:04 -05:00 committed by GitHub
parent 9e9ab39892
commit dd6d2223a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -137,6 +137,7 @@ def load_model_HF(model_name):
params = {
'low_cpu_mem_usage': True,
'attn_implementation': shared.args.attn_implementation,
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16,
}
if shared.args.trust_remote_code: