mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-07 07:33:47 +00:00
Simplify modules/image_models.py
This commit is contained in:
parent
c63a79ee48
commit
544fcb0b7f
1 changed files with 23 additions and 46 deletions
|
|
@ -10,72 +10,49 @@ def get_quantization_config(quant_method):
|
|||
Get the appropriate quantization config based on the selected method.
|
||||
Applies quantization to both the transformer and the text_encoder.
|
||||
"""
|
||||
if quant_method == 'none' or not quant_method:
|
||||
return None
|
||||
|
||||
import torch
|
||||
# Import BitsAndBytesConfig from BOTH libraries to be safe
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBnBConfig
|
||||
from diffusers import TorchAoConfig
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBnBConfig
|
||||
|
||||
if quant_method == 'none' or not quant_method:
|
||||
return None
|
||||
torchao_methods = {
|
||||
'torchao-int8wo': 'int8wo',
|
||||
'torchao-fp4': 'fp4_e2m1',
|
||||
'torchao-float8wo': 'float8wo',
|
||||
}
|
||||
|
||||
# Bitsandbytes 8-bit quantization
|
||||
elif quant_method == 'bnb-8bit':
|
||||
if quant_method == 'bnb-8bit':
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": DiffusersBnBConfig(
|
||||
load_in_8bit=True
|
||||
),
|
||||
"text_encoder": TransformersBnBConfig(
|
||||
load_in_8bit=True
|
||||
)
|
||||
"transformer": DiffusersBnBConfig(load_in_8bit=True),
|
||||
"text_encoder": TransformersBnBConfig(load_in_8bit=True)
|
||||
}
|
||||
)
|
||||
|
||||
# Bitsandbytes 4-bit quantization
|
||||
elif quant_method == 'bnb-4bit':
|
||||
bnb_4bit_kwargs = dict(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True
|
||||
)
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": DiffusersBnBConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True
|
||||
),
|
||||
"text_encoder": TransformersBnBConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True
|
||||
)
|
||||
"transformer": DiffusersBnBConfig(**bnb_4bit_kwargs),
|
||||
"text_encoder": TransformersBnBConfig(**bnb_4bit_kwargs)
|
||||
}
|
||||
)
|
||||
|
||||
# torchao int8 weight-only
|
||||
elif quant_method == 'torchao-int8wo':
|
||||
elif quant_method in torchao_methods:
|
||||
ao_type = torchao_methods[quant_method]
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": TorchAoConfig("int8wo"),
|
||||
"text_encoder": TorchAoConfig("int8wo")
|
||||
}
|
||||
)
|
||||
|
||||
# torchao fp4 (e2m1)
|
||||
elif quant_method == 'torchao-fp4':
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": TorchAoConfig("fp4_e2m1"),
|
||||
"text_encoder": TorchAoConfig("fp4_e2m1")
|
||||
}
|
||||
)
|
||||
|
||||
# torchao float8 weight-only
|
||||
elif quant_method == 'torchao-float8wo':
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": TorchAoConfig("float8wo"),
|
||||
"text_encoder": TorchAoConfig("float8wo")
|
||||
"transformer": TorchAoConfig(ao_type),
|
||||
"text_encoder": TorchAoConfig(ao_type)
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue