Image: Several fixes

This commit is contained in:
oobabooga 2025-12-05 05:53:22 -08:00
parent 8eac99599a
commit afa29b9554
4 changed files with 36 additions and 36 deletions

View file

@ -141,16 +141,24 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
if not cpu_offload:
pipe.to(get_device())
# Set attention backend (if supported by the pipeline)
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'):
if attn_backend == 'flash_attention_2':
pipe.transformer.set_attention_backend("flash")
# sdpa is the default, no action needed
modules = ["transformer", "unet"]
# Set attention backend
if attn_backend == 'flash_attention_2':
for name in modules:
mod = getattr(pipe, name, None)
if hasattr(mod, "set_attention_backend"):
mod.set_attention_backend("flash")
break
# Compile model
if compile_model:
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'compile'):
logger.info("Compiling model (first run will be slow)...")
pipe.transformer.compile()
for name in modules:
mod = getattr(pipe, name, None)
if hasattr(mod, "compile"):
logger.info("Compiling model (first run will be slow)...")
mod.compile()
break
if cpu_offload:
pipe.enable_model_cpu_offload()