# modules/image_models.py import time import torch import modules.shared as shared from modules.logging_colors import logger from modules.utils import resolve_model_path from modules.torch_utils import get_device def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False): """ Load a diffusers image generation model. Args: model_name: Name of the model directory dtype: 'bfloat16' or 'float16' attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3' cpu_offload: Enable CPU offloading for low VRAM compile_model: Compile the model for faster inference (slow first run) """ from diffusers import ZImagePipeline logger.info(f"Loading image model \"{model_name}\"") t0 = time.time() dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16} target_dtype = dtype_map.get(dtype, torch.bfloat16) model_path = resolve_model_path(model_name, image_model=True) try: pipe = ZImagePipeline.from_pretrained( str(model_path), torch_dtype=target_dtype, low_cpu_mem_usage=False, ) if not cpu_offload: pipe.to(get_device()) # Set attention backend if attn_backend == 'flash_attention_2': pipe.transformer.set_attention_backend("flash") elif attn_backend == 'flash_attention_3': pipe.transformer.set_attention_backend("_flash_3") # sdpa is the default, no action needed if compile_model: logger.info("Compiling model (first run will be slow)...") pipe.transformer.compile() if cpu_offload: pipe.enable_model_cpu_offload() shared.image_model = pipe shared.image_model_name = model_name logger.info(f"Loaded image model \"{model_name}\" in {(time.time()-t0):.2f} seconds.") return pipe except Exception as e: logger.error(f"Failed to load image model: {str(e)}") return None def unload_image_model(): """Unload the current image model and free VRAM.""" if shared.image_model is None: return del shared.image_model shared.image_model = None shared.image_model_name = 'None' # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Image model unloaded.")