mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-06 07:03:37 +00:00
Image: Several fixes
This commit is contained in:
parent
8eac99599a
commit
afa29b9554
4 changed files with 36 additions and 36 deletions
|
|
@ -141,16 +141,24 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
|
|||
if not cpu_offload:
|
||||
pipe.to(get_device())
|
||||
|
||||
# Set attention backend (if supported by the pipeline)
|
||||
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'):
|
||||
if attn_backend == 'flash_attention_2':
|
||||
pipe.transformer.set_attention_backend("flash")
|
||||
# sdpa is the default, no action needed
|
||||
modules = ["transformer", "unet"]
|
||||
|
||||
# Set attention backend
|
||||
if attn_backend == 'flash_attention_2':
|
||||
for name in modules:
|
||||
mod = getattr(pipe, name, None)
|
||||
if hasattr(mod, "set_attention_backend"):
|
||||
mod.set_attention_backend("flash")
|
||||
break
|
||||
|
||||
# Compile model
|
||||
if compile_model:
|
||||
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'compile'):
|
||||
logger.info("Compiling model (first run will be slow)...")
|
||||
pipe.transformer.compile()
|
||||
for name in modules:
|
||||
mod = getattr(pipe, name, None)
|
||||
if hasattr(mod, "compile"):
|
||||
logger.info("Compiling model (first run will be slow)...")
|
||||
mod.compile()
|
||||
break
|
||||
|
||||
if cpu_offload:
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue