diff --git a/modules/image_models.py b/modules/image_models.py index 21612f61..6a6c6547 100644 --- a/modules/image_models.py +++ b/modules/image_models.py @@ -19,7 +19,7 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl cpu_offload: Enable CPU offloading for low VRAM compile_model: Compile the model for faster inference (slow first run) """ - from diffusers import ZImagePipeline + from diffusers import PipelineQuantizationConfig, ZImagePipeline logger.info(f"Loading image model \"{model_name}\"") t0 = time.time() @@ -30,10 +30,17 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl model_path = resolve_model_path(model_name, image_model=True) try: + # Define quantization config for 8-bit + pipeline_quant_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={"load_in_8bit": True}, + ) + pipe = ZImagePipeline.from_pretrained( str(model_path), + quantization_config=pipeline_quant_config, torch_dtype=target_dtype, - low_cpu_mem_usage=False, + low_cpu_mem_usage=True, ) if not cpu_offload: