diff --git a/modules/transformers_loader.py b/modules/transformers_loader.py index 7f521b8c..5964f012 100644 --- a/modules/transformers_loader.py +++ b/modules/transformers_loader.py @@ -109,7 +109,6 @@ def load_model_HF(model_name): params = { 'low_cpu_mem_usage': True, 'attn_implementation': shared.args.attn_implementation, - 'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16, } if shared.original_args.trust_remote_code: @@ -120,6 +119,17 @@ def load_model_HF(model_name): config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.original_args.trust_remote_code) + # Determine torch_dtype: respect --bf16 flag, otherwise autodetect + # from model config, but never allow float32. + if shared.args.bf16: + params['torch_dtype'] = torch.bfloat16 + else: + dtype = getattr(config, 'torch_dtype', None) or getattr(getattr(config, 'text_config', None), 'torch_dtype', None) + if dtype in (torch.float16, torch.bfloat16): + params['torch_dtype'] = dtype + else: + params['torch_dtype'] = torch.float16 + if 'chatglm' in model_name.lower(): LoaderClass = AutoModel else: