2025-12-02 00:41:58 +01:00
import json
2025-11-27 22:44:07 +01:00
import os
2025-12-01 22:59:10 +01:00
import time
2025-11-28 00:32:01 +01:00
import traceback
2025-11-27 23:24:35 +01:00
from datetime import datetime
2025-11-28 00:32:01 +01:00
from pathlib import Path
2025-11-27 23:24:35 +01:00
import gradio as gr
import numpy as np
import torch
2025-12-02 00:41:58 +01:00
from PIL import Image
from PIL . PngImagePlugin import PngInfo
2025-11-27 23:24:35 +01:00
2025-12-01 19:42:03 +01:00
from modules import shared , ui , utils
2025-11-27 23:24:35 +01:00
from modules . image_models import load_image_model , unload_image_model
2025-12-01 22:59:10 +01:00
from modules . logging_colors import logger
2025-12-01 19:42:03 +01:00
from modules . utils import gradio
2025-11-27 19:10:11 +01:00
2025-11-27 23:38:50 +01:00
ASPECT_RATIOS = {
" 1:1 Square " : ( 1 , 1 ) ,
" 16:9 Cinema " : ( 16 , 9 ) ,
" 9:16 Mobile " : ( 9 , 16 ) ,
" 4:3 Photo " : ( 4 , 3 ) ,
" Custom " : None ,
}
2025-12-01 19:42:03 +01:00
STEP = 32
2025-12-02 00:41:58 +01:00
IMAGES_PER_PAGE = 64
# Settings keys to save in PNG metadata (Generate tab only)
METADATA_SETTINGS_KEYS = [
' image_prompt ' ,
' image_neg_prompt ' ,
' image_width ' ,
' image_height ' ,
' image_aspect_ratio ' ,
' image_steps ' ,
' image_seed ' ,
' image_batch_size ' ,
' image_batch_count ' ,
]
# Cache for all image paths
_image_cache = [ ]
_cache_timestamp = 0
2025-11-27 23:38:50 +01:00
2025-11-28 00:32:01 +01:00
def round_to_step ( value , step = STEP ) :
return round ( value / step ) * step
def clamp ( value , min_val , max_val ) :
return max ( min_val , min ( max_val , value ) )
def apply_aspect_ratio ( aspect_ratio , current_width , current_height ) :
if aspect_ratio == " Custom " or aspect_ratio not in ASPECT_RATIOS :
return current_width , current_height
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
w_ratio , h_ratio = ASPECT_RATIOS [ aspect_ratio ]
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
if w_ratio == h_ratio :
base = min ( current_width , current_height )
new_width = base
new_height = base
elif w_ratio < h_ratio :
new_width = current_width
new_height = round_to_step ( current_width * h_ratio / w_ratio )
else :
new_height = current_height
new_width = round_to_step ( current_height * w_ratio / h_ratio )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
new_width = clamp ( new_width , 256 , 2048 )
new_height = clamp ( new_height , 256 , 2048 )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
return int ( new_width ) , int ( new_height )
def update_height_from_width ( width , aspect_ratio ) :
if aspect_ratio == " Custom " or aspect_ratio not in ASPECT_RATIOS :
return gr . update ( )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
w_ratio , h_ratio = ASPECT_RATIOS [ aspect_ratio ]
new_height = round_to_step ( width * h_ratio / w_ratio )
new_height = clamp ( new_height , 256 , 2048 )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
return int ( new_height )
def update_width_from_height ( height , aspect_ratio ) :
if aspect_ratio == " Custom " or aspect_ratio not in ASPECT_RATIOS :
return gr . update ( )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
w_ratio , h_ratio = ASPECT_RATIOS [ aspect_ratio ]
new_width = round_to_step ( height * w_ratio / h_ratio )
new_width = clamp ( new_width , 256 , 2048 )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
return int ( new_width )
def swap_dimensions_and_update_ratio ( width , height , aspect_ratio ) :
new_width , new_height = height , width
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
new_ratio = " Custom "
for name , ratios in ASPECT_RATIOS . items ( ) :
if ratios is None :
continue
w_r , h_r = ratios
expected_height = new_width * h_r / w_r
if abs ( expected_height - new_height ) < STEP :
new_ratio = name
break
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
return new_width , new_height , new_ratio
2025-12-02 00:41:58 +01:00
def build_generation_metadata ( state , actual_seed ) :
""" Build metadata dict from generation settings. """
metadata = { }
for key in METADATA_SETTINGS_KEYS :
if key in state :
metadata [ key ] = state [ key ]
# Store the actual seed used (not -1)
metadata [ ' image_seed ' ] = actual_seed
metadata [ ' generated_at ' ] = datetime . now ( ) . isoformat ( )
metadata [ ' model ' ] = shared . image_model_name
return metadata
def save_generated_images ( images , state , actual_seed ) :
""" Save images with generation metadata embedded in PNG. """
date_str = datetime . now ( ) . strftime ( " % Y- % m- %d " )
folder_path = os . path . join ( " user_data " , " image_outputs " , date_str )
os . makedirs ( folder_path , exist_ok = True )
metadata = build_generation_metadata ( state , actual_seed )
metadata_json = json . dumps ( metadata , ensure_ascii = False )
for idx , img in enumerate ( images ) :
timestamp = datetime . now ( ) . strftime ( " % H- % M- % S " )
2025-12-02 01:02:35 +01:00
filename = f " { timestamp } _ { actual_seed : 010d } _ { idx : 03d } .png "
2025-12-02 00:41:58 +01:00
filepath = os . path . join ( folder_path , filename )
# Create PNG metadata
png_info = PngInfo ( )
png_info . add_text ( " image_gen_settings " , metadata_json )
# Save with metadata
img . save ( filepath , pnginfo = png_info )
def read_image_metadata ( image_path ) :
""" Read generation metadata from PNG file. """
try :
with Image . open ( image_path ) as img :
if hasattr ( img , ' text ' ) and ' image_gen_settings ' in img . text :
return json . loads ( img . text [ ' image_gen_settings ' ] )
except Exception as e :
logger . debug ( f " Could not read metadata from { image_path } : { e } " )
return None
def format_metadata_for_display ( metadata ) :
""" Format metadata as readable text. """
if not metadata :
return " No generation settings found in this image. "
lines = [ " **Generation Settings** " , " " ]
# Display in a nice order
display_order = [
( ' image_prompt ' , ' Prompt ' ) ,
( ' image_neg_prompt ' , ' Negative Prompt ' ) ,
( ' image_width ' , ' Width ' ) ,
( ' image_height ' , ' Height ' ) ,
( ' image_aspect_ratio ' , ' Aspect Ratio ' ) ,
( ' image_steps ' , ' Steps ' ) ,
( ' image_seed ' , ' Seed ' ) ,
( ' image_batch_size ' , ' Batch Size ' ) ,
( ' image_batch_count ' , ' Batch Count ' ) ,
( ' model ' , ' Model ' ) ,
( ' generated_at ' , ' Generated At ' ) ,
]
for key , label in display_order :
if key in metadata :
value = metadata [ key ]
if key in [ ' image_prompt ' , ' image_neg_prompt ' ] and value :
# Truncate long prompts for display
if len ( str ( value ) ) > 200 :
value = str ( value ) [ : 200 ] + " ... "
lines . append ( f " ** { label } :** { value } " )
return " \n \n " . join ( lines )
def get_all_history_images ( force_refresh = False ) :
""" Get all history images sorted by modification time (newest first). Uses caching. """
global _image_cache , _cache_timestamp
output_dir = os . path . join ( " user_data " , " image_outputs " )
if not os . path . exists ( output_dir ) :
return [ ]
# Check if we need to refresh cache
current_time = time . time ( )
if not force_refresh and _image_cache and ( current_time - _cache_timestamp ) < 2 :
return _image_cache
image_files = [ ]
for root , _ , files in os . walk ( output_dir ) :
for file in files :
if file . endswith ( ( " .png " , " .jpg " , " .jpeg " ) ) :
full_path = os . path . join ( root , file )
image_files . append ( ( full_path , os . path . getmtime ( full_path ) ) )
image_files . sort ( key = lambda x : x [ 1 ] , reverse = True )
_image_cache = [ x [ 0 ] for x in image_files ]
_cache_timestamp = current_time
return _image_cache
def get_paginated_images ( page = 0 , force_refresh = False ) :
""" Get images for a specific page. """
all_images = get_all_history_images ( force_refresh )
total_images = len ( all_images )
total_pages = max ( 1 , ( total_images + IMAGES_PER_PAGE - 1 ) / / IMAGES_PER_PAGE )
# Clamp page to valid range
page = max ( 0 , min ( page , total_pages - 1 ) )
start_idx = page * IMAGES_PER_PAGE
end_idx = min ( start_idx + IMAGES_PER_PAGE , total_images )
page_images = all_images [ start_idx : end_idx ]
return page_images , page , total_pages , total_images
def refresh_gallery ( current_page = 0 ) :
""" Refresh gallery with current page. """
images , page , total_pages , total_images = get_paginated_images ( current_page , force_refresh = True )
page_info = f " Page { page + 1 } of { total_pages } ( { total_images } total images) "
return images , page , page_info
def go_to_page ( page_num , current_page ) :
""" Go to a specific page (1-indexed input). """
try :
page = int ( page_num ) - 1 # Convert to 0-indexed
except ( ValueError , TypeError ) :
page = current_page
images , page , total_pages , total_images = get_paginated_images ( page )
page_info = f " Page { page + 1 } of { total_pages } ( { total_images } total images) "
return images , page , page_info
def next_page ( current_page ) :
""" Go to next page. """
images , page , total_pages , total_images = get_paginated_images ( current_page + 1 )
page_info = f " Page { page + 1 } of { total_pages } ( { total_images } total images) "
return images , page , page_info
def prev_page ( current_page ) :
""" Go to previous page. """
images , page , total_pages , total_images = get_paginated_images ( current_page - 1 )
page_info = f " Page { page + 1 } of { total_pages } ( { total_images } total images) "
return images , page , page_info
def on_gallery_select ( evt : gr . SelectData , current_page ) :
""" Handle image selection from gallery. """
if evt . index is None :
return " " , " Select an image to view its settings "
# Get the current page's images to find the actual file path
all_images = get_all_history_images ( )
total_images = len ( all_images )
# Calculate the actual index in the full list
start_idx = current_page * IMAGES_PER_PAGE
actual_idx = start_idx + evt . index
if actual_idx > = total_images :
return " " , " Image not found "
image_path = all_images [ actual_idx ]
metadata = read_image_metadata ( image_path )
metadata_display = format_metadata_for_display ( metadata )
return image_path , metadata_display
def send_to_generate ( selected_image_path ) :
""" Load settings from selected image and return updates for all Generate tab inputs. """
if not selected_image_path or not os . path . exists ( selected_image_path ) :
return [ gr . update ( ) ] * 9 + [ " No image selected " ]
metadata = read_image_metadata ( selected_image_path )
if not metadata :
return [ gr . update ( ) ] * 9 + [ " No settings found in this image " ]
# Return updates for each input element in order
updates = [
gr . update ( value = metadata . get ( ' image_prompt ' , ' ' ) ) ,
gr . update ( value = metadata . get ( ' image_neg_prompt ' , ' ' ) ) ,
gr . update ( value = metadata . get ( ' image_width ' , 1024 ) ) ,
gr . update ( value = metadata . get ( ' image_height ' , 1024 ) ) ,
gr . update ( value = metadata . get ( ' image_aspect_ratio ' , ' 1:1 Square ' ) ) ,
gr . update ( value = metadata . get ( ' image_steps ' , 9 ) ) ,
gr . update ( value = metadata . get ( ' image_seed ' , - 1 ) ) ,
gr . update ( value = metadata . get ( ' image_batch_size ' , 1 ) ) ,
gr . update ( value = metadata . get ( ' image_batch_count ' , 1 ) ) ,
]
status = f " ✓ Settings loaded from image (seed: { metadata . get ( ' image_seed ' , ' unknown ' ) } ) "
return updates + [ status ]
2025-11-27 19:10:11 +01:00
def create_ui ( ) :
2025-12-01 19:42:03 +01:00
if shared . settings [ ' image_model_menu ' ] != ' None ' :
shared . image_model_name = shared . settings [ ' image_model_menu ' ]
2025-11-28 00:48:53 +01:00
2025-11-27 22:44:07 +01:00
with gr . Tab ( " Image AI " , elem_id = " image-ai-tab " ) :
with gr . Tabs ( ) :
2025-12-01 19:42:03 +01:00
# TAB 1: GENERATE
2025-11-28 00:32:01 +01:00
with gr . TabItem ( " Generate " ) :
2025-11-27 22:44:07 +01:00
with gr . Row ( ) :
with gr . Column ( scale = 4 , min_width = 350 ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_prompt ' ] = gr . Textbox (
label = " Prompt " ,
placeholder = " Describe your imagination... " ,
lines = 3 ,
autofocus = True ,
value = shared . settings [ ' image_prompt ' ]
)
shared . gradio [ ' image_neg_prompt ' ] = gr . Textbox (
label = " Negative Prompt " ,
placeholder = " Low quality... " ,
lines = 3 ,
value = shared . settings [ ' image_neg_prompt ' ]
)
2025-11-28 00:48:53 +01:00
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_generate_btn ' ] = gr . Button ( " ✨ GENERATE " , variant = " primary " , size = " lg " , elem_id = " gen-btn " )
2025-11-27 22:44:07 +01:00
gr . HTML ( " <hr style= ' border-top: 1px solid #444; margin: 20px 0; ' > " )
2025-12-01 19:49:22 +01:00
gr . Markdown ( " ### Dimensions " )
2025-11-27 22:44:07 +01:00
with gr . Row ( ) :
with gr . Column ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_width ' ] = gr . Slider ( 256 , 2048 , value = shared . settings [ ' image_width ' ] , step = 32 , label = " Width " )
2025-11-27 22:44:07 +01:00
with gr . Column ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_height ' ] = gr . Slider ( 256 , 2048 , value = shared . settings [ ' image_height ' ] , step = 32 , label = " Height " )
shared . gradio [ ' image_swap_btn ' ] = gr . Button ( " ⇄ Swap " , elem_classes = ' refresh-button ' , scale = 0 , min_width = 80 , elem_id = " swap-height-width " )
2025-11-28 00:42:11 +01:00
2025-11-27 23:38:50 +01:00
with gr . Row ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_aspect_ratio ' ] = gr . Radio (
2025-11-27 23:38:50 +01:00
choices = [ " 1:1 Square " , " 16:9 Cinema " , " 9:16 Mobile " , " 4:3 Photo " , " Custom " ] ,
2025-12-01 19:42:03 +01:00
value = shared . settings [ ' image_aspect_ratio ' ] ,
2025-11-27 23:38:50 +01:00
label = " Aspect Ratio " ,
interactive = True
)
2025-11-27 22:44:07 +01:00
2025-12-01 19:49:22 +01:00
gr . Markdown ( " ### Config " )
2025-11-27 22:44:07 +01:00
with gr . Row ( ) :
with gr . Column ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_steps ' ] = gr . Slider ( 1 , 15 , value = shared . settings [ ' image_steps ' ] , step = 1 , label = " Steps " )
shared . gradio [ ' image_seed ' ] = gr . Number ( label = " Seed " , value = shared . settings [ ' image_seed ' ] , precision = 0 , info = " -1 = Random " )
2025-11-27 22:44:07 +01:00
with gr . Column ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_batch_size ' ] = gr . Slider ( 1 , 32 , value = shared . settings [ ' image_batch_size ' ] , step = 1 , label = " Batch Size (VRAM Heavy) " , info = " Generates N images at once. " )
shared . gradio [ ' image_batch_count ' ] = gr . Slider ( 1 , 128 , value = shared . settings [ ' image_batch_count ' ] , step = 1 , label = " Sequential Count (Loop) " , info = " Repeats the generation N times. " )
2025-11-28 00:48:53 +01:00
2025-11-27 22:44:07 +01:00
with gr . Column ( scale = 6 , min_width = 500 ) :
with gr . Column ( elem_classes = [ " viewport-container " ] ) :
2025-12-01 19:48:55 +01:00
shared . gradio [ ' image_output_gallery ' ] = gr . Gallery ( label = " Output " , show_label = False , columns = 2 , rows = 2 , height = " 80vh " , object_fit = " contain " , preview = True , elem_id = " image-output-gallery " )
2025-11-27 22:44:07 +01:00
2025-12-02 00:41:58 +01:00
# TAB 2: GALLERY (with pagination)
2025-11-27 22:44:07 +01:00
with gr . TabItem ( " Gallery " ) :
with gr . Row ( ) :
2025-12-02 00:41:58 +01:00
with gr . Column ( scale = 3 ) :
# Pagination controls
with gr . Row ( ) :
shared . gradio [ ' image_refresh_history ' ] = gr . Button ( " 🔄 Refresh " , elem_classes = " refresh-button " )
shared . gradio [ ' image_prev_page ' ] = gr . Button ( " ◀ Prev " , elem_classes = " refresh-button " )
shared . gradio [ ' image_page_info ' ] = gr . Markdown ( " Loading... " , elem_id = " image-page-info " )
shared . gradio [ ' image_next_page ' ] = gr . Button ( " Next ▶ " , elem_classes = " refresh-button " )
shared . gradio [ ' image_page_input ' ] = gr . Number ( value = 1 , label = " Page " , precision = 0 , minimum = 1 , scale = 0 , min_width = 80 )
shared . gradio [ ' image_go_to_page ' ] = gr . Button ( " Go " , elem_classes = " refresh-button " , scale = 0 , min_width = 50 )
# State for current page and selected image path
shared . gradio [ ' image_current_page ' ] = gr . State ( value = 0 )
shared . gradio [ ' image_selected_path ' ] = gr . State ( value = " " )
# Paginated gallery using gr.Gallery
shared . gradio [ ' image_history_gallery ' ] = gr . Gallery (
value = lambda : get_paginated_images ( 0 ) [ 0 ] ,
label = " Image History " ,
show_label = False ,
columns = 6 ,
object_fit = " cover " ,
height = " auto " ,
allow_preview = True ,
elem_id = " image-history-gallery "
)
with gr . Column ( scale = 1 ) :
gr . Markdown ( " ### Selected Image " )
shared . gradio [ ' image_settings_display ' ] = gr . Markdown ( " Select an image to view its settings " )
shared . gradio [ ' image_send_to_generate ' ] = gr . Button ( " 📤 Send to Generate " , variant = " primary " )
shared . gradio [ ' image_gallery_status ' ] = gr . Markdown ( " " )
2025-11-27 22:44:07 +01:00
2025-12-01 19:42:03 +01:00
# TAB 3: MODEL
2025-11-28 00:32:01 +01:00
with gr . TabItem ( " Model " ) :
with gr . Row ( ) :
with gr . Column ( ) :
with gr . Row ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_model_menu ' ] = gr . Dropdown (
2025-11-28 00:32:01 +01:00
choices = utils . get_available_image_models ( ) ,
2025-12-01 19:42:03 +01:00
value = shared . settings [ ' image_model_menu ' ] ,
2025-11-28 00:32:01 +01:00
label = ' Model ' ,
elem_classes = ' slim-dropdown '
)
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_refresh_models ' ] = gr . Button ( " 🔄 " , elem_classes = ' refresh-button ' , scale = 0 , min_width = 40 )
shared . gradio [ ' image_load_model ' ] = gr . Button ( " Load " , variant = ' primary ' , elem_classes = ' refresh-button ' )
shared . gradio [ ' image_unload_model ' ] = gr . Button ( " Unload " , elem_classes = ' refresh-button ' )
2025-11-28 00:48:53 +01:00
2025-11-28 00:42:11 +01:00
gr . Markdown ( " ## Settings " )
with gr . Row ( ) :
with gr . Column ( ) :
2025-12-02 02:05:42 +01:00
shared . gradio [ ' image_quant ' ] = gr . Dropdown (
label = ' Quantization ' ,
choices = [ ' none ' , ' bnb-8bit ' , ' bnb-4bit ' , ' quanto-8bit ' , ' quanto-4bit ' , ' quanto-2bit ' ] ,
value = shared . settings [ ' image_quant ' ] ,
info = ' Quantization method for reduced VRAM usage. Quanto supports lower precisions (2-bit, 4-bit, 8-bit). '
)
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_dtype ' ] = gr . Dropdown (
2025-11-28 00:42:11 +01:00
choices = [ ' bfloat16 ' , ' float16 ' ] ,
2025-12-01 19:42:03 +01:00
value = shared . settings [ ' image_dtype ' ] ,
2025-11-28 00:42:11 +01:00
label = ' Data Type ' ,
info = ' bfloat16 recommended for modern GPUs '
)
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_attn_backend ' ] = gr . Dropdown (
2025-11-28 00:42:11 +01:00
choices = [ ' sdpa ' , ' flash_attention_2 ' , ' flash_attention_3 ' ] ,
2025-12-01 19:42:03 +01:00
value = shared . settings [ ' image_attn_backend ' ] ,
2025-11-28 00:42:11 +01:00
label = ' Attention Backend ' ,
info = ' SDPA is default. Flash Attention requires compatible GPU. '
)
with gr . Column ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_compile ' ] = gr . Checkbox (
value = shared . settings [ ' image_compile ' ] ,
2025-11-28 00:42:11 +01:00
label = ' Compile Model ' ,
info = ' Faster inference after first run. First run will be slow. '
)
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_cpu_offload ' ] = gr . Checkbox (
value = shared . settings [ ' image_cpu_offload ' ] ,
2025-11-28 00:42:11 +01:00
label = ' CPU Offload ' ,
info = ' Enable for low VRAM GPUs. Slower but uses less memory. '
)
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
with gr . Column ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_download_path ' ] = gr . Textbox (
2025-11-28 00:42:11 +01:00
label = " Download model " ,
2025-11-28 00:32:01 +01:00
placeholder = " Tongyi-MAI/Z-Image-Turbo " ,
2025-12-01 19:42:03 +01:00
info = " Enter HuggingFace path. Use : for branch, e.g. user/model:main "
2025-11-28 00:32:01 +01:00
)
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_download_btn ' ] = gr . Button ( " Download " , variant = ' primary ' )
shared . gradio [ ' image_model_status ' ] = gr . Markdown (
value = f " Model: ** { shared . settings [ ' image_model_menu ' ] } ** (not loaded) " if shared . settings [ ' image_model_menu ' ] != ' None ' else " No model selected "
2025-11-28 00:42:11 +01:00
)
2025-11-27 23:25:49 +01:00
2025-11-28 00:32:01 +01:00
2025-12-01 19:42:03 +01:00
def create_event_handlers ( ) :
# Dimension controls
shared . gradio [ ' image_aspect_ratio ' ] . change (
apply_aspect_ratio ,
gradio ( ' image_aspect_ratio ' , ' image_width ' , ' image_height ' ) ,
gradio ( ' image_width ' , ' image_height ' ) ,
show_progress = False
)
shared . gradio [ ' image_width ' ] . release (
update_height_from_width ,
gradio ( ' image_width ' , ' image_aspect_ratio ' ) ,
gradio ( ' image_height ' ) ,
show_progress = False
)
shared . gradio [ ' image_height ' ] . release (
update_width_from_height ,
gradio ( ' image_height ' , ' image_aspect_ratio ' ) ,
gradio ( ' image_width ' ) ,
show_progress = False
)
shared . gradio [ ' image_swap_btn ' ] . click (
swap_dimensions_and_update_ratio ,
gradio ( ' image_width ' , ' image_height ' , ' image_aspect_ratio ' ) ,
gradio ( ' image_width ' , ' image_height ' , ' image_aspect_ratio ' ) ,
show_progress = False
)
# Generation
shared . gradio [ ' image_generate_btn ' ] . click (
ui . gather_interface_values , gradio ( shared . input_elements ) , gradio ( ' interface_state ' ) ) . then (
2025-12-01 22:59:10 +01:00
generate , gradio ( ' interface_state ' ) , gradio ( ' image_output_gallery ' ) )
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_prompt ' ] . submit (
ui . gather_interface_values , gradio ( shared . input_elements ) , gradio ( ' interface_state ' ) ) . then (
2025-12-01 22:59:10 +01:00
generate , gradio ( ' interface_state ' ) , gradio ( ' image_output_gallery ' ) )
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_neg_prompt ' ] . submit (
ui . gather_interface_values , gradio ( shared . input_elements ) , gradio ( ' interface_state ' ) ) . then (
2025-12-01 22:59:10 +01:00
generate , gradio ( ' interface_state ' ) , gradio ( ' image_output_gallery ' ) )
2025-12-01 19:42:03 +01:00
# Model management
shared . gradio [ ' image_refresh_models ' ] . click (
lambda : gr . update ( choices = utils . get_available_image_models ( ) ) ,
None ,
gradio ( ' image_model_menu ' ) ,
show_progress = False
)
shared . gradio [ ' image_load_model ' ] . click (
load_image_model_wrapper ,
2025-12-02 02:05:42 +01:00
gradio ( ' image_model_menu ' , ' image_dtype ' , ' image_attn_backend ' , ' image_cpu_offload ' , ' image_compile ' , ' image_quant ' ) ,
2025-12-01 19:42:03 +01:00
gradio ( ' image_model_status ' ) ,
show_progress = True
)
shared . gradio [ ' image_unload_model ' ] . click (
unload_image_model_wrapper ,
None ,
gradio ( ' image_model_status ' ) ,
show_progress = False
)
shared . gradio [ ' image_download_btn ' ] . click (
download_image_model_wrapper ,
gradio ( ' image_download_path ' ) ,
gradio ( ' image_model_status ' , ' image_model_menu ' ) ,
show_progress = True
)
2025-12-02 00:41:58 +01:00
# Gallery pagination handlers
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_refresh_history ' ] . click (
2025-12-02 00:41:58 +01:00
refresh_gallery ,
gradio ( ' image_current_page ' ) ,
gradio ( ' image_history_gallery ' , ' image_current_page ' , ' image_page_info ' ) ,
show_progress = False
)
shared . gradio [ ' image_next_page ' ] . click (
next_page ,
gradio ( ' image_current_page ' ) ,
gradio ( ' image_history_gallery ' , ' image_current_page ' , ' image_page_info ' ) ,
show_progress = False
)
shared . gradio [ ' image_prev_page ' ] . click (
prev_page ,
gradio ( ' image_current_page ' ) ,
gradio ( ' image_history_gallery ' , ' image_current_page ' , ' image_page_info ' ) ,
show_progress = False
)
shared . gradio [ ' image_go_to_page ' ] . click (
go_to_page ,
gradio ( ' image_page_input ' , ' image_current_page ' ) ,
gradio ( ' image_history_gallery ' , ' image_current_page ' , ' image_page_info ' ) ,
show_progress = False
)
# Image selection from gallery
shared . gradio [ ' image_history_gallery ' ] . select (
on_gallery_select ,
gradio ( ' image_current_page ' ) ,
gradio ( ' image_selected_path ' , ' image_settings_display ' ) ,
show_progress = False
)
# Send to Generate
shared . gradio [ ' image_send_to_generate ' ] . click (
send_to_generate ,
gradio ( ' image_selected_path ' ) ,
gradio (
' image_prompt ' ,
' image_neg_prompt ' ,
' image_width ' ,
' image_height ' ,
' image_aspect_ratio ' ,
' image_steps ' ,
' image_seed ' ,
' image_batch_size ' ,
' image_batch_count ' ,
' image_gallery_status '
) ,
2025-12-01 19:42:03 +01:00
show_progress = False
)
def generate ( state ) :
model_name = state [ ' image_model_menu ' ]
if not model_name or model_name == ' None ' :
2025-12-01 22:59:10 +01:00
logger . error ( " No image model selected. Go to the Model tab and select a model. " )
return [ ]
2025-11-28 00:48:53 +01:00
2025-11-27 23:24:35 +01:00
if shared . image_model is None :
2025-11-28 00:32:01 +01:00
result = load_image_model (
model_name ,
2025-12-01 19:42:03 +01:00
dtype = state [ ' image_dtype ' ] ,
attn_backend = state [ ' image_attn_backend ' ] ,
cpu_offload = state [ ' image_cpu_offload ' ] ,
2025-12-02 02:05:42 +01:00
compile_model = state [ ' image_compile ' ] ,
quant_method = state [ ' image_quant ' ]
2025-11-28 00:32:01 +01:00
)
if result is None :
2025-12-01 22:59:10 +01:00
logger . error ( f " Failed to load model ` { model_name } `. " )
return [ ]
2025-11-28 00:48:53 +01:00
2025-12-01 19:42:03 +01:00
shared . image_model_name = model_name
seed = state [ ' image_seed ' ]
2025-11-27 23:24:35 +01:00
if seed == - 1 :
seed = np . random . randint ( 0 , 2 * * 32 - 1 )
2025-11-28 00:48:53 +01:00
2025-11-27 22:53:46 +01:00
generator = torch . Generator ( " cuda " ) . manual_seed ( int ( seed ) )
all_images = [ ]
2025-11-28 00:48:53 +01:00
2025-12-01 22:59:10 +01:00
t0 = time . time ( )
2025-12-01 19:42:03 +01:00
for i in range ( int ( state [ ' image_batch_count ' ] ) ) :
generator . manual_seed ( int ( seed + i ) )
2025-11-27 23:24:35 +01:00
batch_results = shared . image_model (
2025-12-01 19:42:03 +01:00
prompt = state [ ' image_prompt ' ] ,
negative_prompt = state [ ' image_neg_prompt ' ] ,
height = int ( state [ ' image_height ' ] ) ,
width = int ( state [ ' image_width ' ] ) ,
num_inference_steps = int ( state [ ' image_steps ' ] ) ,
2025-11-27 22:53:46 +01:00
guidance_scale = 0.0 ,
2025-12-01 19:42:03 +01:00
num_images_per_prompt = int ( state [ ' image_batch_size ' ] ) ,
2025-11-27 22:53:46 +01:00
generator = generator ,
) . images
all_images . extend ( batch_results )
2025-11-28 00:48:53 +01:00
2025-12-01 22:59:10 +01:00
t1 = time . time ( )
2025-12-02 00:41:58 +01:00
save_generated_images ( all_images , state , seed )
2025-12-01 22:59:10 +01:00
2025-12-02 00:44:31 +01:00
logger . info ( f ' Images generated in { ( t1 - t0 ) : .2f } seconds ( { state [ " image_steps " ] / ( t1 - t0 ) : .2f } steps/s, seed { seed } ) ' )
2025-12-01 22:59:10 +01:00
return all_images
2025-11-27 22:53:46 +01:00
2025-12-02 02:05:42 +01:00
def load_image_model_wrapper ( model_name , dtype , attn_backend , cpu_offload , compile_model , quant_method ) :
2025-12-01 19:42:03 +01:00
if not model_name or model_name == ' None ' :
2025-11-28 00:32:01 +01:00
yield " No model selected "
return
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
try :
yield f " Loading ` { model_name } `... "
unload_image_model ( )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
result = load_image_model (
model_name ,
dtype = dtype ,
attn_backend = attn_backend ,
cpu_offload = cpu_offload ,
2025-12-02 02:05:42 +01:00
compile_model = compile_model ,
quant_method = quant_method
2025-11-28 00:32:01 +01:00
)
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
if result is not None :
2025-12-01 19:42:03 +01:00
shared . image_model_name = model_name
2025-12-02 02:05:42 +01:00
yield f " ✓ Loaded ** { model_name } ** (quantization: { quant_method } ) "
2025-11-28 00:32:01 +01:00
else :
yield f " ✗ Failed to load ` { model_name } ` "
except Exception :
2025-12-01 19:42:03 +01:00
yield f " Error: \n ``` \n { traceback . format_exc ( ) } \n ``` "
2025-11-28 00:32:01 +01:00
def unload_image_model_wrapper ( ) :
unload_image_model ( )
if shared . image_model_name != ' None ' :
return f " Model: ** { shared . image_model_name } ** (not loaded) "
2025-12-01 19:42:03 +01:00
return " No model loaded "
2025-11-28 00:32:01 +01:00
def download_image_model_wrapper ( model_path ) :
from huggingface_hub import snapshot_download
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
if not model_path :
yield " No model specified " , gr . update ( )
return
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
try :
if ' : ' in model_path :
model_id , branch = model_path . rsplit ( ' : ' , 1 )
else :
model_id , branch = model_path , ' main '
2025-11-28 00:48:53 +01:00
2025-11-28 01:37:03 +01:00
folder_name = model_id . replace ( ' / ' , ' _ ' )
2025-11-28 00:32:01 +01:00
output_folder = Path ( shared . args . image_model_dir ) / folder_name
2025-11-28 00:48:53 +01:00
2025-12-01 19:42:03 +01:00
yield f " Downloading ` { model_id } ` (branch: { branch } )... " , gr . update ( )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
snapshot_download (
repo_id = model_id ,
revision = branch ,
local_dir = output_folder ,
local_dir_use_symlinks = False ,
)
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
new_choices = utils . get_available_image_models ( )
yield f " ✓ Downloaded to ` { output_folder } ` " , gr . update ( choices = new_choices , value = folder_name )
except Exception :
2025-12-01 19:42:03 +01:00
yield f " Error: \n ``` \n { traceback . format_exc ( ) } \n ``` " , gr . update ( )