Transformers loader: replace use_flash_attention_2/use_eager_attention with a unified attn_implementation

Closes #7107
This commit is contained in:
oobabooga 2025-07-09 18:38:45 -07:00
parent 511bb31646
commit 6c2bdda0f0
6 changed files with 5 additions and 20 deletions

View file

@ -40,11 +40,10 @@ loaders_and_params = OrderedDict({
'load_in_8bit',
'load_in_4bit',
'torch_compile',
'use_flash_attention_2',
'attn_implementation',
'cpu',
'disk',
'use_double_quant',
'use_eager_attention',
'bf16',
'trust_remote_code',
'no_use_fast',

View file

@ -15,7 +15,6 @@ from modules.logging_colors import logger
def get_fallback_settings():
return {
'bf16': False,
'use_eager_attention': False,
'ctx_size': 2048,
'rope_freq_base': 0,
'compress_pos_emb': 1,
@ -118,14 +117,9 @@ def get_model_metadata(model):
if metadata['rope_scaling']['type'] == 'linear':
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
# For Gemma-2
if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16':
model_settings['bf16'] = True
# For Gemma-2
if 'architectures' in metadata and isinstance(metadata['architectures'], list) and 'Gemma2ForCausalLM' in metadata['architectures']:
model_settings['use_eager_attention'] = True
# Try to find the Jinja instruct template
path = Path(f'{shared.args.model_dir}/{model}') / 'tokenizer_config.json'
if path.exists():

View file

@ -61,8 +61,7 @@ group.add_argument('--no-cache', action='store_true', help='Set use_cache to Fal
group.add_argument('--trust-remote-code', action='store_true', help='Set trust_remote_code=True while loading the model. Necessary for some models.')
group.add_argument('--force-safetensors', action='store_true', help='Set use_safetensors=True while loading the model. This prevents arbitrary code execution.')
group.add_argument('--no_use_fast', action='store_true', help='Set use_fast=False while loading the tokenizer (it\'s True by default). Use this if you have any problems related to use_fast.')
group.add_argument('--use_flash_attention_2', action='store_true', help='Set use_flash_attention_2=True while loading the model.')
group.add_argument('--use_eager_attention', action='store_true', help='Set attn_implementation= eager while loading the model.')
group.add_argument('--attn-implementation', type=str, default='sdpa', help='Attention implementation. Valid options: sdpa, eager, flash_attention_2.')
group.add_argument('--torch-compile', action='store_true', help='Compile the model with torch.compile for improved performance.')
# bitsandbytes 4-bit

View file

@ -135,20 +135,15 @@ def load_model_HF(model_name):
params = {
'low_cpu_mem_usage': True,
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16,
'attn_implementation': shared.args.attn_implementation,
}
if shared.args.trust_remote_code:
params['trust_remote_code'] = True
if shared.args.use_flash_attention_2:
params['use_flash_attention_2'] = True
if shared.args.force_safetensors:
params['force_safetensors'] = True
if shared.args.use_eager_attention:
params['attn_implementation'] = 'eager'
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
if 'chatglm' in model_name.lower():

View file

@ -144,7 +144,7 @@ def list_model_elements():
'load_in_4bit',
'torch_compile',
'flash_attn',
'use_flash_attention_2',
'attn_implementation',
'cpu',
'disk',
'row_split',
@ -153,7 +153,6 @@ def list_model_elements():
'mlock',
'numa',
'use_double_quant',
'use_eager_attention',
'bf16',
'autosplit',
'enable_tp',

View file

@ -44,6 +44,7 @@ def create_ui():
shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=0, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Must be greater than 0 for the GPU to be used. ⚠️ Lower this value if you can\'t load the model.')
shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=256, maximum=131072, step=256, value=shared.args.ctx_size, info='Context length. Common values: 4096, 8192, 16384, 32768, 65536, 131072. ⚠️ Lower this value if you can\'t load the model.')
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.')
shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
with gr.Column():
shared.gradio['vram_info'] = gr.HTML(value=get_initial_vram_info())
@ -52,7 +53,6 @@ def create_ui():
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
shared.gradio['torch_compile'] = gr.Checkbox(label="torch-compile", value=shared.args.torch_compile, info='Compile the model with torch.compile for improved performance.')
shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.')
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant, info='Used by load-in-4bit.')
shared.gradio['autosplit'] = gr.Checkbox(label="autosplit", value=shared.args.autosplit, info='Automatically split the model tensors across the available GPUs.')
shared.gradio['enable_tp'] = gr.Checkbox(label="enable_tp", value=shared.args.enable_tp, info='Enable Tensor Parallelism (TP).')
@ -96,7 +96,6 @@ def create_ui():
shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.')
shared.gradio['use_eager_attention'] = gr.Checkbox(label="use_eager_attention", value=shared.args.use_eager_attention, info='Set attn_implementation= eager while loading the model.')
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn)
shared.gradio['no_xformers'] = gr.Checkbox(label="no_xformers", value=shared.args.no_xformers)