From 6c2bdda0f02f497ee91f8c13afab7de149823ff0 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 9 Jul 2025 18:38:45 -0700 Subject: [PATCH] Transformers loader: replace `use_flash_attention_2`/`use_eager_attention` with a unified `attn_implementation` Closes #7107 --- modules/loaders.py | 3 +-- modules/models_settings.py | 6 ------ modules/shared.py | 3 +-- modules/transformers_loader.py | 7 +------ modules/ui.py | 3 +-- modules/ui_model_menu.py | 3 +-- 6 files changed, 5 insertions(+), 20 deletions(-) diff --git a/modules/loaders.py b/modules/loaders.py index 6fbd2198..2b2c6b78 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -40,11 +40,10 @@ loaders_and_params = OrderedDict({ 'load_in_8bit', 'load_in_4bit', 'torch_compile', - 'use_flash_attention_2', + 'attn_implementation', 'cpu', 'disk', 'use_double_quant', - 'use_eager_attention', 'bf16', 'trust_remote_code', 'no_use_fast', diff --git a/modules/models_settings.py b/modules/models_settings.py index bea5b4d6..a06e594e 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -15,7 +15,6 @@ from modules.logging_colors import logger def get_fallback_settings(): return { 'bf16': False, - 'use_eager_attention': False, 'ctx_size': 2048, 'rope_freq_base': 0, 'compress_pos_emb': 1, @@ -118,14 +117,9 @@ def get_model_metadata(model): if metadata['rope_scaling']['type'] == 'linear': model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor'] - # For Gemma-2 if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16': model_settings['bf16'] = True - # For Gemma-2 - if 'architectures' in metadata and isinstance(metadata['architectures'], list) and 'Gemma2ForCausalLM' in metadata['architectures']: - model_settings['use_eager_attention'] = True - # Try to find the Jinja instruct template path = Path(f'{shared.args.model_dir}/{model}') / 'tokenizer_config.json' if path.exists(): diff --git a/modules/shared.py b/modules/shared.py index 5333ec4f..57649114 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -61,8 +61,7 @@ group.add_argument('--no-cache', action='store_true', help='Set use_cache to Fal group.add_argument('--trust-remote-code', action='store_true', help='Set trust_remote_code=True while loading the model. Necessary for some models.') group.add_argument('--force-safetensors', action='store_true', help='Set use_safetensors=True while loading the model. This prevents arbitrary code execution.') group.add_argument('--no_use_fast', action='store_true', help='Set use_fast=False while loading the tokenizer (it\'s True by default). Use this if you have any problems related to use_fast.') -group.add_argument('--use_flash_attention_2', action='store_true', help='Set use_flash_attention_2=True while loading the model.') -group.add_argument('--use_eager_attention', action='store_true', help='Set attn_implementation= eager while loading the model.') +group.add_argument('--attn-implementation', type=str, default='sdpa', help='Attention implementation. Valid options: sdpa, eager, flash_attention_2.') group.add_argument('--torch-compile', action='store_true', help='Compile the model with torch.compile for improved performance.') # bitsandbytes 4-bit diff --git a/modules/transformers_loader.py b/modules/transformers_loader.py index 905f5c47..ef524b57 100644 --- a/modules/transformers_loader.py +++ b/modules/transformers_loader.py @@ -135,20 +135,15 @@ def load_model_HF(model_name): params = { 'low_cpu_mem_usage': True, 'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16, + 'attn_implementation': shared.args.attn_implementation, } if shared.args.trust_remote_code: params['trust_remote_code'] = True - if shared.args.use_flash_attention_2: - params['use_flash_attention_2'] = True - if shared.args.force_safetensors: params['force_safetensors'] = True - if shared.args.use_eager_attention: - params['attn_implementation'] = 'eager' - config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code) if 'chatglm' in model_name.lower(): diff --git a/modules/ui.py b/modules/ui.py index 0030bb02..9d6777f0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -144,7 +144,7 @@ def list_model_elements(): 'load_in_4bit', 'torch_compile', 'flash_attn', - 'use_flash_attention_2', + 'attn_implementation', 'cpu', 'disk', 'row_split', @@ -153,7 +153,6 @@ def list_model_elements(): 'mlock', 'numa', 'use_double_quant', - 'use_eager_attention', 'bf16', 'autosplit', 'enable_tp', diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index e09e292e..2018a943 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -44,6 +44,7 @@ def create_ui(): shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=0, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Must be greater than 0 for the GPU to be used. ⚠️ Lower this value if you can\'t load the model.') shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=256, maximum=131072, step=256, value=shared.args.ctx_size, info='Context length. Common values: 4096, 8192, 16384, 32768, 65536, 131072. ⚠️ Lower this value if you can\'t load the model.') shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7') + shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.') shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).') with gr.Column(): shared.gradio['vram_info'] = gr.HTML(value=get_initial_vram_info()) @@ -52,7 +53,6 @@ def create_ui(): shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit) shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit) shared.gradio['torch_compile'] = gr.Checkbox(label="torch-compile", value=shared.args.torch_compile, info='Compile the model with torch.compile for improved performance.') - shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.') shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant, info='Used by load-in-4bit.') shared.gradio['autosplit'] = gr.Checkbox(label="autosplit", value=shared.args.autosplit, info='Automatically split the model tensors across the available GPUs.') shared.gradio['enable_tp'] = gr.Checkbox(label="enable_tp", value=shared.args.enable_tp, info='Enable Tensor Parallelism (TP).') @@ -96,7 +96,6 @@ def create_ui(): shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap) shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock) shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.') - shared.gradio['use_eager_attention'] = gr.Checkbox(label="use_eager_attention", value=shared.args.use_eager_attention, info='Set attn_implementation= eager while loading the model.') shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16) shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn) shared.gradio['no_xformers'] = gr.Checkbox(label="no_xformers", value=shared.args.no_xformers)