Transformers loader: replace use_flash_attention_2/use_eager_attention with a unified attn_implementation

Closes #7107
This commit is contained in:
oobabooga 2025-07-09 18:38:45 -07:00
parent 511bb31646
commit 6c2bdda0f0
6 changed files with 5 additions and 20 deletions

View file

@ -15,7 +15,6 @@ from modules.logging_colors import logger
def get_fallback_settings():
return {
'bf16': False,
'use_eager_attention': False,
'ctx_size': 2048,
'rope_freq_base': 0,
'compress_pos_emb': 1,
@ -118,14 +117,9 @@ def get_model_metadata(model):
if metadata['rope_scaling']['type'] == 'linear':
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
# For Gemma-2
if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16':
model_settings['bf16'] = True
# For Gemma-2
if 'architectures' in metadata and isinstance(metadata['architectures'], list) and 'Gemma2ForCausalLM' in metadata['architectures']:
model_settings['use_eager_attention'] = True
# Try to find the Jinja instruct template
path = Path(f'{shared.args.model_dir}/{model}') / 'tokenizer_config.json'
if path.exists():