From f46f49e26c9bd945f73c2c9b1346803d67bc96b2 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 1 Dec 2025 18:18:15 -0800 Subject: [PATCH] Initial Qwen-Image support --- modules/image_models.py | 44 ++++++++++++++++++++++++++-------- modules/ui_image_generation.py | 40 +++++++++++++++++++++++-------- 2 files changed, 64 insertions(+), 20 deletions(-) diff --git a/modules/image_models.py b/modules/image_models.py index fe149253..e4831758 100644 --- a/modules/image_models.py +++ b/modules/image_models.py @@ -75,6 +75,22 @@ def get_quantization_config(quant_method): return None +def get_pipeline_type(pipe): + """ + Detect the pipeline type based on the loaded pipeline class. + + Returns: + str: 'zimage', 'qwenimage', or 'unknown' + """ + class_name = pipe.__class__.__name__ + if 'ZImage' in class_name: + return 'zimage' + elif 'QwenImage' in class_name: + return 'qwenimage' + else: + return 'unknown' + + def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False, quant_method='none'): """ Load a diffusers image generation model. @@ -88,7 +104,7 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit' """ import torch - from diffusers import ZImagePipeline + from diffusers import DiffusionPipeline logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}") t0 = time.time() @@ -111,30 +127,37 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl if pipeline_quant_config is not None: load_kwargs["quantization_config"] = pipeline_quant_config - pipe = ZImagePipeline.from_pretrained( + # Use DiffusionPipeline for automatic pipeline detection + # This handles both ZImagePipeline and QwenImagePipeline + pipe = DiffusionPipeline.from_pretrained( str(model_path), **load_kwargs ) + pipeline_type = get_pipeline_type(pipe) + if not cpu_offload: pipe.to(get_device()) - # Set attention backend - if attn_backend == 'flash_attention_2': - pipe.transformer.set_attention_backend("flash") - elif attn_backend == 'flash_attention_3': - pipe.transformer.set_attention_backend("_flash_3") - # sdpa is the default, no action needed + # Set attention backend (if supported by the pipeline) + if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'): + if attn_backend == 'flash_attention_2': + pipe.transformer.set_attention_backend("flash") + elif attn_backend == 'flash_attention_3': + pipe.transformer.set_attention_backend("_flash_3") + # sdpa is the default, no action needed if compile_model: - logger.info("Compiling model (first run will be slow)...") - pipe.transformer.compile() + if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'compile'): + logger.info("Compiling model (first run will be slow)...") + pipe.transformer.compile() if cpu_offload: pipe.enable_model_cpu_offload() shared.image_model = pipe shared.image_model_name = model_name + shared.image_pipeline_type = pipeline_type logger.info(f"Loaded image model \"{model_name}\" in {(time.time() - t0):.2f} seconds.") return pipe @@ -152,6 +175,7 @@ def unload_image_model(): del shared.image_model shared.image_model = None shared.image_model_name = 'None' + shared.image_pipeline_type = None from modules.torch_utils import clear_torch_cache clear_torch_cache() diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py index 09faf423..42c8c21f 100644 --- a/modules/ui_image_generation.py +++ b/modules/ui_image_generation.py @@ -604,7 +604,12 @@ def create_event_handlers(): def generate(state): + """ + Generate images using the loaded model. + Automatically adjusts parameters based on pipeline type. + """ import torch + import numpy as np model_name = state['image_model_menu'] @@ -634,19 +639,34 @@ def generate(state): generator = torch.Generator("cuda").manual_seed(int(seed)) all_images = [] + # Get pipeline type for parameter adjustment + pipeline_type = getattr(shared, 'image_pipeline_type', None) + if pipeline_type is None: + pipeline_type = get_pipeline_type(shared.image_model) + + # Build generation kwargs based on pipeline type + gen_kwargs = { + "prompt": state['image_prompt'], + "negative_prompt": state['image_neg_prompt'], + "height": int(state['image_height']), + "width": int(state['image_width']), + "num_inference_steps": int(state['image_steps']), + "num_images_per_prompt": int(state['image_batch_size']), + "generator": generator, + } + + # Add pipeline-specific parameters + if pipeline_type == 'qwenimage': + # Qwen-Image uses true_cfg_scale instead of guidance_scale + gen_kwargs["true_cfg_scale"] = state.get('image_cfg_scale', 4.0) + else: + # Z-Image and others use guidance_scale + gen_kwargs["guidance_scale"] = state.get('image_cfg_scale', 0.0) + t0 = time.time() for i in range(int(state['image_batch_count'])): generator.manual_seed(int(seed + i)) - batch_results = shared.image_model( - prompt=state['image_prompt'], - negative_prompt=state['image_neg_prompt'], - height=int(state['image_height']), - width=int(state['image_width']), - num_inference_steps=int(state['image_steps']), - guidance_scale=0.0, - num_images_per_prompt=int(state['image_batch_size']), - generator=generator, - ).images + batch_results = shared.image_model(**gen_kwargs).images all_images.extend(batch_results) t1 = time.time()