From 7dfb6e9c57c6ac45c2c77409e032c59739b7724b Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 1 Dec 2025 17:05:42 -0800 Subject: [PATCH] Add quantization options (bnb and quanto) --- modules/image_models.py | 104 ++++++++++++++++++++++++++------- modules/shared.py | 13 ++++- modules/ui.py | 2 + modules/ui_image_generation.py | 19 ++++-- 4 files changed, 109 insertions(+), 29 deletions(-) diff --git a/modules/image_models.py b/modules/image_models.py index 9e2075fd..de3743bf 100644 --- a/modules/image_models.py +++ b/modules/image_models.py @@ -8,7 +8,75 @@ from modules.torch_utils import get_device from modules.utils import resolve_model_path -def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False): +def get_quantization_config(quant_method): + """ + Get the appropriate quantization config based on the selected method. + + Args: + quant_method: One of 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit' + + Returns: + PipelineQuantizationConfig or None + """ + from diffusers.quantizers import PipelineQuantizationConfig + from diffusers import BitsAndBytesConfig, QuantoConfig + + if quant_method == 'none' or not quant_method: + return None + + # Bitsandbytes 8-bit quantization + elif quant_method == 'bnb-8bit': + return PipelineQuantizationConfig( + quant_mapping={ + "transformer": BitsAndBytesConfig( + load_in_8bit=True + ) + } + ) + + # Bitsandbytes 4-bit quantization + elif quant_method == 'bnb-4bit': + return PipelineQuantizationConfig( + quant_mapping={ + "transformer": BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True + ) + } + ) + + # Quanto 8-bit quantization + elif quant_method == 'quanto-8bit': + return PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8") + } + ) + + # Quanto 4-bit quantization + elif quant_method == 'quanto-4bit': + return PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int4") + } + ) + + # Quanto 2-bit quantization + elif quant_method == 'quanto-2bit': + return PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int2") + } + ) + + else: + logger.warning(f"Unknown quantization method: {quant_method}. Loading without quantization.") + return None + + +def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False, quant_method='none'): """ Load a diffusers image generation model. @@ -18,10 +86,11 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3' cpu_offload: Enable CPU offloading for low VRAM compile_model: Compile the model for faster inference (slow first run) + quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit' """ - from diffusers import PipelineQuantizationConfig, ZImagePipeline + from diffusers import ZImagePipeline - logger.info(f"Loading image model \"{model_name}\"") + logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}") t0 = time.time() dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16} @@ -30,28 +99,21 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl model_path = resolve_model_path(model_name, image_model=True) try: - # Define quantization config for 8-bit - pipeline_quant_config = PipelineQuantizationConfig( - quant_backend="bitsandbytes_8bit", - quant_kwargs={"load_in_8bit": True}, - ) + # Get quantization config based on selected method + pipeline_quant_config = get_quantization_config(quant_method) - # Define quantization config for 4-bit - # pipeline_quant_config = PipelineQuantizationConfig( - # quant_backend="bitsandbytes_4bit", - # quant_kwargs={ - # "load_in_4bit": True, - # "bnb_4bit_quant_type": "nf4", # Or "fp4" for floating point - # "bnb_4bit_compute_dtype": torch.bfloat16, # For faster computation - # "bnb_4bit_use_double_quant": True, # Nested quantization for extra savings - # }, - # ) + # Load the pipeline + load_kwargs = { + "torch_dtype": target_dtype, + "low_cpu_mem_usage": True, + } + + if pipeline_quant_config is not None: + load_kwargs["quantization_config"] = pipeline_quant_config pipe = ZImagePipeline.from_pretrained( str(model_path), - quantization_config=pipeline_quant_config, - torch_dtype=target_dtype, - low_cpu_mem_usage=True, + **load_kwargs ) if not cpu_offload: diff --git a/modules/shared.py b/modules/shared.py index 9a062e91..d33aa717 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -58,6 +58,9 @@ group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16', group.add_argument('--image-attn-backend', type=str, default=None, choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.') group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.') group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.') +group.add_argument('--image-quant', type=str, default=None, + choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'], + help='Quantization method for image model.') # Model loader group = parser.add_argument_group('Model loader') @@ -317,8 +320,9 @@ settings = { 'image_model_menu': 'None', 'image_dtype': 'bfloat16', 'image_attn_backend': 'sdpa', - 'image_compile': False, 'image_cpu_offload': False, + 'image_compile': False, + 'image_quant': 'none', } default_settings = copy.deepcopy(settings) @@ -344,8 +348,8 @@ def do_cmd_flags_warnings(): def apply_image_model_cli_overrides(): - """Apply CLI flags for image model settings, overriding saved settings.""" - if args.image_model: + """Apply command-line overrides for image model settings.""" + if args.image_model is not None: settings['image_model_menu'] = args.image_model if args.image_dtype is not None: settings['image_dtype'] = args.image_dtype @@ -355,6 +359,9 @@ def apply_image_model_cli_overrides(): settings['image_cpu_offload'] = True if args.image_compile: settings['image_compile'] = True + if args.image_quant is not None: + settings['image_quant'] = args.image_quant + def fix_loader_name(name): diff --git a/modules/ui.py b/modules/ui.py index 3aba20b4..3bcba56b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -296,6 +296,7 @@ def list_interface_input_elements(): 'image_attn_backend', 'image_compile', 'image_cpu_offload', + 'image_quant', ] return elements @@ -542,6 +543,7 @@ def setup_auto_save(): 'image_attn_backend', 'image_compile', 'image_cpu_offload', + 'image_quant', ] for element_name in change_elements: diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py index ff1b9f67..a5cf3695 100644 --- a/modules/ui_image_generation.py +++ b/modules/ui_image_generation.py @@ -432,6 +432,13 @@ def create_ui(): gr.Markdown("## Settings") with gr.Row(): with gr.Column(): + shared.gradio['image_quant'] = gr.Dropdown( + label='Quantization', + choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'], + value=shared.settings['image_quant'], + info='Quantization method for reduced VRAM usage. Quanto supports lower precisions (2-bit, 4-bit, 8-bit).' + ) + shared.gradio['image_dtype'] = gr.Dropdown( choices=['bfloat16', 'float16'], value=shared.settings['image_dtype'], @@ -521,7 +528,7 @@ def create_event_handlers(): shared.gradio['image_load_model'].click( load_image_model_wrapper, - gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile'), + gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile', 'image_quant'), gradio('image_model_status'), show_progress=True ) @@ -610,7 +617,8 @@ def generate(state): dtype=state['image_dtype'], attn_backend=state['image_attn_backend'], cpu_offload=state['image_cpu_offload'], - compile_model=state['image_compile'] + compile_model=state['image_compile'], + quant_method=state['image_quant'] ) if result is None: logger.error(f"Failed to load model `{model_name}`.") @@ -647,7 +655,7 @@ def generate(state): return all_images -def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model): +def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method): if not model_name or model_name == 'None': yield "No model selected" return @@ -661,12 +669,13 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi dtype=dtype, attn_backend=attn_backend, cpu_offload=cpu_offload, - compile_model=compile_model + compile_model=compile_model, + quant_method=quant_method ) if result is not None: shared.image_model_name = model_name - yield f"✓ Loaded **{model_name}**" + yield f"✓ Loaded **{model_name}** (quantization: {quant_method})" else: yield f"✗ Failed to load `{model_name}`" except Exception: