mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Add quantization options (bnb and quanto)
This commit is contained in:
parent
a7808f7f42
commit
7dfb6e9c57
|
|
@ -8,7 +8,75 @@ from modules.torch_utils import get_device
|
|||
from modules.utils import resolve_model_path
|
||||
|
||||
|
||||
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False):
|
||||
def get_quantization_config(quant_method):
|
||||
"""
|
||||
Get the appropriate quantization config based on the selected method.
|
||||
|
||||
Args:
|
||||
quant_method: One of 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
|
||||
|
||||
Returns:
|
||||
PipelineQuantizationConfig or None
|
||||
"""
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from diffusers import BitsAndBytesConfig, QuantoConfig
|
||||
|
||||
if quant_method == 'none' or not quant_method:
|
||||
return None
|
||||
|
||||
# Bitsandbytes 8-bit quantization
|
||||
elif quant_method == 'bnb-8bit':
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": BitsAndBytesConfig(
|
||||
load_in_8bit=True
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
# Bitsandbytes 4-bit quantization
|
||||
elif quant_method == 'bnb-4bit':
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
# Quanto 8-bit quantization
|
||||
elif quant_method == 'quanto-8bit':
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": QuantoConfig(weights_dtype="int8")
|
||||
}
|
||||
)
|
||||
|
||||
# Quanto 4-bit quantization
|
||||
elif quant_method == 'quanto-4bit':
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": QuantoConfig(weights_dtype="int4")
|
||||
}
|
||||
)
|
||||
|
||||
# Quanto 2-bit quantization
|
||||
elif quant_method == 'quanto-2bit':
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": QuantoConfig(weights_dtype="int2")
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown quantization method: {quant_method}. Loading without quantization.")
|
||||
return None
|
||||
|
||||
|
||||
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False, quant_method='none'):
|
||||
"""
|
||||
Load a diffusers image generation model.
|
||||
|
||||
|
|
@ -18,10 +86,11 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
|
|||
attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3'
|
||||
cpu_offload: Enable CPU offloading for low VRAM
|
||||
compile_model: Compile the model for faster inference (slow first run)
|
||||
quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
|
||||
"""
|
||||
from diffusers import PipelineQuantizationConfig, ZImagePipeline
|
||||
from diffusers import ZImagePipeline
|
||||
|
||||
logger.info(f"Loading image model \"{model_name}\"")
|
||||
logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}")
|
||||
t0 = time.time()
|
||||
|
||||
dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
|
||||
|
|
@ -30,28 +99,21 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
|
|||
model_path = resolve_model_path(model_name, image_model=True)
|
||||
|
||||
try:
|
||||
# Define quantization config for 8-bit
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
)
|
||||
# Get quantization config based on selected method
|
||||
pipeline_quant_config = get_quantization_config(quant_method)
|
||||
|
||||
# Define quantization config for 4-bit
|
||||
# pipeline_quant_config = PipelineQuantizationConfig(
|
||||
# quant_backend="bitsandbytes_4bit",
|
||||
# quant_kwargs={
|
||||
# "load_in_4bit": True,
|
||||
# "bnb_4bit_quant_type": "nf4", # Or "fp4" for floating point
|
||||
# "bnb_4bit_compute_dtype": torch.bfloat16, # For faster computation
|
||||
# "bnb_4bit_use_double_quant": True, # Nested quantization for extra savings
|
||||
# },
|
||||
# )
|
||||
# Load the pipeline
|
||||
load_kwargs = {
|
||||
"torch_dtype": target_dtype,
|
||||
"low_cpu_mem_usage": True,
|
||||
}
|
||||
|
||||
if pipeline_quant_config is not None:
|
||||
load_kwargs["quantization_config"] = pipeline_quant_config
|
||||
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
str(model_path),
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=target_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
**load_kwargs
|
||||
)
|
||||
|
||||
if not cpu_offload:
|
||||
|
|
|
|||
|
|
@ -58,6 +58,9 @@ group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16',
|
|||
group.add_argument('--image-attn-backend', type=str, default=None, choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.')
|
||||
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
|
||||
group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.')
|
||||
group.add_argument('--image-quant', type=str, default=None,
|
||||
choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'],
|
||||
help='Quantization method for image model.')
|
||||
|
||||
# Model loader
|
||||
group = parser.add_argument_group('Model loader')
|
||||
|
|
@ -317,8 +320,9 @@ settings = {
|
|||
'image_model_menu': 'None',
|
||||
'image_dtype': 'bfloat16',
|
||||
'image_attn_backend': 'sdpa',
|
||||
'image_compile': False,
|
||||
'image_cpu_offload': False,
|
||||
'image_compile': False,
|
||||
'image_quant': 'none',
|
||||
}
|
||||
|
||||
default_settings = copy.deepcopy(settings)
|
||||
|
|
@ -344,8 +348,8 @@ def do_cmd_flags_warnings():
|
|||
|
||||
|
||||
def apply_image_model_cli_overrides():
|
||||
"""Apply CLI flags for image model settings, overriding saved settings."""
|
||||
if args.image_model:
|
||||
"""Apply command-line overrides for image model settings."""
|
||||
if args.image_model is not None:
|
||||
settings['image_model_menu'] = args.image_model
|
||||
if args.image_dtype is not None:
|
||||
settings['image_dtype'] = args.image_dtype
|
||||
|
|
@ -355,6 +359,9 @@ def apply_image_model_cli_overrides():
|
|||
settings['image_cpu_offload'] = True
|
||||
if args.image_compile:
|
||||
settings['image_compile'] = True
|
||||
if args.image_quant is not None:
|
||||
settings['image_quant'] = args.image_quant
|
||||
|
||||
|
||||
|
||||
def fix_loader_name(name):
|
||||
|
|
|
|||
|
|
@ -296,6 +296,7 @@ def list_interface_input_elements():
|
|||
'image_attn_backend',
|
||||
'image_compile',
|
||||
'image_cpu_offload',
|
||||
'image_quant',
|
||||
]
|
||||
|
||||
return elements
|
||||
|
|
@ -542,6 +543,7 @@ def setup_auto_save():
|
|||
'image_attn_backend',
|
||||
'image_compile',
|
||||
'image_cpu_offload',
|
||||
'image_quant',
|
||||
]
|
||||
|
||||
for element_name in change_elements:
|
||||
|
|
|
|||
|
|
@ -432,6 +432,13 @@ def create_ui():
|
|||
gr.Markdown("## Settings")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['image_quant'] = gr.Dropdown(
|
||||
label='Quantization',
|
||||
choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'],
|
||||
value=shared.settings['image_quant'],
|
||||
info='Quantization method for reduced VRAM usage. Quanto supports lower precisions (2-bit, 4-bit, 8-bit).'
|
||||
)
|
||||
|
||||
shared.gradio['image_dtype'] = gr.Dropdown(
|
||||
choices=['bfloat16', 'float16'],
|
||||
value=shared.settings['image_dtype'],
|
||||
|
|
@ -521,7 +528,7 @@ def create_event_handlers():
|
|||
|
||||
shared.gradio['image_load_model'].click(
|
||||
load_image_model_wrapper,
|
||||
gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile'),
|
||||
gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile', 'image_quant'),
|
||||
gradio('image_model_status'),
|
||||
show_progress=True
|
||||
)
|
||||
|
|
@ -610,7 +617,8 @@ def generate(state):
|
|||
dtype=state['image_dtype'],
|
||||
attn_backend=state['image_attn_backend'],
|
||||
cpu_offload=state['image_cpu_offload'],
|
||||
compile_model=state['image_compile']
|
||||
compile_model=state['image_compile'],
|
||||
quant_method=state['image_quant']
|
||||
)
|
||||
if result is None:
|
||||
logger.error(f"Failed to load model `{model_name}`.")
|
||||
|
|
@ -647,7 +655,7 @@ def generate(state):
|
|||
return all_images
|
||||
|
||||
|
||||
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model):
|
||||
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method):
|
||||
if not model_name or model_name == 'None':
|
||||
yield "No model selected"
|
||||
return
|
||||
|
|
@ -661,12 +669,13 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi
|
|||
dtype=dtype,
|
||||
attn_backend=attn_backend,
|
||||
cpu_offload=cpu_offload,
|
||||
compile_model=compile_model
|
||||
compile_model=compile_model,
|
||||
quant_method=quant_method
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
shared.image_model_name = model_name
|
||||
yield f"✓ Loaded **{model_name}**"
|
||||
yield f"✓ Loaded **{model_name}** (quantization: {quant_method})"
|
||||
else:
|
||||
yield f"✗ Failed to load `{model_name}`"
|
||||
except Exception:
|
||||
|
|
|
|||
Loading…
Reference in a new issue