diff --git a/modules/image_models.py b/modules/image_models.py index 6a6c6547..9e2075fd 100644 --- a/modules/image_models.py +++ b/modules/image_models.py @@ -36,6 +36,17 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl quant_kwargs={"load_in_8bit": True}, ) + # Define quantization config for 4-bit + # pipeline_quant_config = PipelineQuantizationConfig( + # quant_backend="bitsandbytes_4bit", + # quant_kwargs={ + # "load_in_4bit": True, + # "bnb_4bit_quant_type": "nf4", # Or "fp4" for floating point + # "bnb_4bit_compute_dtype": torch.bfloat16, # For faster computation + # "bnb_4bit_use_double_quant": True, # Nested quantization for extra savings + # }, + # ) + pipe = ZImagePipeline.from_pretrained( str(model_path), quantization_config=pipeline_quant_config,