Image generation: add torchao quantization (supports torch.compile)

This commit is contained in:
oobabooga 2025-12-02 14:22:51 -08:00
parent 97281ff831
commit 9448bf1caa
12 changed files with 40 additions and 6 deletions

View file

@ -10,13 +10,14 @@ def get_quantization_config(quant_method):
Get the appropriate quantization config based on the selected method. Get the appropriate quantization config based on the selected method.
Args: Args:
quant_method: One of 'none', 'bnb-8bit', 'bnb-4bit' quant_method: One of 'none', 'bnb-8bit', 'bnb-4bit',
'torchao-int8wo', 'torchao-fp4', 'torchao-float8wo'
Returns: Returns:
PipelineQuantizationConfig or None PipelineQuantizationConfig or None
""" """
import torch import torch
from diffusers import BitsAndBytesConfig, QuantoConfig from diffusers import BitsAndBytesConfig, TorchAoConfig
from diffusers.quantizers import PipelineQuantizationConfig from diffusers.quantizers import PipelineQuantizationConfig
if quant_method == 'none' or not quant_method: if quant_method == 'none' or not quant_method:
@ -45,6 +46,30 @@ def get_quantization_config(quant_method):
} }
) )
# torchao int8 weight-only
elif quant_method == 'torchao-int8wo':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig("int8wo")
}
)
# torchao fp4 (e2m1)
elif quant_method == 'torchao-fp4':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig("fp4_e2m1")
}
)
# torchao float8 weight-only
elif quant_method == 'torchao-float8wo':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig("float8wo")
}
)
else: else:
logger.warning(f"Unknown quantization method: {quant_method}. Loading without quantization.") logger.warning(f"Unknown quantization method: {quant_method}. Loading without quantization.")
return None return None
@ -76,7 +101,7 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3' attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3'
cpu_offload: Enable CPU offloading for low VRAM cpu_offload: Enable CPU offloading for low VRAM
compile_model: Compile the model for faster inference (slow first run) compile_model: Compile the model for faster inference (slow first run)
quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit' quant_method: 'none', 'bnb-8bit', 'bnb-4bit', or torchao options (int8wo, fp4, float8wo)
""" """
import torch import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline

View file

@ -60,7 +60,7 @@ group.add_argument('--image-attn-backend', type=str, default=None, choices=['sdp
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.') group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.') group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.')
group.add_argument('--image-quant', type=str, default=None, group.add_argument('--image-quant', type=str, default=None,
choices=['none', 'bnb-8bit', 'bnb-4bit'], choices=['none', 'bnb-8bit', 'bnb-4bit', 'torchao-int8wo', 'torchao-fp4', 'torchao-float8wo'],
help='Quantization method for image model.') help='Quantization method for image model.')
# Model loader # Model loader

View file

@ -473,9 +473,9 @@ def create_ui():
with gr.Column(): with gr.Column():
shared.gradio['image_quant'] = gr.Dropdown( shared.gradio['image_quant'] = gr.Dropdown(
label='Quantization', label='Quantization',
choices=['none', 'bnb-8bit', 'bnb-4bit'], choices=['none', 'bnb-8bit', 'bnb-4bit', 'torchao-int8wo', 'torchao-fp4', 'torchao-float8wo'],
value=shared.settings['image_quant'], value=shared.settings['image_quant'],
info='Quantization method for reduced VRAM usage. Quanto supports lower precisions (2-bit, 4-bit, 8-bit).' info='BnB: bitsandbytes quantization. torchao: int8wo, fp4, float8wo.'
) )
shared.gradio['image_dtype'] = gr.Dropdown( shared.gradio['image_dtype'] = gr.Dropdown(

View file

@ -25,6 +25,7 @@ safetensors==0.6.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
torchao==0.14.*
transformers==4.57.* transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows" triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
torchao==0.14.*
transformers==4.57.* transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows" triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
torchao==0.14.*
transformers==4.57.* transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows" triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
torchao==0.14.*
transformers==4.57.* transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows" triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
torchao==0.14.*
transformers==4.57.* transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows" triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
torchao==0.14.*
transformers==4.57.* transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows" triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
torchao==0.14.*
transformers==4.57.* transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows" triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm tqdm

View file

@ -25,6 +25,7 @@ safetensors==0.6.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
torchao==0.14.*
transformers==4.57.* transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows" triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm tqdm

View file

@ -23,6 +23,7 @@ safetensors==0.6.*
scipy scipy
sentencepiece sentencepiece
tensorboard tensorboard
torchao==0.14.*
transformers==4.57.* transformers==4.57.*
triton-windows==3.5.1.post21; platform_system == "Windows" triton-windows==3.5.1.post21; platform_system == "Windows"
tqdm tqdm