Add quantization options (bnb and quanto)

This commit is contained in:
oobabooga 2025-12-01 17:05:42 -08:00
parent a7808f7f42
commit 7dfb6e9c57
4 changed files with 109 additions and 29 deletions

View file

@ -8,7 +8,75 @@ from modules.torch_utils import get_device
from modules.utils import resolve_model_path
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False):
def get_quantization_config(quant_method):
"""
Get the appropriate quantization config based on the selected method.
Args:
quant_method: One of 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
Returns:
PipelineQuantizationConfig or None
"""
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers import BitsAndBytesConfig, QuantoConfig
if quant_method == 'none' or not quant_method:
return None
# Bitsandbytes 8-bit quantization
elif quant_method == 'bnb-8bit':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": BitsAndBytesConfig(
load_in_8bit=True
)
}
)
# Bitsandbytes 4-bit quantization
elif quant_method == 'bnb-4bit':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
}
)
# Quanto 8-bit quantization
elif quant_method == 'quanto-8bit':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8")
}
)
# Quanto 4-bit quantization
elif quant_method == 'quanto-4bit':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int4")
}
)
# Quanto 2-bit quantization
elif quant_method == 'quanto-2bit':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int2")
}
)
else:
logger.warning(f"Unknown quantization method: {quant_method}. Loading without quantization.")
return None
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False, quant_method='none'):
"""
Load a diffusers image generation model.
@ -18,10 +86,11 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3'
cpu_offload: Enable CPU offloading for low VRAM
compile_model: Compile the model for faster inference (slow first run)
quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
"""
from diffusers import PipelineQuantizationConfig, ZImagePipeline
from diffusers import ZImagePipeline
logger.info(f"Loading image model \"{model_name}\"")
logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}")
t0 = time.time()
dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
@ -30,28 +99,21 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
model_path = resolve_model_path(model_name, image_model=True)
try:
# Define quantization config for 8-bit
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
)
# Get quantization config based on selected method
pipeline_quant_config = get_quantization_config(quant_method)
# Define quantization config for 4-bit
# pipeline_quant_config = PipelineQuantizationConfig(
# quant_backend="bitsandbytes_4bit",
# quant_kwargs={
# "load_in_4bit": True,
# "bnb_4bit_quant_type": "nf4", # Or "fp4" for floating point
# "bnb_4bit_compute_dtype": torch.bfloat16, # For faster computation
# "bnb_4bit_use_double_quant": True, # Nested quantization for extra savings
# },
# )
# Load the pipeline
load_kwargs = {
"torch_dtype": target_dtype,
"low_cpu_mem_usage": True,
}
if pipeline_quant_config is not None:
load_kwargs["quantization_config"] = pipeline_quant_config
pipe = ZImagePipeline.from_pretrained(
str(model_path),
quantization_config=pipeline_quant_config,
torch_dtype=target_dtype,
low_cpu_mem_usage=True,
**load_kwargs
)
if not cpu_offload:

View file

@ -58,6 +58,9 @@ group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16',
group.add_argument('--image-attn-backend', type=str, default=None, choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.')
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.')
group.add_argument('--image-quant', type=str, default=None,
choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'],
help='Quantization method for image model.')
# Model loader
group = parser.add_argument_group('Model loader')
@ -317,8 +320,9 @@ settings = {
'image_model_menu': 'None',
'image_dtype': 'bfloat16',
'image_attn_backend': 'sdpa',
'image_compile': False,
'image_cpu_offload': False,
'image_compile': False,
'image_quant': 'none',
}
default_settings = copy.deepcopy(settings)
@ -344,8 +348,8 @@ def do_cmd_flags_warnings():
def apply_image_model_cli_overrides():
"""Apply CLI flags for image model settings, overriding saved settings."""
if args.image_model:
"""Apply command-line overrides for image model settings."""
if args.image_model is not None:
settings['image_model_menu'] = args.image_model
if args.image_dtype is not None:
settings['image_dtype'] = args.image_dtype
@ -355,6 +359,9 @@ def apply_image_model_cli_overrides():
settings['image_cpu_offload'] = True
if args.image_compile:
settings['image_compile'] = True
if args.image_quant is not None:
settings['image_quant'] = args.image_quant
def fix_loader_name(name):

View file

@ -296,6 +296,7 @@ def list_interface_input_elements():
'image_attn_backend',
'image_compile',
'image_cpu_offload',
'image_quant',
]
return elements
@ -542,6 +543,7 @@ def setup_auto_save():
'image_attn_backend',
'image_compile',
'image_cpu_offload',
'image_quant',
]
for element_name in change_elements:

View file

@ -432,6 +432,13 @@ def create_ui():
gr.Markdown("## Settings")
with gr.Row():
with gr.Column():
shared.gradio['image_quant'] = gr.Dropdown(
label='Quantization',
choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'],
value=shared.settings['image_quant'],
info='Quantization method for reduced VRAM usage. Quanto supports lower precisions (2-bit, 4-bit, 8-bit).'
)
shared.gradio['image_dtype'] = gr.Dropdown(
choices=['bfloat16', 'float16'],
value=shared.settings['image_dtype'],
@ -521,7 +528,7 @@ def create_event_handlers():
shared.gradio['image_load_model'].click(
load_image_model_wrapper,
gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile'),
gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile', 'image_quant'),
gradio('image_model_status'),
show_progress=True
)
@ -610,7 +617,8 @@ def generate(state):
dtype=state['image_dtype'],
attn_backend=state['image_attn_backend'],
cpu_offload=state['image_cpu_offload'],
compile_model=state['image_compile']
compile_model=state['image_compile'],
quant_method=state['image_quant']
)
if result is None:
logger.error(f"Failed to load model `{model_name}`.")
@ -647,7 +655,7 @@ def generate(state):
return all_images
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model):
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method):
if not model_name or model_name == 'None':
yield "No model selected"
return
@ -661,12 +669,13 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi
dtype=dtype,
attn_backend=attn_backend,
cpu_offload=cpu_offload,
compile_model=compile_model
compile_model=compile_model,
quant_method=quant_method
)
if result is not None:
shared.image_model_name = model_name
yield f"✓ Loaded **{model_name}**"
yield f"✓ Loaded **{model_name}** (quantization: {quant_method})"
else:
yield f"✗ Failed to load `{model_name}`"
except Exception: