mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-06 15:13:38 +00:00
Don't pass torch_dtype to transformers, autodetect from model config
This commit is contained in:
parent
4073164be0
commit
a32ce254f2
1 changed files with 11 additions and 1 deletions
|
|
@ -109,7 +109,6 @@ def load_model_HF(model_name):
|
|||
params = {
|
||||
'low_cpu_mem_usage': True,
|
||||
'attn_implementation': shared.args.attn_implementation,
|
||||
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16,
|
||||
}
|
||||
|
||||
if shared.original_args.trust_remote_code:
|
||||
|
|
@ -120,6 +119,17 @@ def load_model_HF(model_name):
|
|||
|
||||
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.original_args.trust_remote_code)
|
||||
|
||||
# Determine torch_dtype: respect --bf16 flag, otherwise autodetect
|
||||
# from model config, but never allow float32.
|
||||
if shared.args.bf16:
|
||||
params['torch_dtype'] = torch.bfloat16
|
||||
else:
|
||||
dtype = getattr(config, 'torch_dtype', None) or getattr(getattr(config, 'text_config', None), 'torch_dtype', None)
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
params['torch_dtype'] = dtype
|
||||
else:
|
||||
params['torch_dtype'] = torch.float16
|
||||
|
||||
if 'chatglm' in model_name.lower():
|
||||
LoaderClass = AutoModel
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue