text-generation-webui/modules/image_models.py

98 lines
3.1 KiB
Python
Raw Normal View History

2025-11-28 00:56:58 +01:00
import time
2025-11-28 03:15:15 +01:00
2025-11-28 00:56:58 +01:00
import torch
import modules.shared as shared
from modules.logging_colors import logger
from modules.torch_utils import get_device
2025-11-28 03:15:15 +01:00
from modules.utils import resolve_model_path
2025-11-28 00:56:58 +01:00
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False):
"""
Load a diffusers image generation model.
Args:
model_name: Name of the model directory
dtype: 'bfloat16' or 'float16'
attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3'
cpu_offload: Enable CPU offloading for low VRAM
compile_model: Compile the model for faster inference (slow first run)
"""
2025-11-28 03:23:26 +01:00
from diffusers import PipelineQuantizationConfig, ZImagePipeline
2025-11-28 00:56:58 +01:00
logger.info(f"Loading image model \"{model_name}\"")
t0 = time.time()
dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
target_dtype = dtype_map.get(dtype, torch.bfloat16)
model_path = resolve_model_path(model_name, image_model=True)
try:
2025-11-28 03:23:26 +01:00
# Define quantization config for 8-bit
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
)
2025-11-28 03:29:32 +01:00
# Define quantization config for 4-bit
# pipeline_quant_config = PipelineQuantizationConfig(
# quant_backend="bitsandbytes_4bit",
# quant_kwargs={
# "load_in_4bit": True,
# "bnb_4bit_quant_type": "nf4", # Or "fp4" for floating point
# "bnb_4bit_compute_dtype": torch.bfloat16, # For faster computation
# "bnb_4bit_use_double_quant": True, # Nested quantization for extra savings
# },
# )
2025-11-28 00:56:58 +01:00
pipe = ZImagePipeline.from_pretrained(
str(model_path),
2025-11-28 03:23:26 +01:00
quantization_config=pipeline_quant_config,
2025-11-28 00:56:58 +01:00
torch_dtype=target_dtype,
2025-11-28 03:23:26 +01:00
low_cpu_mem_usage=True,
2025-11-28 00:56:58 +01:00
)
if not cpu_offload:
pipe.to(get_device())
# Set attention backend
if attn_backend == 'flash_attention_2':
pipe.transformer.set_attention_backend("flash")
elif attn_backend == 'flash_attention_3':
pipe.transformer.set_attention_backend("_flash_3")
# sdpa is the default, no action needed
if compile_model:
logger.info("Compiling model (first run will be slow)...")
pipe.transformer.compile()
if cpu_offload:
pipe.enable_model_cpu_offload()
shared.image_model = pipe
shared.image_model_name = model_name
2025-11-28 03:15:15 +01:00
logger.info(f"Loaded image model \"{model_name}\" in {(time.time() - t0):.2f} seconds.")
2025-11-28 00:56:58 +01:00
return pipe
except Exception as e:
logger.error(f"Failed to load image model: {str(e)}")
return None
def unload_image_model():
"""Unload the current image model and free VRAM."""
if shared.image_model is None:
return
del shared.image_model
shared.image_model = None
shared.image_model_name = 'None'
2025-11-28 03:15:15 +01:00
from modules.torch_utils import clear_torch_cache
clear_torch_cache()
2025-11-28 00:56:58 +01:00
logger.info("Image model unloaded.")