Image: Remove the flash_attention_3 option (no idea how to get it working)

This commit is contained in:
oobabooga 2025-12-03 18:40:34 -08:00
parent c93d27add3
commit c357eed4c7
3 changed files with 3 additions and 5 deletions

View file

@ -98,7 +98,7 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
Args:
model_name: Name of the model directory
dtype: 'bfloat16' or 'float16'
attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3'
attn_backend: 'sdpa' or 'flash_attention_2'
cpu_offload: Enable CPU offloading for low VRAM
compile_model: Compile the model for faster inference (slow first run)
quant_method: 'none', 'bnb-8bit', 'bnb-4bit', or torchao options (int8wo, fp4, float8wo)
@ -145,8 +145,6 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'):
if attn_backend == 'flash_attention_2':
pipe.transformer.set_attention_backend("flash")
elif attn_backend == 'flash_attention_3':
pipe.transformer.set_attention_backend("_flash_3")
# sdpa is the default, no action needed
if compile_model: