Image generation: add torchao quantization (supports torch.compile)

This commit is contained in:
oobabooga 2025-12-02 14:22:51 -08:00
parent 97281ff831
commit 9448bf1caa
12 changed files with 40 additions and 6 deletions

View file

@ -10,13 +10,14 @@ def get_quantization_config(quant_method):
Get the appropriate quantization config based on the selected method.
Args:
quant_method: One of 'none', 'bnb-8bit', 'bnb-4bit'
quant_method: One of 'none', 'bnb-8bit', 'bnb-4bit',
'torchao-int8wo', 'torchao-fp4', 'torchao-float8wo'
Returns:
PipelineQuantizationConfig or None
"""
import torch
from diffusers import BitsAndBytesConfig, QuantoConfig
from diffusers import BitsAndBytesConfig, TorchAoConfig
from diffusers.quantizers import PipelineQuantizationConfig
if quant_method == 'none' or not quant_method:
@ -45,6 +46,30 @@ def get_quantization_config(quant_method):
}
)
# torchao int8 weight-only
elif quant_method == 'torchao-int8wo':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig("int8wo")
}
)
# torchao fp4 (e2m1)
elif quant_method == 'torchao-fp4':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig("fp4_e2m1")
}
)
# torchao float8 weight-only
elif quant_method == 'torchao-float8wo':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig("float8wo")
}
)
else:
logger.warning(f"Unknown quantization method: {quant_method}. Loading without quantization.")
return None
@ -76,7 +101,7 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3'
cpu_offload: Enable CPU offloading for low VRAM
compile_model: Compile the model for faster inference (slow first run)
quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit'
quant_method: 'none', 'bnb-8bit', 'bnb-4bit', or torchao options (int8wo, fp4, float8wo)
"""
import torch
from diffusers import DiffusionPipeline

View file

@ -60,7 +60,7 @@ group.add_argument('--image-attn-backend', type=str, default=None, choices=['sdp
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.')
group.add_argument('--image-quant', type=str, default=None,
choices=['none', 'bnb-8bit', 'bnb-4bit'],
choices=['none', 'bnb-8bit', 'bnb-4bit', 'torchao-int8wo', 'torchao-fp4', 'torchao-float8wo'],
help='Quantization method for image model.')
# Model loader

View file

@ -473,9 +473,9 @@ def create_ui():
with gr.Column():
shared.gradio['image_quant'] = gr.Dropdown(
label='Quantization',
choices=['none', 'bnb-8bit', 'bnb-4bit'],
choices=['none', 'bnb-8bit', 'bnb-4bit', 'torchao-int8wo', 'torchao-fp4', 'torchao-float8wo'],
value=shared.settings['image_quant'],
info='Quantization method for reduced VRAM usage. Quanto supports lower precisions (2-bit, 4-bit, 8-bit).'
info='BnB: bitsandbytes quantization. torchao: int8wo, fp4, float8wo.'
)
shared.gradio['image_dtype'] = gr.Dropdown(

View file

@ -25,6 +25,7 @@ safetensors==0.6.*
scipy
sentencepiece
tensorboard
torchao==0.14.*
transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy
sentencepiece
tensorboard
torchao==0.14.*
transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy
sentencepiece
tensorboard
torchao==0.14.*
transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy
sentencepiece
tensorboard
torchao==0.14.*
transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy
sentencepiece
tensorboard
torchao==0.14.*
transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy
sentencepiece
tensorboard
torchao==0.14.*
transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy
sentencepiece
tensorboard
torchao==0.14.*
transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm

View file

@ -25,6 +25,7 @@ safetensors==0.6.*
scipy
sentencepiece
tensorboard
torchao==0.14.*
transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy
sentencepiece
tensorboard
torchao==0.14.*
transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm