Initial Qwen-Image support

This commit is contained in:
oobabooga 2025-12-01 18:18:15 -08:00
parent 225b8c326b
commit f46f49e26c
2 changed files with 64 additions and 20 deletions

View file

@ -75,6 +75,22 @@ def get_quantization_config(quant_method):
return None
def get_pipeline_type(pipe):
"""
Detect the pipeline type based on the loaded pipeline class.
Returns:
str: 'zimage', 'qwenimage', or 'unknown'
"""
class_name = pipe.__class__.__name__
if 'ZImage' in class_name:
return 'zimage'
elif 'QwenImage' in class_name:
return 'qwenimage'
else:
return 'unknown'
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False, quant_method='none'):
"""
Load a diffusers image generation model.
@ -88,7 +104,7 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
"""
import torch
from diffusers import ZImagePipeline
from diffusers import DiffusionPipeline
logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}")
t0 = time.time()
@ -111,30 +127,37 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
if pipeline_quant_config is not None:
load_kwargs["quantization_config"] = pipeline_quant_config
pipe = ZImagePipeline.from_pretrained(
# Use DiffusionPipeline for automatic pipeline detection
# This handles both ZImagePipeline and QwenImagePipeline
pipe = DiffusionPipeline.from_pretrained(
str(model_path),
**load_kwargs
)
pipeline_type = get_pipeline_type(pipe)
if not cpu_offload:
pipe.to(get_device())
# Set attention backend
if attn_backend == 'flash_attention_2':
pipe.transformer.set_attention_backend("flash")
elif attn_backend == 'flash_attention_3':
pipe.transformer.set_attention_backend("_flash_3")
# sdpa is the default, no action needed
# Set attention backend (if supported by the pipeline)
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'):
if attn_backend == 'flash_attention_2':
pipe.transformer.set_attention_backend("flash")
elif attn_backend == 'flash_attention_3':
pipe.transformer.set_attention_backend("_flash_3")
# sdpa is the default, no action needed
if compile_model:
logger.info("Compiling model (first run will be slow)...")
pipe.transformer.compile()
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'compile'):
logger.info("Compiling model (first run will be slow)...")
pipe.transformer.compile()
if cpu_offload:
pipe.enable_model_cpu_offload()
shared.image_model = pipe
shared.image_model_name = model_name
shared.image_pipeline_type = pipeline_type
logger.info(f"Loaded image model \"{model_name}\" in {(time.time() - t0):.2f} seconds.")
return pipe
@ -152,6 +175,7 @@ def unload_image_model():
del shared.image_model
shared.image_model = None
shared.image_model_name = 'None'
shared.image_pipeline_type = None
from modules.torch_utils import clear_torch_cache
clear_torch_cache()

View file

@ -604,7 +604,12 @@ def create_event_handlers():
def generate(state):
"""
Generate images using the loaded model.
Automatically adjusts parameters based on pipeline type.
"""
import torch
import numpy as np
model_name = state['image_model_menu']
@ -634,19 +639,34 @@ def generate(state):
generator = torch.Generator("cuda").manual_seed(int(seed))
all_images = []
# Get pipeline type for parameter adjustment
pipeline_type = getattr(shared, 'image_pipeline_type', None)
if pipeline_type is None:
pipeline_type = get_pipeline_type(shared.image_model)
# Build generation kwargs based on pipeline type
gen_kwargs = {
"prompt": state['image_prompt'],
"negative_prompt": state['image_neg_prompt'],
"height": int(state['image_height']),
"width": int(state['image_width']),
"num_inference_steps": int(state['image_steps']),
"num_images_per_prompt": int(state['image_batch_size']),
"generator": generator,
}
# Add pipeline-specific parameters
if pipeline_type == 'qwenimage':
# Qwen-Image uses true_cfg_scale instead of guidance_scale
gen_kwargs["true_cfg_scale"] = state.get('image_cfg_scale', 4.0)
else:
# Z-Image and others use guidance_scale
gen_kwargs["guidance_scale"] = state.get('image_cfg_scale', 0.0)
t0 = time.time()
for i in range(int(state['image_batch_count'])):
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(
prompt=state['image_prompt'],
negative_prompt=state['image_neg_prompt'],
height=int(state['image_height']),
width=int(state['image_width']),
num_inference_steps=int(state['image_steps']),
guidance_scale=0.0,
num_images_per_prompt=int(state['image_batch_size']),
generator=generator,
).images
batch_results = shared.image_model(**gen_kwargs).images
all_images.extend(batch_results)
t1 = time.time()