Simplify modules/image_models.py

This commit is contained in:
oobabooga 2026-04-04 23:29:57 -07:00
parent c63a79ee48
commit 544fcb0b7f

View file

@ -10,72 +10,49 @@ def get_quantization_config(quant_method):
Get the appropriate quantization config based on the selected method.
Applies quantization to both the transformer and the text_encoder.
"""
if quant_method == 'none' or not quant_method:
return None
import torch
# Import BitsAndBytesConfig from BOTH libraries to be safe
from diffusers import BitsAndBytesConfig as DiffusersBnBConfig
from diffusers import TorchAoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from transformers import BitsAndBytesConfig as TransformersBnBConfig
if quant_method == 'none' or not quant_method:
return None
torchao_methods = {
'torchao-int8wo': 'int8wo',
'torchao-fp4': 'fp4_e2m1',
'torchao-float8wo': 'float8wo',
}
# Bitsandbytes 8-bit quantization
elif quant_method == 'bnb-8bit':
if quant_method == 'bnb-8bit':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": DiffusersBnBConfig(
load_in_8bit=True
),
"text_encoder": TransformersBnBConfig(
load_in_8bit=True
)
"transformer": DiffusersBnBConfig(load_in_8bit=True),
"text_encoder": TransformersBnBConfig(load_in_8bit=True)
}
)
# Bitsandbytes 4-bit quantization
elif quant_method == 'bnb-4bit':
bnb_4bit_kwargs = dict(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
return PipelineQuantizationConfig(
quant_mapping={
"transformer": DiffusersBnBConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
),
"text_encoder": TransformersBnBConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
"transformer": DiffusersBnBConfig(**bnb_4bit_kwargs),
"text_encoder": TransformersBnBConfig(**bnb_4bit_kwargs)
}
)
# torchao int8 weight-only
elif quant_method == 'torchao-int8wo':
elif quant_method in torchao_methods:
ao_type = torchao_methods[quant_method]
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig("int8wo"),
"text_encoder": TorchAoConfig("int8wo")
}
)
# torchao fp4 (e2m1)
elif quant_method == 'torchao-fp4':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig("fp4_e2m1"),
"text_encoder": TorchAoConfig("fp4_e2m1")
}
)
# torchao float8 weight-only
elif quant_method == 'torchao-float8wo':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig("float8wo"),
"text_encoder": TorchAoConfig("float8wo")
"transformer": TorchAoConfig(ao_type),
"text_encoder": TorchAoConfig(ao_type)
}
)