diff --git a/modules/models_settings.py b/modules/models_settings.py index 2e3fff9c..7ae68125 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -9,6 +9,8 @@ from modules import chat, loaders, metadata_gguf, shared, ui def get_fallback_settings(): return { + 'bf16': False, + 'use_eager_attention': False, 'wbits': 'None', 'groupsize': 'None', 'desc_act': False, @@ -97,10 +99,18 @@ def get_model_metadata(model): elif 'attn_config' in metadata and 'rope_theta' in metadata['attn_config']: model_settings['rope_freq_base'] = metadata['attn_config']['rope_theta'] - if 'rope_scaling' in metadata and type(metadata['rope_scaling']) is dict and all(key in metadata['rope_scaling'] for key in ('type', 'factor')): + if 'rope_scaling' in metadata and isinstance(metadata['rope_scaling'], dict) and all(key in metadata['rope_scaling'] for key in ('type', 'factor')): if metadata['rope_scaling']['type'] == 'linear': model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor'] + # For Gemma-2 + if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16': + model_settings['bf16'] = True + + # For Gemma-2 + if 'architectures' in metadata and isinstance(metadata['architectures'], list) and 'Gemma2ForCausalLM' in metadata['architectures']: + model_settings['use_eager_attention'] = True + # Read GPTQ metadata for old GPTQ loaders if 'quantization_config' in metadata and metadata['quantization_config'].get('quant_method', '') != 'exl2': if 'bits' in metadata['quantization_config']: @@ -133,7 +143,7 @@ def get_model_metadata(model): for k in ['eos_token', 'bos_token']: if k in metadata: value = metadata[k] - if type(value) is dict: + if isinstance(value, dict): value = value['content'] template = template.replace(k, "'{}'".format(value)) @@ -168,7 +178,7 @@ def infer_loader(model_name, model_settings): path_to_model = Path(f'{shared.args.model_dir}/{model_name}') if not path_to_model.exists(): loader = None - elif (path_to_model / 'quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0): + elif (path_to_model / 'quantize_config.json').exists() or ('wbits' in model_settings and isinstance(model_settings['wbits'], int) and model_settings['wbits'] > 0): loader = 'ExLlamav2_HF' elif (path_to_model / 'quant_config.json').exists() or re.match(r'.*-awq', model_name.lower()): loader = 'AutoAWQ'