Don't pass torch_dtype to transformers, autodetect from model config

This commit is contained in:
oobabooga 2026-04-01 20:28:44 -07:00
parent 4073164be0
commit a32ce254f2

View file

@ -109,7 +109,6 @@ def load_model_HF(model_name):
params = {
'low_cpu_mem_usage': True,
'attn_implementation': shared.args.attn_implementation,
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16,
}
if shared.original_args.trust_remote_code:
@ -120,6 +119,17 @@ def load_model_HF(model_name):
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.original_args.trust_remote_code)
# Determine torch_dtype: respect --bf16 flag, otherwise autodetect
# from model config, but never allow float32.
if shared.args.bf16:
params['torch_dtype'] = torch.bfloat16
else:
dtype = getattr(config, 'torch_dtype', None) or getattr(getattr(config, 'text_config', None), 'torch_dtype', None)
if dtype in (torch.float16, torch.bfloat16):
params['torch_dtype'] = dtype
else:
params['torch_dtype'] = torch.float16
if 'chatglm' in model_name.lower():
LoaderClass = AutoModel
else: