Compare commits

...

16 commits

Author SHA1 Message Date
oobabooga f45412676d Minor label changes 2025-12-01 19:16:53 -08:00
oobabooga 5c61fcf479 Autosave on prompt change 2025-12-01 18:59:11 -08:00
oobabooga d75d7a3a63 Add a CFG scale slider, add qwen3 magic 2025-12-01 18:41:15 -08:00
oobabooga 151b552bc3 Decrease the resolution step to allow for 1368 2025-12-01 18:24:02 -08:00
oobabooga 322aab3410 Increase the image_steps maximum 2025-12-01 18:20:47 -08:00
oobabooga f46f49e26c Initial Qwen-Image support 2025-12-01 18:18:15 -08:00
oobabooga 225b8c326b Try to not break portable builds 2025-12-01 17:13:16 -08:00
oobabooga 5fb1380ac1 Handle URLs like https://huggingface.co/Qwen/Qwen-Image 2025-12-01 17:09:32 -08:00
oobabooga 7dfb6e9c57 Add quantization options (bnb and quanto) 2025-12-01 17:05:42 -08:00
oobabooga a7808f7f42 Make filenames always have the same size 2025-12-01 16:06:02 -08:00
oobabooga 748e2e55fd Add steps/second info to log message 2025-12-01 15:44:31 -08:00
oobabooga 6a7209a842 Add PNG metadata, add pagination to Gallery tab 2025-12-01 15:41:58 -08:00
oobabooga c8e9d7fc37 Fix the gallery height after the previous commit 2025-12-01 14:00:41 -08:00
oobabooga b4738beaf8 Remove the seed UI element 2025-12-01 13:59:10 -08:00
oobabooga 75796f5a58 Set gallery heights 2025-12-01 13:44:18 -08:00
oobabooga 990f0e2468 Revert "Revert CSS changes"
This reverts commit 366fe353f0.
2025-12-01 13:29:22 -08:00
6 changed files with 601 additions and 100 deletions

View file

@ -1681,3 +1681,59 @@ button#swap-height-width {
right: 0;
border: 0;
}
#image-output-gallery, #image-output-gallery > :nth-child(2) {
height: calc(100vh - 83px);
max-height: calc(100vh - 83px);
}
#image-history-gallery, #image-history-gallery > :nth-child(2) {
height: calc(100vh - 174px);
max-height: calc(100vh - 174px);
}
/* Additional CSS for the paginated image gallery */
/* Page info styling */
#image-page-info {
display: flex;
align-items: center;
justify-content: center;
min-width: 200px;
font-size: 0.9em;
color: var(--body-text-color-subdued);
}
/* Settings display panel */
#image-ai-tab .settings-display-panel {
background: var(--background-fill-secondary);
padding: 12px;
border-radius: 8px;
font-size: 0.9em;
max-height: 300px;
overflow-y: auto;
margin-top: 8px;
}
/* Gallery status message */
#image-ai-tab .gallery-status {
color: var(--color-accent);
font-size: 0.85em;
margin-top: 4px;
}
/* Pagination button row alignment */
#image-ai-tab .pagination-controls {
display: flex;
align-items: center;
gap: 8px;
flex-wrap: wrap;
}
/* Selected image preview container */
#image-ai-tab .selected-preview-container {
border: 1px solid var(--border-color-primary);
border-radius: 8px;
padding: 8px;
background: var(--background-fill-secondary);
}

View file

