text-generation-webui/modules/ui_image_generation.py

748 lines
28 KiB
Python
Raw Normal View History

import json
2025-11-27 22:44:07 +01:00
import os
2025-12-01 22:59:10 +01:00
import time
2025-11-28 00:32:01 +01:00
import traceback
2025-11-27 23:24:35 +01:00
from datetime import datetime
2025-11-28 00:32:01 +01:00
from pathlib import Path
2025-11-27 23:24:35 +01:00
import gradio as gr
import numpy as np
from PIL import Image
from PIL.PngImagePlugin import PngInfo
2025-11-27 23:24:35 +01:00
2025-12-01 19:42:03 +01:00
from modules import shared, ui, utils
2025-11-27 23:24:35 +01:00
from modules.image_models import load_image_model, unload_image_model
2025-12-01 22:59:10 +01:00
from modules.logging_colors import logger
2025-12-01 19:42:03 +01:00
from modules.utils import gradio
2025-11-27 19:10:11 +01:00
2025-11-27 23:38:50 +01:00
ASPECT_RATIOS = {
"1:1 Square": (1, 1),
"16:9 Cinema": (16, 9),
"9:16 Mobile": (9, 16),
"4:3 Photo": (4, 3),
"Custom": None,
}
STEP = 16
IMAGES_PER_PAGE = 64
# Settings keys to save in PNG metadata (Generate tab only)
METADATA_SETTINGS_KEYS = [
'image_prompt',
'image_neg_prompt',
'image_width',
'image_height',
'image_aspect_ratio',
'image_steps',
'image_seed',
'image_batch_size',
'image_batch_count',
]
# Cache for all image paths
_image_cache = []
_cache_timestamp = 0
2025-11-27 23:38:50 +01:00
2025-11-28 00:32:01 +01:00
def round_to_step(value, step=STEP):
return round(value / step) * step
def clamp(value, min_val, max_val):
return max(min_val, min(max_val, value))
def apply_aspect_ratio(aspect_ratio, current_width, current_height):
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return current_width, current_height
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
if w_ratio == h_ratio:
base = min(current_width, current_height)
new_width = base
new_height = base
elif w_ratio < h_ratio:
new_width = current_width
new_height = round_to_step(current_width * h_ratio / w_ratio)
else:
new_height = current_height
new_width = round_to_step(current_height * w_ratio / h_ratio)
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
new_width = clamp(new_width, 256, 2048)
new_height = clamp(new_height, 256, 2048)
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
return int(new_width), int(new_height)
def update_height_from_width(width, aspect_ratio):
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return gr.update()
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
new_height = round_to_step(width * h_ratio / w_ratio)
new_height = clamp(new_height, 256, 2048)
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
return int(new_height)
def update_width_from_height(height, aspect_ratio):
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return gr.update()
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
new_width = round_to_step(height * w_ratio / h_ratio)
new_width = clamp(new_width, 256, 2048)
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
return int(new_width)
def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
new_width, new_height = height, width
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
new_ratio = "Custom"
for name, ratios in ASPECT_RATIOS.items():
if ratios is None:
continue
w_r, h_r = ratios
expected_height = new_width * h_r / w_r
if abs(expected_height - new_height) < STEP:
new_ratio = name
break
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
return new_width, new_height, new_ratio
def build_generation_metadata(state, actual_seed):
"""Build metadata dict from generation settings."""
metadata = {}
for key in METADATA_SETTINGS_KEYS:
if key in state:
metadata[key] = state[key]
# Store the actual seed used (not -1)
metadata['image_seed'] = actual_seed
metadata['generated_at'] = datetime.now().isoformat()
metadata['model'] = shared.image_model_name
return metadata
def save_generated_images(images, state, actual_seed):
"""Save images with generation metadata embedded in PNG."""
date_str = datetime.now().strftime("%Y-%m-%d")
folder_path = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder_path, exist_ok=True)
metadata = build_generation_metadata(state, actual_seed)
metadata_json = json.dumps(metadata, ensure_ascii=False)
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
filename = f"{timestamp}_{actual_seed:010d}_{idx:03d}.png"
filepath = os.path.join(folder_path, filename)
# Create PNG metadata
png_info = PngInfo()
png_info.add_text("image_gen_settings", metadata_json)
# Save with metadata
img.save(filepath, pnginfo=png_info)
def read_image_metadata(image_path):
"""Read generation metadata from PNG file."""
try:
with Image.open(image_path) as img:
if hasattr(img, 'text') and 'image_gen_settings' in img.text:
return json.loads(img.text['image_gen_settings'])
except Exception as e:
logger.debug(f"Could not read metadata from {image_path}: {e}")
return None
def format_metadata_for_display(metadata):
"""Format metadata as readable text."""
if not metadata:
return "No generation settings found in this image."
lines = ["**Generation Settings**", ""]
# Display in a nice order
display_order = [
('image_prompt', 'Prompt'),
('image_neg_prompt', 'Negative Prompt'),
('image_width', 'Width'),
('image_height', 'Height'),
('image_aspect_ratio', 'Aspect Ratio'),
('image_steps', 'Steps'),
('image_seed', 'Seed'),
('image_batch_size', 'Batch Size'),
('image_batch_count', 'Batch Count'),
('model', 'Model'),
('generated_at', 'Generated At'),
]
for key, label in display_order:
if key in metadata:
value = metadata[key]
if key in ['image_prompt', 'image_neg_prompt'] and value:
# Truncate long prompts for display
if len(str(value)) > 200:
value = str(value)[:200] + "..."
lines.append(f"**{label}:** {value}")
return "\n\n".join(lines)
def get_all_history_images(force_refresh=False):
"""Get all history images sorted by modification time (newest first). Uses caching."""
global _image_cache, _cache_timestamp
output_dir = os.path.join("user_data", "image_outputs")
if not os.path.exists(output_dir):
return []
# Check if we need to refresh cache
current_time = time.time()
if not force_refresh and _image_cache and (current_time - _cache_timestamp) < 2:
return _image_cache
image_files = []
for root, _, files in os.walk(output_dir):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
full_path = os.path.join(root, file)
image_files.append((full_path, os.path.getmtime(full_path)))
image_files.sort(key=lambda x: x[1], reverse=True)
_image_cache = [x[0] for x in image_files]
_cache_timestamp = current_time
return _image_cache
def get_paginated_images(page=0, force_refresh=False):
"""Get images for a specific page."""
all_images = get_all_history_images(force_refresh)
total_images = len(all_images)
total_pages = max(1, (total_images + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE)
# Clamp page to valid range
page = max(0, min(page, total_pages - 1))
start_idx = page * IMAGES_PER_PAGE
end_idx = min(start_idx + IMAGES_PER_PAGE, total_images)
page_images = all_images[start_idx:end_idx]
return page_images, page, total_pages, total_images
def refresh_gallery(current_page=0):
"""Refresh gallery with current page."""
images, page, total_pages, total_images = get_paginated_images(current_page, force_refresh=True)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def go_to_page(page_num, current_page):
"""Go to a specific page (1-indexed input)."""
try:
page = int(page_num) - 1 # Convert to 0-indexed
except (ValueError, TypeError):
page = current_page
images, page, total_pages, total_images = get_paginated_images(page)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def next_page(current_page):
"""Go to next page."""
images, page, total_pages, total_images = get_paginated_images(current_page + 1)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def prev_page(current_page):
"""Go to previous page."""
images, page, total_pages, total_images = get_paginated_images(current_page - 1)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def on_gallery_select(evt: gr.SelectData, current_page):
"""Handle image selection from gallery."""
if evt.index is None:
return "", "Select an image to view its settings"
# Get the current page's images to find the actual file path
all_images = get_all_history_images()
total_images = len(all_images)
# Calculate the actual index in the full list
start_idx = current_page * IMAGES_PER_PAGE
actual_idx = start_idx + evt.index
if actual_idx >= total_images:
return "", "Image not found"
image_path = all_images[actual_idx]
metadata = read_image_metadata(image_path)
metadata_display = format_metadata_for_display(metadata)
return image_path, metadata_display
def send_to_generate(selected_image_path):
"""Load settings from selected image and return updates for all Generate tab inputs."""
if not selected_image_path or not os.path.exists(selected_image_path):
return [gr.update()] * 9 + ["No image selected"]
metadata = read_image_metadata(selected_image_path)
if not metadata:
return [gr.update()] * 9 + ["No settings found in this image"]
# Return updates for each input element in order
updates = [
gr.update(value=metadata.get('image_prompt', '')),
gr.update(value=metadata.get('image_neg_prompt', '')),
gr.update(value=metadata.get('image_width', 1024)),
gr.update(value=metadata.get('image_height', 1024)),
gr.update(value=metadata.get('image_aspect_ratio', '1:1 Square')),
gr.update(value=metadata.get('image_steps', 9)),
gr.update(value=metadata.get('image_seed', -1)),
gr.update(value=metadata.get('image_batch_size', 1)),
gr.update(value=metadata.get('image_batch_count', 1)),
]
status = f"✓ Settings loaded from image (seed: {metadata.get('image_seed', 'unknown')})"
return updates + [status]
2025-11-27 19:10:11 +01:00
def create_ui():
2025-12-01 19:42:03 +01:00
if shared.settings['image_model_menu'] != 'None':
shared.image_model_name = shared.settings['image_model_menu']
2025-11-28 00:48:53 +01:00
2025-11-27 22:44:07 +01:00
with gr.Tab("Image AI", elem_id="image-ai-tab"):
with gr.Tabs():
2025-12-01 19:42:03 +01:00
# TAB 1: GENERATE
2025-11-28 00:32:01 +01:00
with gr.TabItem("Generate"):
2025-11-27 22:44:07 +01:00
with gr.Row():
with gr.Column(scale=4, min_width=350):
2025-12-01 19:42:03 +01:00
shared.gradio['image_prompt'] = gr.Textbox(
label="Prompt",
placeholder="Describe your imagination...",
lines=3,
autofocus=True,
value=shared.settings['image_prompt']
)
shared.gradio['image_neg_prompt'] = gr.Textbox(
label="Negative Prompt",
placeholder="Low quality...",
lines=3,
value=shared.settings['image_neg_prompt']
)
2025-11-28 00:48:53 +01:00
2025-12-01 19:42:03 +01:00
shared.gradio['image_generate_btn'] = gr.Button("✨ GENERATE", variant="primary", size="lg", elem_id="gen-btn")
2025-11-27 22:44:07 +01:00
gr.HTML("<hr style='border-top: 1px solid #444; margin: 20px 0;'>")
2025-12-01 19:49:22 +01:00
gr.Markdown("### Dimensions")
2025-11-27 22:44:07 +01:00
with gr.Row():
with gr.Column():
2025-12-01 19:42:03 +01:00
shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=32, label="Width")
2025-11-27 22:44:07 +01:00
with gr.Column():
2025-12-01 19:42:03 +01:00
shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=32, label="Height")
shared.gradio['image_swap_btn'] = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width")
2025-11-28 00:42:11 +01:00
2025-11-27 23:38:50 +01:00
with gr.Row():
2025-12-01 19:42:03 +01:00
shared.gradio['image_aspect_ratio'] = gr.Radio(
2025-11-27 23:38:50 +01:00
choices=["1:1 Square", "16:9 Cinema", "9:16 Mobile", "4:3 Photo", "Custom"],
2025-12-01 19:42:03 +01:00
value=shared.settings['image_aspect_ratio'],
2025-11-27 23:38:50 +01:00
label="Aspect Ratio",
interactive=True
)
2025-11-27 22:44:07 +01:00
2025-12-01 19:49:22 +01:00
gr.Markdown("### Config")
2025-11-27 22:44:07 +01:00
with gr.Row():
with gr.Column():
2025-12-02 03:20:47 +01:00
shared.gradio['image_steps'] = gr.Slider(1, 100, value=shared.settings['image_steps'], step=1, label="Steps")
2025-12-01 19:42:03 +01:00
shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random")
2025-11-27 22:44:07 +01:00
with gr.Column():
2025-12-01 19:42:03 +01:00
shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
shared.gradio['image_batch_count'] = gr.Slider(1, 128, value=shared.settings['image_batch_count'], step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
2025-11-28 00:48:53 +01:00
2025-11-27 22:44:07 +01:00
with gr.Column(scale=6, min_width=500):
with gr.Column(elem_classes=["viewport-container"]):
2025-12-01 19:48:55 +01:00
shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery")
2025-11-27 22:44:07 +01:00
# TAB 2: GALLERY (with pagination)
2025-11-27 22:44:07 +01:00
with gr.TabItem("Gallery"):
with gr.Row():
with gr.Column(scale=3):
# Pagination controls
with gr.Row():
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh", elem_classes="refresh-button")
shared.gradio['image_prev_page'] = gr.Button("◀ Prev", elem_classes="refresh-button")
shared.gradio['image_page_info'] = gr.Markdown("Loading...", elem_id="image-page-info")
shared.gradio['image_next_page'] = gr.Button("Next ▶", elem_classes="refresh-button")
shared.gradio['image_page_input'] = gr.Number(value=1, label="Page", precision=0, minimum=1, scale=0, min_width=80)
shared.gradio['image_go_to_page'] = gr.Button("Go", elem_classes="refresh-button", scale=0, min_width=50)
# State for current page and selected image path
shared.gradio['image_current_page'] = gr.State(value=0)
shared.gradio['image_selected_path'] = gr.State(value="")
# Paginated gallery using gr.Gallery
shared.gradio['image_history_gallery'] = gr.Gallery(
value=lambda: get_paginated_images(0)[0],
label="Image History",
show_label=False,
columns=6,
object_fit="cover",
height="auto",
allow_preview=True,
elem_id="image-history-gallery"
)
with gr.Column(scale=1):
gr.Markdown("### Selected Image")
shared.gradio['image_settings_display'] = gr.Markdown("Select an image to view its settings")
shared.gradio['image_send_to_generate'] = gr.Button("📤 Send to Generate", variant="primary")
shared.gradio['image_gallery_status'] = gr.Markdown("")
2025-11-27 22:44:07 +01:00
2025-12-01 19:42:03 +01:00
# TAB 3: MODEL
2025-11-28 00:32:01 +01:00
with gr.TabItem("Model"):
with gr.Row():
with gr.Column():
with gr.Row():
2025-12-01 19:42:03 +01:00
shared.gradio['image_model_menu'] = gr.Dropdown(
2025-11-28 00:32:01 +01:00
choices=utils.get_available_image_models(),
2025-12-01 19:42:03 +01:00
value=shared.settings['image_model_menu'],
2025-11-28 00:32:01 +01:00
label='Model',
elem_classes='slim-dropdown'
)
2025-12-01 19:42:03 +01:00
shared.gradio['image_refresh_models'] = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40)
shared.gradio['image_load_model'] = gr.Button("Load", variant='primary', elem_classes='refresh-button')
shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button')
2025-11-28 00:48:53 +01:00
2025-11-28 00:42:11 +01:00
gr.Markdown("## Settings")
with gr.Row():
with gr.Column():
shared.gradio['image_quant'] = gr.Dropdown(
label='Quantization',
choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'],
value=shared.settings['image_quant'],
info='Quantization method for reduced VRAM usage. Quanto supports lower precisions (2-bit, 4-bit, 8-bit).'
)
2025-12-01 19:42:03 +01:00
shared.gradio['image_dtype'] = gr.Dropdown(
2025-11-28 00:42:11 +01:00
choices=['bfloat16', 'float16'],
2025-12-01 19:42:03 +01:00
value=shared.settings['image_dtype'],
2025-11-28 00:42:11 +01:00
label='Data Type',
info='bfloat16 recommended for modern GPUs'
)
2025-12-01 19:42:03 +01:00
shared.gradio['image_attn_backend'] = gr.Dropdown(
2025-11-28 00:42:11 +01:00
choices=['sdpa', 'flash_attention_2', 'flash_attention_3'],
2025-12-01 19:42:03 +01:00
value=shared.settings['image_attn_backend'],
2025-11-28 00:42:11 +01:00
label='Attention Backend',
info='SDPA is default. Flash Attention requires compatible GPU.'
)
with gr.Column():
2025-12-01 19:42:03 +01:00
shared.gradio['image_compile'] = gr.Checkbox(
value=shared.settings['image_compile'],
2025-11-28 00:42:11 +01:00
label='Compile Model',
info='Faster inference after first run. First run will be slow.'
)
2025-12-01 19:42:03 +01:00
shared.gradio['image_cpu_offload'] = gr.Checkbox(
value=shared.settings['image_cpu_offload'],
2025-11-28 00:42:11 +01:00
label='CPU Offload',
info='Enable for low VRAM GPUs. Slower but uses less memory.'
)
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
with gr.Column():
2025-12-01 19:42:03 +01:00
shared.gradio['image_download_path'] = gr.Textbox(
2025-11-28 00:42:11 +01:00
label="Download model",
2025-11-28 00:32:01 +01:00
placeholder="Tongyi-MAI/Z-Image-Turbo",
2025-12-01 19:42:03 +01:00
info="Enter HuggingFace path. Use : for branch, e.g. user/model:main"
2025-11-28 00:32:01 +01:00
)
2025-12-01 19:42:03 +01:00
shared.gradio['image_download_btn'] = gr.Button("Download", variant='primary')
shared.gradio['image_model_status'] = gr.Markdown(
value=f"Model: **{shared.settings['image_model_menu']}** (not loaded)" if shared.settings['image_model_menu'] != 'None' else "No model selected"
2025-11-28 00:42:11 +01:00
)
2025-11-27 23:25:49 +01:00
2025-11-28 00:32:01 +01:00
2025-12-01 19:42:03 +01:00
def create_event_handlers():
# Dimension controls
shared.gradio['image_aspect_ratio'].change(
apply_aspect_ratio,
gradio('image_aspect_ratio', 'image_width', 'image_height'),
gradio('image_width', 'image_height'),
show_progress=False
)
shared.gradio['image_width'].release(
update_height_from_width,
gradio('image_width', 'image_aspect_ratio'),
gradio('image_height'),
show_progress=False
)
shared.gradio['image_height'].release(
update_width_from_height,
gradio('image_height', 'image_aspect_ratio'),
gradio('image_width'),
show_progress=False
)
shared.gradio['image_swap_btn'].click(
swap_dimensions_and_update_ratio,
gradio('image_width', 'image_height', 'image_aspect_ratio'),
gradio('image_width', 'image_height', 'image_aspect_ratio'),
show_progress=False
)
# Generation
shared.gradio['image_generate_btn'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
2025-12-01 22:59:10 +01:00
generate, gradio('interface_state'), gradio('image_output_gallery'))
2025-12-01 19:42:03 +01:00
shared.gradio['image_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
2025-12-01 22:59:10 +01:00
generate, gradio('interface_state'), gradio('image_output_gallery'))
2025-12-01 19:42:03 +01:00
shared.gradio['image_neg_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
2025-12-01 22:59:10 +01:00
generate, gradio('interface_state'), gradio('image_output_gallery'))
2025-12-01 19:42:03 +01:00
# Model management
shared.gradio['image_refresh_models'].click(
lambda: gr.update(choices=utils.get_available_image_models()),
None,
gradio('image_model_menu'),
show_progress=False
)
shared.gradio['image_load_model'].click(
load_image_model_wrapper,
gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile', 'image_quant'),
2025-12-01 19:42:03 +01:00
gradio('image_model_status'),
show_progress=True
)
shared.gradio['image_unload_model'].click(
unload_image_model_wrapper,
None,
gradio('image_model_status'),
show_progress=False
)
shared.gradio['image_download_btn'].click(
download_image_model_wrapper,
gradio('image_download_path'),
gradio('image_model_status', 'image_model_menu'),
show_progress=True
)
# Gallery pagination handlers
2025-12-01 19:42:03 +01:00
shared.gradio['image_refresh_history'].click(
refresh_gallery,
gradio('image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
shared.gradio['image_next_page'].click(
next_page,
gradio('image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
shared.gradio['image_prev_page'].click(
prev_page,
gradio('image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
shared.gradio['image_go_to_page'].click(
go_to_page,
gradio('image_page_input', 'image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
# Image selection from gallery
shared.gradio['image_history_gallery'].select(
on_gallery_select,
gradio('image_current_page'),
gradio('image_selected_path', 'image_settings_display'),
show_progress=False
)
# Send to Generate
shared.gradio['image_send_to_generate'].click(
send_to_generate,
gradio('image_selected_path'),
gradio(
'image_prompt',
'image_neg_prompt',
'image_width',
'image_height',
'image_aspect_ratio',
'image_steps',
'image_seed',
'image_batch_size',
'image_batch_count',
'image_gallery_status'
),
2025-12-01 19:42:03 +01:00
show_progress=False
)
def generate(state):
2025-12-02 03:18:15 +01:00
"""
Generate images using the loaded model.
Automatically adjusts parameters based on pipeline type.
"""
2025-12-02 02:13:16 +01:00
import torch
2025-12-02 03:18:15 +01:00
import numpy as np
2025-12-02 02:13:16 +01:00
2025-12-01 19:42:03 +01:00
model_name = state['image_model_menu']
if not model_name or model_name == 'None':
2025-12-01 22:59:10 +01:00
logger.error("No image model selected. Go to the Model tab and select a model.")
return []
2025-11-28 00:48:53 +01:00
2025-11-27 23:24:35 +01:00
if shared.image_model is None:
2025-11-28 00:32:01 +01:00
result = load_image_model(
model_name,
2025-12-01 19:42:03 +01:00
dtype=state['image_dtype'],
attn_backend=state['image_attn_backend'],
cpu_offload=state['image_cpu_offload'],
compile_model=state['image_compile'],
quant_method=state['image_quant']
2025-11-28 00:32:01 +01:00
)
if result is None:
2025-12-01 22:59:10 +01:00
logger.error(f"Failed to load model `{model_name}`.")
return []
2025-11-28 00:48:53 +01:00
2025-12-01 19:42:03 +01:00
shared.image_model_name = model_name
seed = state['image_seed']
2025-11-27 23:24:35 +01:00
if seed == -1:
seed = np.random.randint(0, 2**32 - 1)
2025-11-28 00:48:53 +01:00
2025-11-27 22:53:46 +01:00
generator = torch.Generator("cuda").manual_seed(int(seed))
all_images = []
2025-11-28 00:48:53 +01:00
2025-12-02 03:18:15 +01:00
# Get pipeline type for parameter adjustment
pipeline_type = getattr(shared, 'image_pipeline_type', None)
if pipeline_type is None:
pipeline_type = get_pipeline_type(shared.image_model)
# Build generation kwargs based on pipeline type
gen_kwargs = {
"prompt": state['image_prompt'],
"negative_prompt": state['image_neg_prompt'],
"height": int(state['image_height']),
"width": int(state['image_width']),
"num_inference_steps": int(state['image_steps']),
"num_images_per_prompt": int(state['image_batch_size']),
"generator": generator,
}
# Add pipeline-specific parameters
if pipeline_type == 'qwenimage':
# Qwen-Image uses true_cfg_scale instead of guidance_scale
gen_kwargs["true_cfg_scale"] = state.get('image_cfg_scale', 4.0)
else:
# Z-Image and others use guidance_scale
gen_kwargs["guidance_scale"] = state.get('image_cfg_scale', 0.0)
2025-12-01 22:59:10 +01:00
t0 = time.time()
2025-12-01 19:42:03 +01:00
for i in range(int(state['image_batch_count'])):
generator.manual_seed(int(seed + i))
2025-12-02 03:18:15 +01:00
batch_results = shared.image_model(**gen_kwargs).images
2025-11-27 22:53:46 +01:00
all_images.extend(batch_results)
2025-11-28 00:48:53 +01:00
2025-12-01 22:59:10 +01:00
t1 = time.time()
save_generated_images(all_images, state, seed)
2025-12-01 22:59:10 +01:00
2025-12-02 00:44:31 +01:00
logger.info(f'Images generated in {(t1-t0):.2f} seconds ({state["image_steps"]/(t1-t0):.2f} steps/s, seed {seed})')
2025-12-01 22:59:10 +01:00
return all_images
2025-11-27 22:53:46 +01:00
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method):
2025-12-01 19:42:03 +01:00
if not model_name or model_name == 'None':
2025-11-28 00:32:01 +01:00
yield "No model selected"
return
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
try:
yield f"Loading `{model_name}`..."
unload_image_model()
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
result = load_image_model(
model_name,
dtype=dtype,
attn_backend=attn_backend,
cpu_offload=cpu_offload,
compile_model=compile_model,
quant_method=quant_method
2025-11-28 00:32:01 +01:00
)
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
if result is not None:
2025-12-01 19:42:03 +01:00
shared.image_model_name = model_name
yield f"✓ Loaded **{model_name}** (quantization: {quant_method})"
2025-11-28 00:32:01 +01:00
else:
yield f"✗ Failed to load `{model_name}`"
except Exception:
2025-12-01 19:42:03 +01:00
yield f"Error:\n```\n{traceback.format_exc()}\n```"
2025-11-28 00:32:01 +01:00
def unload_image_model_wrapper():
unload_image_model()
if shared.image_model_name != 'None':
return f"Model: **{shared.image_model_name}** (not loaded)"
2025-12-01 19:42:03 +01:00
return "No model loaded"
2025-11-28 00:32:01 +01:00
def download_image_model_wrapper(model_path):
from huggingface_hub import snapshot_download
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
if not model_path:
yield "No model specified", gr.update()
return
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
try:
model_path = model_path.strip()
if model_path.startswith('https://huggingface.co/'):
model_path = model_path[len('https://huggingface.co/'):]
elif model_path.startswith('huggingface.co/'):
model_path = model_path[len('huggingface.co/'):]
2025-11-28 00:32:01 +01:00
if ':' in model_path:
model_id, branch = model_path.rsplit(':', 1)
else:
model_id, branch = model_path, 'main'
2025-11-28 00:48:53 +01:00
2025-11-28 01:37:03 +01:00
folder_name = model_id.replace('/', '_')
2025-11-28 00:32:01 +01:00
output_folder = Path(shared.args.image_model_dir) / folder_name
2025-11-28 00:48:53 +01:00
2025-12-01 19:42:03 +01:00
yield f"Downloading `{model_id}` (branch: {branch})...", gr.update()
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
snapshot_download(
repo_id=model_id,
revision=branch,
local_dir=output_folder,
local_dir_use_symlinks=False,
)
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
new_choices = utils.get_available_image_models()
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
except Exception:
2025-12-01 19:42:03 +01:00
yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update()