Don't pass torch_dtype to transformers loader, let it be autodetected

This commit is contained in:
oobabooga 2025-08-05 11:35:53 -07:00
parent 3039aeffeb
commit 3b28dc1821

View file

@ -136,7 +136,6 @@ def load_model_HF(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
params = {
'low_cpu_mem_usage': True,
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16,
'attn_implementation': shared.args.attn_implementation,
}