@ -1,14 +1,97 @@
import time
import torch
import modules.shared as shared
from modules.logging_colors import logger
from modules.torch_utils import get_device
from modules.utils import resolve_model_path
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False):
def get_quantization_config(quant_method):
"""
Get the appropriate quantization config based on the selected method.
Args:
quant_method: One of 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
Returns:
PipelineQuantizationConfig or None
"""
import torch
from diffusers import BitsAndBytesConfig, QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
if quant_method == 'none' or not quant_method:
return None
# Bitsandbytes 8-bit quantization
elif quant_method == 'bnb-8bit':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": BitsAndBytesConfig(
load_in_8bit=True
)
}
)
# Bitsandbytes 4-bit quantization
elif quant_method == 'bnb-4bit':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
}
)
# Quanto 8-bit quantization
elif quant_method == 'quanto-8bit':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8")
}
)
# Quanto 4-bit quantization
elif quant_method == 'quanto-4bit':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int4")
}
)
# Quanto 2-bit quantization
elif quant_method == 'quanto-2bit':
return PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int2")
}
)
else:
logger.warning(f"Unknown quantization method: {quant_method}. Loading without quantization.")
return None
def get_pipeline_type(pipe):
"""
Detect the pipeline type based on the loaded pipeline class.
Returns:
str: 'zimage', 'qwenimage', or 'unknown'
"""
class_name = pipe.__class__.__name__
if 'ZImage' in class_name:
return 'zimage'
elif 'QwenImage' in class_name:
return 'qwenimage'
else:
return 'unknown'
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False, quant_method='none'):
"""
Load a diffusers image generation model.
@ -18,10 +101,12 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3'
cpu_offload: Enable CPU offloading for low VRAM
compile_model: Compile the model for faster inference (slow first run)
quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
"""
from diffusers import PipelineQuantizationConfig, ZImagePipeline
import torch
from diffusers import DiffusionPipeline
logger.info(f"Loading image model \"{model_name}\"")
logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}")
t0 = time.time()
dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
@ -30,49 +115,49 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
model_path = resolve_model_path(model_name, image_model=True)
try:
# Define quantization config for 8-bit
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
)
# Get quantization config based on selected method
pipeline_quant_config = get_quantization_config(quant_method)
# Define quantization config for 4-bit
# pipeline_quant_config = PipelineQuantizationConfig(
# quant_backend="bitsandbytes_4bit",
# quant_kwargs={
# "load_in_4bit": True,
# "bnb_4bit_quant_type": "nf4", # Or "fp4" for floating point
# "bnb_4bit_compute_dtype": torch.bfloat16, # For faster computation
# "bnb_4bit_use_double_quant": True, # Nested quantization for extra savings
# },
# )
# Load the pipeline
load_kwargs = {
"torch_dtype": target_dtype,
"low_cpu_mem_usage": True,
}
pipe = ZImagePipeline.from_pretrained(
if pipeline_quant_config is not None:
load_kwargs["quantization_config"] = pipeline_quant_config
# Use DiffusionPipeline for automatic pipeline detection
# This handles both ZImagePipeline and QwenImagePipeline
pipe = DiffusionPipeline.from_pretrained(
str(model_path),
quantization_config=pipeline_quant_config,
torch_dtype=target_dtype,
low_cpu_mem_usage=True,
**load_kwargs
)
pipeline_type = get_pipeline_type(pipe)
if not cpu_offload:
pipe.to(get_device())
# Set attention backend
if attn_backend == 'flash_attention_2':
pipe.transformer.set_attention_backend("flash")
elif attn_backend == 'flash_attention_3':
pipe.transformer.set_attention_backend("_flash_3")
# sdpa is the default, no action needed
# Set attention backend (if supported by the pipeline)
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'):
if attn_backend == 'flash_attention_2':
pipe.transformer.set_attention_backend("flash")
elif attn_backend == 'flash_attention_3':
pipe.transformer.set_attention_backend("_flash_3")
# sdpa is the default, no action needed
if compile_model:
logger.info("Compiling model (first run will be slow)...")
pipe.transformer.compile()
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'compile'):
logger.info("Compiling model (first run will be slow)...")
pipe.transformer.compile()
if cpu_offload:
pipe.enable_model_cpu_offload()
shared.image_model = pipe
shared.image_model_name = model_name
shared.image_pipeline_type = pipeline_type
logger.info(f"Loaded image model \"{model_name}\" in {(time.time() - t0):.2f} seconds.")
return pipe
@ -90,6 +175,7 @@ def unload_image_model():
del shared.image_model
shared.image_model = None
shared.image_model_name = 'None'
shared.image_pipeline_type = None
from modules.torch_utils import clear_torch_cache
clear_torch_cache()

View file

@ -58,6 +58,9 @@ group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16',
group.add_argument('--image-attn-backend', type=str, default=None, choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.')
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.')
group.add_argument('--image-quant', type=str, default=None,
choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'],
help='Quantization method for image model.')
# Model loader
group = parser.add_argument_group('Model loader')
@ -311,14 +314,16 @@ settings = {
'image_height': 1024,
'image_aspect_ratio': '1:1 Square',
'image_steps': 9,
'image_cfg_scale': 0.0,
'image_seed': -1,
'image_batch_size': 1,
'image_batch_count': 1,
'image_model_menu': 'None',
'image_dtype': 'bfloat16',
'image_attn_backend': 'sdpa',
'image_compile': False,
'image_cpu_offload': False,
'image_compile': False,
'image_quant': 'none',
}
default_settings = copy.deepcopy(settings)
@ -344,8 +349,8 @@ def do_cmd_flags_warnings():
def apply_image_model_cli_overrides():
"""Apply CLI flags for image model settings, overriding saved settings."""
if args.image_model:
"""Apply command-line overrides for image model settings."""
if args.image_model is not None:
settings['image_model_menu'] = args.image_model
if args.image_dtype is not None:
settings['image_dtype'] = args.image_dtype
@ -355,6 +360,9 @@ def apply_image_model_cli_overrides():
settings['image_cpu_offload'] = True
if args.image_compile:
settings['image_compile'] = True
if args.image_quant is not None:
settings['image_quant'] = args.image_quant
def fix_loader_name(name):

View file

@ -288,6 +288,7 @@ def list_interface_input_elements():
'image_height',
'image_aspect_ratio',
'image_steps',
'image_cfg_scale',
'image_seed',
'image_batch_size',
'image_batch_count',
@ -296,6 +297,7 @@ def list_interface_input_elements():
'image_attn_backend',
'image_compile',
'image_cpu_offload',
'image_quant',
]
return elements
@ -530,10 +532,13 @@ def setup_auto_save():
'include_past_attachments',
# Image generation tab (ui_image_generation.py)
'image_prompt',
'image_neg_prompt',
'image_width',
'image_height',
'image_aspect_ratio',
'image_steps',
'image_cfg_scale',
'image_seed',
'image_batch_size',
'image_batch_count',
@ -542,6 +547,7 @@ def setup_auto_save():
'image_attn_backend',
'image_compile',
'image_cpu_offload',
'image_quant',
]
for element_name in change_elements:

View file

