Add the code for 4-bit quantization

This commit is contained in:
oobabooga 2025-11-27 18:29:32 -08:00
parent 742db85de0
commit cecb172d2c

View file

@ -36,6 +36,17 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
quant_kwargs={"load_in_8bit": True}, quant_kwargs={"load_in_8bit": True},
) )
# Define quantization config for 4-bit
# pipeline_quant_config = PipelineQuantizationConfig(
# quant_backend="bitsandbytes_4bit",
# quant_kwargs={
# "load_in_4bit": True,
# "bnb_4bit_quant_type": "nf4", # Or "fp4" for floating point
# "bnb_4bit_compute_dtype": torch.bfloat16, # For faster computation
# "bnb_4bit_use_double_quant": True, # Nested quantization for extra savings
# },
# )
pipe = ZImagePipeline.from_pretrained( pipe = ZImagePipeline.from_pretrained(
str(model_path), str(model_path),
quantization_config=pipeline_quant_config, quantization_config=pipeline_quant_config,