From 544fcb0b7f0344fac249005f869b02110da69738 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 4 Apr 2026 23:29:57 -0700 Subject: [PATCH] Simplify modules/image_models.py --- modules/image_models.py | 69 ++++++++++++++--------------------------- 1 file changed, 23 insertions(+), 46 deletions(-) diff --git a/modules/image_models.py b/modules/image_models.py index 290aaf19..eed8783c 100644 --- a/modules/image_models.py +++ b/modules/image_models.py @@ -10,72 +10,49 @@ def get_quantization_config(quant_method): Get the appropriate quantization config based on the selected method. Applies quantization to both the transformer and the text_encoder. """ + if quant_method == 'none' or not quant_method: + return None + import torch - # Import BitsAndBytesConfig from BOTH libraries to be safe from diffusers import BitsAndBytesConfig as DiffusersBnBConfig from diffusers import TorchAoConfig from diffusers.quantizers import PipelineQuantizationConfig from transformers import BitsAndBytesConfig as TransformersBnBConfig - if quant_method == 'none' or not quant_method: - return None + torchao_methods = { + 'torchao-int8wo': 'int8wo', + 'torchao-fp4': 'fp4_e2m1', + 'torchao-float8wo': 'float8wo', + } - # Bitsandbytes 8-bit quantization - elif quant_method == 'bnb-8bit': + if quant_method == 'bnb-8bit': return PipelineQuantizationConfig( quant_mapping={ - "transformer": DiffusersBnBConfig( - load_in_8bit=True - ), - "text_encoder": TransformersBnBConfig( - load_in_8bit=True - ) + "transformer": DiffusersBnBConfig(load_in_8bit=True), + "text_encoder": TransformersBnBConfig(load_in_8bit=True) } ) - # Bitsandbytes 4-bit quantization elif quant_method == 'bnb-4bit': + bnb_4bit_kwargs = dict( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True + ) return PipelineQuantizationConfig( quant_mapping={ - "transformer": DiffusersBnBConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, - bnb_4bit_use_double_quant=True - ), - "text_encoder": TransformersBnBConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, - bnb_4bit_use_double_quant=True - ) + "transformer": DiffusersBnBConfig(**bnb_4bit_kwargs), + "text_encoder": TransformersBnBConfig(**bnb_4bit_kwargs) } ) - # torchao int8 weight-only - elif quant_method == 'torchao-int8wo': + elif quant_method in torchao_methods: + ao_type = torchao_methods[quant_method] return PipelineQuantizationConfig( quant_mapping={ - "transformer": TorchAoConfig("int8wo"), - "text_encoder": TorchAoConfig("int8wo") - } - ) - - # torchao fp4 (e2m1) - elif quant_method == 'torchao-fp4': - return PipelineQuantizationConfig( - quant_mapping={ - "transformer": TorchAoConfig("fp4_e2m1"), - "text_encoder": TorchAoConfig("fp4_e2m1") - } - ) - - # torchao float8 weight-only - elif quant_method == 'torchao-float8wo': - return PipelineQuantizationConfig( - quant_mapping={ - "transformer": TorchAoConfig("float8wo"), - "text_encoder": TorchAoConfig("float8wo") + "transformer": TorchAoConfig(ao_type), + "text_encoder": TorchAoConfig(ao_type) } )