@ -1,14 +1,18 @@
import json
import os
import time
import traceback
from datetime import datetime
from pathlib import Path
import gradio as gr
import numpy as np
import torch
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from modules import shared, ui, utils
from modules.image_models import load_image_model, unload_image_model
from modules.logging_colors import logger
from modules.utils import gradio
ASPECT_RATIOS = {
@ -19,7 +23,26 @@ ASPECT_RATIOS = {
"Custom": None,
}
STEP = 32
STEP = 16
IMAGES_PER_PAGE = 64
# Settings keys to save in PNG metadata (Generate tab only)
METADATA_SETTINGS_KEYS = [
'image_prompt',
'image_neg_prompt',
'image_width',
'image_height',
'image_aspect_ratio',
'image_steps',
'image_seed',
'image_batch_size',
'image_batch_count',
'image_cfg_scale',
]
# Cache for all image paths
_image_cache = []
_cache_timestamp = 0
def round_to_step(value, step=STEP):
@ -91,6 +114,216 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
return new_width, new_height, new_ratio
def build_generation_metadata(state, actual_seed):
"""Build metadata dict from generation settings."""
metadata = {}
for key in METADATA_SETTINGS_KEYS:
if key in state:
metadata[key] = state[key]
# Store the actual seed used (not -1)
metadata['image_seed'] = actual_seed
metadata['generated_at'] = datetime.now().isoformat()
metadata['model'] = shared.image_model_name
return metadata
def save_generated_images(images, state, actual_seed):
"""Save images with generation metadata embedded in PNG."""
date_str = datetime.now().strftime("%Y-%m-%d")
folder_path = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder_path, exist_ok=True)
metadata = build_generation_metadata(state, actual_seed)
metadata_json = json.dumps(metadata, ensure_ascii=False)
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
filename = f"{timestamp}_{actual_seed:010d}_{idx:03d}.png"
filepath = os.path.join(folder_path, filename)
# Create PNG metadata
png_info = PngInfo()
png_info.add_text("image_gen_settings", metadata_json)
# Save with metadata
img.save(filepath, pnginfo=png_info)
def read_image_metadata(image_path):
"""Read generation metadata from PNG file."""
try:
with Image.open(image_path) as img:
if hasattr(img, 'text') and 'image_gen_settings' in img.text:
return json.loads(img.text['image_gen_settings'])
except Exception as e:
logger.debug(f"Could not read metadata from {image_path}: {e}")
return None
def format_metadata_for_display(metadata):
"""Format metadata as readable text."""
if not metadata:
return "No generation settings found in this image."
lines = ["**Generation Settings**", ""]
# Display in a nice order
display_order = [
('image_prompt', 'Prompt'),
('image_neg_prompt', 'Negative Prompt'),
('image_width', 'Width'),
('image_height', 'Height'),
('image_aspect_ratio', 'Aspect Ratio'),
('image_steps', 'Steps'),
('image_cfg_scale', 'CFG Scale'),
('image_seed', 'Seed'),
('image_batch_size', 'Batch Size'),
('image_batch_count', 'Batch Count'),
('model', 'Model'),
('generated_at', 'Generated At'),
]
for key, label in display_order:
if key in metadata:
value = metadata[key]
if key in ['image_prompt', 'image_neg_prompt'] and value:
# Truncate long prompts for display
if len(str(value)) > 200:
value = str(value)[:200] + "..."
lines.append(f"**{label}:** {value}")
return "\n\n".join(lines)
def get_all_history_images(force_refresh=False):
"""Get all history images sorted by modification time (newest first). Uses caching."""
global _image_cache, _cache_timestamp
output_dir = os.path.join("user_data", "image_outputs")
if not os.path.exists(output_dir):
return []
# Check if we need to refresh cache
current_time = time.time()
if not force_refresh and _image_cache and (current_time - _cache_timestamp) < 2:
return _image_cache
image_files = []
for root, _, files in os.walk(output_dir):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
full_path = os.path.join(root, file)
image_files.append((full_path, os.path.getmtime(full_path)))
image_files.sort(key=lambda x: x[1], reverse=True)
_image_cache = [x[0] for x in image_files]
_cache_timestamp = current_time
return _image_cache
def get_paginated_images(page=0, force_refresh=False):
"""Get images for a specific page."""
all_images = get_all_history_images(force_refresh)
total_images = len(all_images)
total_pages = max(1, (total_images + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE)
# Clamp page to valid range
page = max(0, min(page, total_pages - 1))
start_idx = page * IMAGES_PER_PAGE
end_idx = min(start_idx + IMAGES_PER_PAGE, total_images)
page_images = all_images[start_idx:end_idx]
return page_images, page, total_pages, total_images
def refresh_gallery(current_page=0):
"""Refresh gallery with current page."""
images, page, total_pages, total_images = get_paginated_images(current_page, force_refresh=True)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def go_to_page(page_num, current_page):
"""Go to a specific page (1-indexed input)."""
try:
page = int(page_num) - 1 # Convert to 0-indexed
except (ValueError, TypeError):
page = current_page
images, page, total_pages, total_images = get_paginated_images(page)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def next_page(current_page):
"""Go to next page."""
images, page, total_pages, total_images = get_paginated_images(current_page + 1)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def prev_page(current_page):
"""Go to previous page."""
images, page, total_pages, total_images = get_paginated_images(current_page - 1)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def on_gallery_select(evt: gr.SelectData, current_page):
"""Handle image selection from gallery."""
if evt.index is None:
return "", "Select an image to view its settings"
# Get the current page's images to find the actual file path
all_images = get_all_history_images()
total_images = len(all_images)
# Calculate the actual index in the full list
start_idx = current_page * IMAGES_PER_PAGE
actual_idx = start_idx + evt.index
if actual_idx >= total_images:
return "", "Image not found"
image_path = all_images[actual_idx]
metadata = read_image_metadata(image_path)
metadata_display = format_metadata_for_display(metadata)
return image_path, metadata_display
def send_to_generate(selected_image_path):
"""Load settings from selected image and return updates for all Generate tab inputs."""
if not selected_image_path or not os.path.exists(selected_image_path):
return [gr.update()] * 10 + ["No image selected"]
metadata = read_image_metadata(selected_image_path)
if not metadata:
return [gr.update()] * 10 + ["No settings found in this image"]
# Return updates for each input element in order
updates = [
gr.update(value=metadata.get('image_prompt', '')),
gr.update(value=metadata.get('image_neg_prompt', '')),
gr.update(value=metadata.get('image_width', 1024)),
gr.update(value=metadata.get('image_height', 1024)),
gr.update(value=metadata.get('image_aspect_ratio', '1:1 Square')),
gr.update(value=metadata.get('image_steps', 9)),
gr.update(value=metadata.get('image_seed', -1)),
gr.update(value=metadata.get('image_batch_size', 1)),
gr.update(value=metadata.get('image_batch_count', 1)),
gr.update(value=metadata.get('image_cfg_scale', 0.0)),
]
status = f"✓ Settings loaded from image (seed: {metadata.get('image_seed', 'unknown')})"
return updates + [status]
def create_ui():
if shared.settings['image_model_menu'] != 'None':
shared.image_model_name = shared.settings['image_model_menu']
@ -115,15 +348,15 @@ def create_ui():
value=shared.settings['image_neg_prompt']
)
shared.gradio['image_generate_btn'] = gr.Button("GENERATE", variant="primary", size="lg", elem_id="gen-btn")
shared.gradio['image_generate_btn'] = gr.Button("GENERATE", variant="primary", size="lg", elem_id="gen-btn")
gr.HTML("<hr style='border-top: 1px solid #444; margin: 20px 0;'>")
gr.Markdown("### Dimensions")
with gr.Row():
with gr.Column():
shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=32, label="Width")
shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=STEP, label="Width")
with gr.Column():
shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=32, label="Height")
shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=STEP, label="Height")
shared.gradio['image_swap_btn'] = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width")
with gr.Row():
@ -137,7 +370,14 @@ def create_ui():
gr.Markdown("### Config")
with gr.Row():
with gr.Column():
shared.gradio['image_steps'] = gr.Slider(1, 15, value=shared.settings['image_steps'], step=1, label="Steps")
shared.gradio['image_steps'] = gr.Slider(1, 100, value=shared.settings['image_steps'], step=1, label="Steps")
shared.gradio['image_cfg_scale'] = gr.Slider(
0.0, 10.0,
value=0.0,
step=0.1,
label="CFG Scale",
info="Z-Image Turbo: 0.0 | Qwen: 4.0"
)
shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random")
with gr.Column():
shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
@ -146,14 +386,41 @@ def create_ui():
with gr.Column(scale=6, min_width=500):
with gr.Column(elem_classes=["viewport-container"]):
shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery")
with gr.Row():
shared.gradio['image_used_seed'] = gr.Markdown(label="Info", interactive=False)
# TAB 2: GALLERY
# TAB 2: GALLERY (with pagination)
with gr.TabItem("Gallery"):
with gr.Row():
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
shared.gradio['image_history_gallery'] = gr.Gallery(value=lambda : get_history_images(), label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True, elem_id="image-history-gallery")
with gr.Column(scale=3):
# Pagination controls
with gr.Row():
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh", elem_classes="refresh-button")
shared.gradio['image_prev_page'] = gr.Button("◀ Prev", elem_classes="refresh-button")
shared.gradio['image_page_info'] = gr.Markdown("Loading...", elem_id="image-page-info")
shared.gradio['image_next_page'] = gr.Button("Next ▶", elem_classes="refresh-button")
shared.gradio['image_page_input'] = gr.Number(value=1, label="Page", precision=0, minimum=1, scale=0, min_width=80)
shared.gradio['image_go_to_page'] = gr.Button("Go", elem_classes="refresh-button", scale=0, min_width=50)
# State for current page and selected image path
shared.gradio['image_current_page'] = gr.State(value=0)
shared.gradio['image_selected_path'] = gr.State(value="")
# Paginated gallery using gr.Gallery
shared.gradio['image_history_gallery'] = gr.Gallery(
value=lambda: get_paginated_images(0)[0],
label="Image History",
show_label=False,
columns=6,
object_fit="cover",
height="auto",
allow_preview=True,
elem_id="image-history-gallery"
)
with gr.Column(scale=1):
gr.Markdown("### Selected Image")
shared.gradio['image_settings_display'] = gr.Markdown("Select an image to view its settings")
shared.gradio['image_send_to_generate'] = gr.Button("Send to Generate", variant="primary")
shared.gradio['image_gallery_status'] = gr.Markdown("")
# TAB 3: MODEL
with gr.TabItem("Model"):
@ -173,6 +440,13 @@ def create_ui():
gr.Markdown("## Settings")
with gr.Row():
with gr.Column():
shared.gradio['image_quant'] = gr.Dropdown(
label='Quantization',
choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'],
value=shared.settings['image_quant'],
info='Quantization method for reduced VRAM usage. Quanto supports lower precisions (2-bit, 4-bit, 8-bit).'
)
shared.gradio['image_dtype'] = gr.Dropdown(
choices=['bfloat16', 'float16'],
value=shared.settings['image_dtype'],
@ -242,15 +516,15 @@ def create_event_handlers():
# Generation
shared.gradio['image_generate_btn'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
generate, gradio('interface_state'), gradio('image_output_gallery'))
shared.gradio['image_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
generate, gradio('interface_state'), gradio('image_output_gallery'))
shared.gradio['image_neg_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
generate, gradio('interface_state'), gradio('image_output_gallery'))
# Model management
shared.gradio['image_refresh_models'].click(
@ -262,7 +536,7 @@ def create_event_handlers():
shared.gradio['image_load_model'].click(
load_image_model_wrapper,
gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile'),
gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile', 'image_quant'),
gradio('image_model_status'),
show_progress=True
)
@ -281,20 +555,77 @@ def create_event_handlers():
show_progress=True
)
# History
# Gallery pagination handlers
shared.gradio['image_refresh_history'].click(
get_history_images,
None,
gradio('image_history_gallery'),
refresh_gallery,
gradio('image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
shared.gradio['image_next_page'].click(
next_page,
gradio('image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
shared.gradio['image_prev_page'].click(
prev_page,
gradio('image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
shared.gradio['image_go_to_page'].click(
go_to_page,
gradio('image_page_input', 'image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
# Image selection from gallery
shared.gradio['image_history_gallery'].select(
on_gallery_select,
gradio('image_current_page'),
gradio('image_selected_path', 'image_settings_display'),
show_progress=False
)
# Send to Generate
shared.gradio['image_send_to_generate'].click(
send_to_generate,
gradio('image_selected_path'),
gradio(
'image_prompt',
'image_neg_prompt',
'image_width',
'image_height',
'image_aspect_ratio',
'image_steps',
'image_seed',
'image_batch_size',
'image_batch_count',
'image_cfg_scale',
'image_gallery_status'
),
show_progress=False
)
def generate(state):
"""
Generate images using the loaded model.
Automatically adjusts parameters based on pipeline type.
"""
import torch
import numpy as np
model_name = state['image_model_menu']
if not model_name or model_name == 'None':
return [], "No image model selected. Go to the Model tab and select a model."
logger.error("No image model selected. Go to the Model tab and select a model.")
return []
if shared.image_model is None:
result = load_image_model(
@ -302,10 +633,12 @@ def generate(state):
dtype=state['image_dtype'],
attn_backend=state['image_attn_backend'],
cpu_offload=state['image_cpu_offload'],
compile_model=state['image_compile']
compile_model=state['image_compile'],
quant_method=state['image_quant']
)
if result is None:
return [], f"Failed to load model `{model_name}`."
logger.error(f"Failed to load model `{model_name}`.")
return []
shared.image_model_name = model_name
@ -316,25 +649,56 @@ def generate(state):
generator = torch.Generator("cuda").manual_seed(int(seed))
all_images = []
# Get pipeline type for parameter adjustment
pipeline_type = getattr(shared, 'image_pipeline_type', None)
if pipeline_type is None:
pipeline_type = get_pipeline_type(shared.image_model)
# Process Prompt
prompt = state['image_prompt']
# Apply "Positive Magic" for Qwen models only
if pipeline_type == 'qwenimage':
magic_suffix = ", Ultra HD, 4K, cinematic composition"
# Avoid duplication if user already added it
if magic_suffix.strip(", ") not in prompt:
prompt += magic_suffix
# Build generation kwargs
gen_kwargs = {
"prompt": prompt,
"negative_prompt": state['image_neg_prompt'],
"height": int(state['image_height']),
"width": int(state['image_width']),
"num_inference_steps": int(state['image_steps']),
"num_images_per_prompt": int(state['image_batch_size']),
"generator": generator,
}
# Add pipeline-specific parameters for CFG
cfg_val = state.get('image_cfg_scale', 0.0)
if pipeline_type == 'qwenimage':
# Qwen-Image uses true_cfg_scale (typically 4.0)
gen_kwargs["true_cfg_scale"] = cfg_val
else:
# Z-Image and others use guidance_scale (typically 0.0 for Turbo)
gen_kwargs["guidance_scale"] = cfg_val
t0 = time.time()
for i in range(int(state['image_batch_count'])):
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(
prompt=state['image_prompt'],
negative_prompt=state['image_neg_prompt'],
height=int(state['image_height']),
width=int(state['image_width']),
num_inference_steps=int(state['image_steps']),
guidance_scale=0.0,
num_images_per_prompt=int(state['image_batch_size']),
generator=generator,
).images
batch_results = shared.image_model(**gen_kwargs).images
all_images.extend(batch_results)
save_generated_images(all_images, state['image_prompt'], seed)
return all_images, f"Seed: {seed}"
t1 = time.time()
save_generated_images(all_images, state, seed)
logger.info(f'Images generated in {(t1-t0):.2f} seconds ({state["image_steps"]/(t1-t0):.2f} steps/s, seed {seed})')
return all_images
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model):
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method):
if not model_name or model_name == 'None':
yield "No model selected"
return
@ -348,12 +712,13 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi
dtype=dtype,
attn_backend=attn_backend,
cpu_offload=cpu_offload,
compile_model=compile_model
compile_model=compile_model,
quant_method=quant_method
)
if result is not None:
shared.image_model_name = model_name
yield f"✓ Loaded **{model_name}**"
yield f"✓ Loaded **{model_name}** (quantization: {quant_method})"
else:
yield f"✗ Failed to load `{model_name}`"
except Exception:
@ -375,6 +740,12 @@ def download_image_model_wrapper(model_path):
return
try:
model_path = model_path.strip()
if model_path.startswith('https://huggingface.co/'):
model_path = model_path[len('https://huggingface.co/'):]
elif model_path.startswith('huggingface.co/'):
model_path = model_path[len('huggingface.co/'):]
if ':' in model_path:
model_id, branch = model_path.rsplit(':', 1)
else:
@ -396,30 +767,3 @@ def download_image_model_wrapper(model_path):
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
except Exception:
yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update()
def save_generated_images(images, prompt, seed):
date_str = datetime.now().strftime("%Y-%m-%d")
folder_path = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder_path, exist_ok=True)
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
filename = f"{timestamp}_{seed}_{idx}.png"
img.save(os.path.join(folder_path, filename))
def get_history_images():
output_dir = os.path.join("user_data", "image_outputs")
if not os.path.exists(output_dir):
return []
image_files = []
for root, _, files in os.walk(output_dir):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
full_path = os.path.join(root, file)
image_files.append((full_path, os.path.getmtime(full_path)))
image_files.sort(key=lambda x: x[1], reverse=True)
return [x[0] for x in image_files]

View file

@ -172,7 +172,8 @@ def create_interface():
ui_chat.create_event_handlers()
ui_default.create_event_handlers()
ui_notebook.create_event_handlers()
ui_image_generation.create_event_handlers()
if not shared.args.portable:
ui_image_generation.create_event_handlers()
# Other events
ui_file_saving.create_event_handlers()