Hardcode 8-bit quantization for now

This commit is contained in:
oobabooga 2025-11-27 18:23:26 -08:00
parent 822e74ac97
commit 742db85de0

View file

@ -19,7 +19,7 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
cpu_offload: Enable CPU offloading for low VRAM
compile_model: Compile the model for faster inference (slow first run)
"""
from diffusers import ZImagePipeline
from diffusers import PipelineQuantizationConfig, ZImagePipeline
logger.info(f"Loading image model \"{model_name}\"")
t0 = time.time()
@ -30,10 +30,17 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
model_path = resolve_model_path(model_name, image_model=True)
try:
# Define quantization config for 8-bit
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
)
pipe = ZImagePipeline.from_pretrained(
str(model_path),
quantization_config=pipeline_quant_config,
torch_dtype=target_dtype,
low_cpu_mem_usage=False,
low_cpu_mem_usage=True,
)
if not cpu_offload: