mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-03 07:10:11 +01:00
Initial Qwen-Image support
This commit is contained in:
parent
225b8c326b
commit
f46f49e26c
|
|
@ -75,6 +75,22 @@ def get_quantization_config(quant_method):
|
|||
return None
|
||||
|
||||
|
||||
def get_pipeline_type(pipe):
|
||||
"""
|
||||
Detect the pipeline type based on the loaded pipeline class.
|
||||
|
||||
Returns:
|
||||
str: 'zimage', 'qwenimage', or 'unknown'
|
||||
"""
|
||||
class_name = pipe.__class__.__name__
|
||||
if 'ZImage' in class_name:
|
||||
return 'zimage'
|
||||
elif 'QwenImage' in class_name:
|
||||
return 'qwenimage'
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
|
||||
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False, quant_method='none'):
|
||||
"""
|
||||
Load a diffusers image generation model.
|
||||
|
|
@ -88,7 +104,7 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
|
|||
quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
|
||||
"""
|
||||
import torch
|
||||
from diffusers import ZImagePipeline
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}")
|
||||
t0 = time.time()
|
||||
|
|
@ -111,30 +127,37 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
|
|||
if pipeline_quant_config is not None:
|
||||
load_kwargs["quantization_config"] = pipeline_quant_config
|
||||
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
# Use DiffusionPipeline for automatic pipeline detection
|
||||
# This handles both ZImagePipeline and QwenImagePipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
str(model_path),
|
||||
**load_kwargs
|
||||
)
|
||||
|
||||
pipeline_type = get_pipeline_type(pipe)
|
||||
|
||||
if not cpu_offload:
|
||||
pipe.to(get_device())
|
||||
|
||||
# Set attention backend
|
||||
if attn_backend == 'flash_attention_2':
|
||||
pipe.transformer.set_attention_backend("flash")
|
||||
elif attn_backend == 'flash_attention_3':
|
||||
pipe.transformer.set_attention_backend("_flash_3")
|
||||
# sdpa is the default, no action needed
|
||||
# Set attention backend (if supported by the pipeline)
|
||||
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'):
|
||||
if attn_backend == 'flash_attention_2':
|
||||
pipe.transformer.set_attention_backend("flash")
|
||||
elif attn_backend == 'flash_attention_3':
|
||||
pipe.transformer.set_attention_backend("_flash_3")
|
||||
# sdpa is the default, no action needed
|
||||
|
||||
if compile_model:
|
||||
logger.info("Compiling model (first run will be slow)...")
|
||||
pipe.transformer.compile()
|
||||
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'compile'):
|
||||
logger.info("Compiling model (first run will be slow)...")
|
||||
pipe.transformer.compile()
|
||||
|
||||
if cpu_offload:
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
shared.image_model = pipe
|
||||
shared.image_model_name = model_name
|
||||
shared.image_pipeline_type = pipeline_type
|
||||
|
||||
logger.info(f"Loaded image model \"{model_name}\" in {(time.time() - t0):.2f} seconds.")
|
||||
return pipe
|
||||
|
|
@ -152,6 +175,7 @@ def unload_image_model():
|
|||
del shared.image_model
|
||||
shared.image_model = None
|
||||
shared.image_model_name = 'None'
|
||||
shared.image_pipeline_type = None
|
||||
|
||||
from modules.torch_utils import clear_torch_cache
|
||||
clear_torch_cache()
|
||||
|
|
|
|||
|
|
@ -604,7 +604,12 @@ def create_event_handlers():
|
|||
|
||||
|
||||
def generate(state):
|
||||
"""
|
||||
Generate images using the loaded model.
|
||||
Automatically adjusts parameters based on pipeline type.
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
model_name = state['image_model_menu']
|
||||
|
||||
|
|
@ -634,19 +639,34 @@ def generate(state):
|
|||
generator = torch.Generator("cuda").manual_seed(int(seed))
|
||||
all_images = []
|
||||
|
||||
# Get pipeline type for parameter adjustment
|
||||
pipeline_type = getattr(shared, 'image_pipeline_type', None)
|
||||
if pipeline_type is None:
|
||||
pipeline_type = get_pipeline_type(shared.image_model)
|
||||
|
||||
# Build generation kwargs based on pipeline type
|
||||
gen_kwargs = {
|
||||
"prompt": state['image_prompt'],
|
||||
"negative_prompt": state['image_neg_prompt'],
|
||||
"height": int(state['image_height']),
|
||||
"width": int(state['image_width']),
|
||||
"num_inference_steps": int(state['image_steps']),
|
||||
"num_images_per_prompt": int(state['image_batch_size']),
|
||||
"generator": generator,
|
||||
}
|
||||
|
||||
# Add pipeline-specific parameters
|
||||
if pipeline_type == 'qwenimage':
|
||||
# Qwen-Image uses true_cfg_scale instead of guidance_scale
|
||||
gen_kwargs["true_cfg_scale"] = state.get('image_cfg_scale', 4.0)
|
||||
else:
|
||||
# Z-Image and others use guidance_scale
|
||||
gen_kwargs["guidance_scale"] = state.get('image_cfg_scale', 0.0)
|
||||
|
||||
t0 = time.time()
|
||||
for i in range(int(state['image_batch_count'])):
|
||||
generator.manual_seed(int(seed + i))
|
||||
batch_results = shared.image_model(
|
||||
prompt=state['image_prompt'],
|
||||
negative_prompt=state['image_neg_prompt'],
|
||||
height=int(state['image_height']),
|
||||
width=int(state['image_width']),
|
||||
num_inference_steps=int(state['image_steps']),
|
||||
guidance_scale=0.0,
|
||||
num_images_per_prompt=int(state['image_batch_size']),
|
||||
generator=generator,
|
||||
).images
|
||||
batch_results = shared.image_model(**gen_kwargs).images
|
||||
all_images.extend(batch_results)
|
||||
|
||||
t1 = time.time()
|
||||
|
|
|
|||
Loading…
Reference in a new issue