2025-11-27 22:44:07 +01:00
import os
2025-11-28 00:32:01 +01:00
import traceback
2025-11-27 23:24:35 +01:00
from datetime import datetime
2025-11-28 00:32:01 +01:00
from pathlib import Path
2025-11-27 23:24:35 +01:00
import gradio as gr
import numpy as np
import torch
2025-12-01 19:42:03 +01:00
from modules import shared , ui , utils
2025-11-27 23:24:35 +01:00
from modules . image_models import load_image_model , unload_image_model
2025-12-01 19:42:03 +01:00
from modules . utils import gradio
2025-11-27 19:10:11 +01:00
2025-11-27 23:38:50 +01:00
ASPECT_RATIOS = {
" 1:1 Square " : ( 1 , 1 ) ,
" 16:9 Cinema " : ( 16 , 9 ) ,
" 9:16 Mobile " : ( 9 , 16 ) ,
" 4:3 Photo " : ( 4 , 3 ) ,
" Custom " : None ,
}
2025-12-01 19:42:03 +01:00
STEP = 32
2025-11-27 23:38:50 +01:00
2025-11-28 00:32:01 +01:00
def round_to_step ( value , step = STEP ) :
return round ( value / step ) * step
def clamp ( value , min_val , max_val ) :
return max ( min_val , min ( max_val , value ) )
def apply_aspect_ratio ( aspect_ratio , current_width , current_height ) :
if aspect_ratio == " Custom " or aspect_ratio not in ASPECT_RATIOS :
return current_width , current_height
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
w_ratio , h_ratio = ASPECT_RATIOS [ aspect_ratio ]
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
if w_ratio == h_ratio :
base = min ( current_width , current_height )
new_width = base
new_height = base
elif w_ratio < h_ratio :
new_width = current_width
new_height = round_to_step ( current_width * h_ratio / w_ratio )
else :
new_height = current_height
new_width = round_to_step ( current_height * w_ratio / h_ratio )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
new_width = clamp ( new_width , 256 , 2048 )
new_height = clamp ( new_height , 256 , 2048 )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
return int ( new_width ) , int ( new_height )
def update_height_from_width ( width , aspect_ratio ) :
if aspect_ratio == " Custom " or aspect_ratio not in ASPECT_RATIOS :
return gr . update ( )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
w_ratio , h_ratio = ASPECT_RATIOS [ aspect_ratio ]
new_height = round_to_step ( width * h_ratio / w_ratio )
new_height = clamp ( new_height , 256 , 2048 )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
return int ( new_height )
def update_width_from_height ( height , aspect_ratio ) :
if aspect_ratio == " Custom " or aspect_ratio not in ASPECT_RATIOS :
return gr . update ( )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
w_ratio , h_ratio = ASPECT_RATIOS [ aspect_ratio ]
new_width = round_to_step ( height * w_ratio / h_ratio )
new_width = clamp ( new_width , 256 , 2048 )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
return int ( new_width )
def swap_dimensions_and_update_ratio ( width , height , aspect_ratio ) :
new_width , new_height = height , width
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
new_ratio = " Custom "
for name , ratios in ASPECT_RATIOS . items ( ) :
if ratios is None :
continue
w_r , h_r = ratios
expected_height = new_width * h_r / w_r
if abs ( expected_height - new_height ) < STEP :
new_ratio = name
break
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
return new_width , new_height , new_ratio
2025-11-27 19:10:11 +01:00
def create_ui ( ) :
2025-12-01 19:42:03 +01:00
if shared . settings [ ' image_model_menu ' ] != ' None ' :
shared . image_model_name = shared . settings [ ' image_model_menu ' ]
2025-11-28 00:48:53 +01:00
2025-11-27 22:44:07 +01:00
with gr . Tab ( " Image AI " , elem_id = " image-ai-tab " ) :
with gr . Tabs ( ) :
2025-12-01 19:42:03 +01:00
# TAB 1: GENERATE
2025-11-28 00:32:01 +01:00
with gr . TabItem ( " Generate " ) :
2025-11-27 22:44:07 +01:00
with gr . Row ( ) :
with gr . Column ( scale = 4 , min_width = 350 ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_prompt ' ] = gr . Textbox (
label = " Prompt " ,
placeholder = " Describe your imagination... " ,
lines = 3 ,
autofocus = True ,
value = shared . settings [ ' image_prompt ' ]
)
shared . gradio [ ' image_neg_prompt ' ] = gr . Textbox (
label = " Negative Prompt " ,
placeholder = " Low quality... " ,
lines = 3 ,
value = shared . settings [ ' image_neg_prompt ' ]
)
2025-11-28 00:48:53 +01:00
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_generate_btn ' ] = gr . Button ( " ✨ GENERATE " , variant = " primary " , size = " lg " , elem_id = " gen-btn " )
2025-11-27 22:44:07 +01:00
gr . HTML ( " <hr style= ' border-top: 1px solid #444; margin: 20px 0; ' > " )
gr . Markdown ( " ### 📐 Dimensions " )
with gr . Row ( ) :
with gr . Column ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_width ' ] = gr . Slider ( 256 , 2048 , value = shared . settings [ ' image_width ' ] , step = 32 , label = " Width " )
2025-11-27 22:44:07 +01:00
with gr . Column ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_height ' ] = gr . Slider ( 256 , 2048 , value = shared . settings [ ' image_height ' ] , step = 32 , label = " Height " )
shared . gradio [ ' image_swap_btn ' ] = gr . Button ( " ⇄ Swap " , elem_classes = ' refresh-button ' , scale = 0 , min_width = 80 , elem_id = " swap-height-width " )
2025-11-28 00:42:11 +01:00
2025-11-27 23:38:50 +01:00
with gr . Row ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_aspect_ratio ' ] = gr . Radio (
2025-11-27 23:38:50 +01:00
choices = [ " 1:1 Square " , " 16:9 Cinema " , " 9:16 Mobile " , " 4:3 Photo " , " Custom " ] ,
2025-12-01 19:42:03 +01:00
value = shared . settings [ ' image_aspect_ratio ' ] ,
2025-11-27 23:38:50 +01:00
label = " Aspect Ratio " ,
interactive = True
)
2025-11-27 22:44:07 +01:00
2025-11-28 00:32:01 +01:00
gr . Markdown ( " ### ⚙️ Config " )
2025-11-27 22:44:07 +01:00
with gr . Row ( ) :
with gr . Column ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_steps ' ] = gr . Slider ( 1 , 15 , value = shared . settings [ ' image_steps ' ] , step = 1 , label = " Steps " )
shared . gradio [ ' image_seed ' ] = gr . Number ( label = " Seed " , value = shared . settings [ ' image_seed ' ] , precision = 0 , info = " -1 = Random " )
2025-11-27 22:44:07 +01:00
with gr . Column ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_batch_size ' ] = gr . Slider ( 1 , 32 , value = shared . settings [ ' image_batch_size ' ] , step = 1 , label = " Batch Size (VRAM Heavy) " , info = " Generates N images at once. " )
shared . gradio [ ' image_batch_count ' ] = gr . Slider ( 1 , 128 , value = shared . settings [ ' image_batch_count ' ] , step = 1 , label = " Sequential Count (Loop) " , info = " Repeats the generation N times. " )
2025-11-28 00:48:53 +01:00
2025-11-27 22:44:07 +01:00
with gr . Column ( scale = 6 , min_width = 500 ) :
with gr . Column ( elem_classes = [ " viewport-container " ] ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_output_gallery ' ] = gr . Gallery ( label = " Output " , show_label = False , columns = 2 , rows = 2 , height = " 80vh " , object_fit = " contain " , preview = True )
2025-11-27 22:44:07 +01:00
with gr . Row ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_used_seed ' ] = gr . Markdown ( label = " Info " , interactive = False )
2025-11-27 22:44:07 +01:00
2025-12-01 19:42:03 +01:00
# TAB 2: GALLERY
2025-11-27 22:44:07 +01:00
with gr . TabItem ( " Gallery " ) :
with gr . Row ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_refresh_history ' ] = gr . Button ( " 🔄 Refresh Gallery " , elem_classes = " refresh-button " )
shared . gradio [ ' image_history_gallery ' ] = gr . Gallery ( label = " History " , show_label = False , columns = 6 , object_fit = " cover " , height = " auto " , allow_preview = True )
2025-11-27 22:44:07 +01:00
2025-12-01 19:42:03 +01:00
# TAB 3: MODEL
2025-11-28 00:32:01 +01:00
with gr . TabItem ( " Model " ) :
with gr . Row ( ) :
with gr . Column ( ) :
with gr . Row ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_model_menu ' ] = gr . Dropdown (
2025-11-28 00:32:01 +01:00
choices = utils . get_available_image_models ( ) ,
2025-12-01 19:42:03 +01:00
value = shared . settings [ ' image_model_menu ' ] ,
2025-11-28 00:32:01 +01:00
label = ' Model ' ,
elem_classes = ' slim-dropdown '
)
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_refresh_models ' ] = gr . Button ( " 🔄 " , elem_classes = ' refresh-button ' , scale = 0 , min_width = 40 )
shared . gradio [ ' image_load_model ' ] = gr . Button ( " Load " , variant = ' primary ' , elem_classes = ' refresh-button ' )
shared . gradio [ ' image_unload_model ' ] = gr . Button ( " Unload " , elem_classes = ' refresh-button ' )
2025-11-28 00:48:53 +01:00
2025-11-28 00:42:11 +01:00
gr . Markdown ( " ## Settings " )
with gr . Row ( ) :
with gr . Column ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_dtype ' ] = gr . Dropdown (
2025-11-28 00:42:11 +01:00
choices = [ ' bfloat16 ' , ' float16 ' ] ,
2025-12-01 19:42:03 +01:00
value = shared . settings [ ' image_dtype ' ] ,
2025-11-28 00:42:11 +01:00
label = ' Data Type ' ,
info = ' bfloat16 recommended for modern GPUs '
)
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_attn_backend ' ] = gr . Dropdown (
2025-11-28 00:42:11 +01:00
choices = [ ' sdpa ' , ' flash_attention_2 ' , ' flash_attention_3 ' ] ,
2025-12-01 19:42:03 +01:00
value = shared . settings [ ' image_attn_backend ' ] ,
2025-11-28 00:42:11 +01:00
label = ' Attention Backend ' ,
info = ' SDPA is default. Flash Attention requires compatible GPU. '
)
with gr . Column ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_compile ' ] = gr . Checkbox (
value = shared . settings [ ' image_compile ' ] ,
2025-11-28 00:42:11 +01:00
label = ' Compile Model ' ,
info = ' Faster inference after first run. First run will be slow. '
)
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_cpu_offload ' ] = gr . Checkbox (
value = shared . settings [ ' image_cpu_offload ' ] ,
2025-11-28 00:42:11 +01:00
label = ' CPU Offload ' ,
info = ' Enable for low VRAM GPUs. Slower but uses less memory. '
)
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
with gr . Column ( ) :
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_download_path ' ] = gr . Textbox (
2025-11-28 00:42:11 +01:00
label = " Download model " ,
2025-11-28 00:32:01 +01:00
placeholder = " Tongyi-MAI/Z-Image-Turbo " ,
2025-12-01 19:42:03 +01:00
info = " Enter HuggingFace path. Use : for branch, e.g. user/model:main "
2025-11-28 00:32:01 +01:00
)
2025-12-01 19:42:03 +01:00
shared . gradio [ ' image_download_btn ' ] = gr . Button ( " Download " , variant = ' primary ' )
shared . gradio [ ' image_model_status ' ] = gr . Markdown (
value = f " Model: ** { shared . settings [ ' image_model_menu ' ] } ** (not loaded) " if shared . settings [ ' image_model_menu ' ] != ' None ' else " No model selected "
2025-11-28 00:42:11 +01:00
)
2025-11-27 23:25:49 +01:00
2025-11-28 00:32:01 +01:00
2025-12-01 19:42:03 +01:00
def create_event_handlers ( ) :
# Dimension controls
shared . gradio [ ' image_aspect_ratio ' ] . change (
apply_aspect_ratio ,
gradio ( ' image_aspect_ratio ' , ' image_width ' , ' image_height ' ) ,
gradio ( ' image_width ' , ' image_height ' ) ,
show_progress = False
)
shared . gradio [ ' image_width ' ] . release (
update_height_from_width ,
gradio ( ' image_width ' , ' image_aspect_ratio ' ) ,
gradio ( ' image_height ' ) ,
show_progress = False
)
shared . gradio [ ' image_height ' ] . release (
update_width_from_height ,
gradio ( ' image_height ' , ' image_aspect_ratio ' ) ,
gradio ( ' image_width ' ) ,
show_progress = False
)
shared . gradio [ ' image_swap_btn ' ] . click (
swap_dimensions_and_update_ratio ,
gradio ( ' image_width ' , ' image_height ' , ' image_aspect_ratio ' ) ,
gradio ( ' image_width ' , ' image_height ' , ' image_aspect_ratio ' ) ,
show_progress = False
)
# Generation
shared . gradio [ ' image_generate_btn ' ] . click (
ui . gather_interface_values , gradio ( shared . input_elements ) , gradio ( ' interface_state ' ) ) . then (
generate , gradio ( ' interface_state ' ) , gradio ( ' image_output_gallery ' , ' image_used_seed ' ) )
shared . gradio [ ' image_prompt ' ] . submit (
ui . gather_interface_values , gradio ( shared . input_elements ) , gradio ( ' interface_state ' ) ) . then (
generate , gradio ( ' interface_state ' ) , gradio ( ' image_output_gallery ' , ' image_used_seed ' ) )
shared . gradio [ ' image_neg_prompt ' ] . submit (
ui . gather_interface_values , gradio ( shared . input_elements ) , gradio ( ' interface_state ' ) ) . then (
generate , gradio ( ' interface_state ' ) , gradio ( ' image_output_gallery ' , ' image_used_seed ' ) )
# Model management
shared . gradio [ ' image_refresh_models ' ] . click (
lambda : gr . update ( choices = utils . get_available_image_models ( ) ) ,
None ,
gradio ( ' image_model_menu ' ) ,
show_progress = False
)
shared . gradio [ ' image_load_model ' ] . click (
load_image_model_wrapper ,
gradio ( ' image_model_menu ' , ' image_dtype ' , ' image_attn_backend ' , ' image_cpu_offload ' , ' image_compile ' ) ,
gradio ( ' image_model_status ' ) ,
show_progress = True
)
shared . gradio [ ' image_unload_model ' ] . click (
unload_image_model_wrapper ,
None ,
gradio ( ' image_model_status ' ) ,
show_progress = False
)
shared . gradio [ ' image_download_btn ' ] . click (
download_image_model_wrapper ,
gradio ( ' image_download_path ' ) ,
gradio ( ' image_model_status ' , ' image_model_menu ' ) ,
show_progress = True
)
# History
shared . gradio [ ' image_refresh_history ' ] . click (
get_history_images ,
None ,
gradio ( ' image_history_gallery ' ) ,
show_progress = False
)
def generate ( state ) :
model_name = state [ ' image_model_menu ' ]
if not model_name or model_name == ' None ' :
2025-11-28 00:32:01 +01:00
return [ ] , " No image model selected. Go to the Model tab and select a model. "
2025-11-28 00:48:53 +01:00
2025-11-27 23:24:35 +01:00
if shared . image_model is None :
2025-11-28 00:32:01 +01:00
result = load_image_model (
model_name ,
2025-12-01 19:42:03 +01:00
dtype = state [ ' image_dtype ' ] ,
attn_backend = state [ ' image_attn_backend ' ] ,
cpu_offload = state [ ' image_cpu_offload ' ] ,
compile_model = state [ ' image_compile ' ]
2025-11-28 00:32:01 +01:00
)
if result is None :
return [ ] , f " Failed to load model ` { model_name } `. "
2025-11-28 00:48:53 +01:00
2025-12-01 19:42:03 +01:00
shared . image_model_name = model_name
seed = state [ ' image_seed ' ]
2025-11-27 23:24:35 +01:00
if seed == - 1 :
seed = np . random . randint ( 0 , 2 * * 32 - 1 )
2025-11-28 00:48:53 +01:00
2025-11-27 22:53:46 +01:00
generator = torch . Generator ( " cuda " ) . manual_seed ( int ( seed ) )
all_images = [ ]
2025-11-28 00:48:53 +01:00
2025-12-01 19:42:03 +01:00
for i in range ( int ( state [ ' image_batch_count ' ] ) ) :
generator . manual_seed ( int ( seed + i ) )
2025-11-27 23:24:35 +01:00
batch_results = shared . image_model (
2025-12-01 19:42:03 +01:00
prompt = state [ ' image_prompt ' ] ,
negative_prompt = state [ ' image_neg_prompt ' ] ,
height = int ( state [ ' image_height ' ] ) ,
width = int ( state [ ' image_width ' ] ) ,
num_inference_steps = int ( state [ ' image_steps ' ] ) ,
2025-11-27 22:53:46 +01:00
guidance_scale = 0.0 ,
2025-12-01 19:42:03 +01:00
num_images_per_prompt = int ( state [ ' image_batch_size ' ] ) ,
2025-11-27 22:53:46 +01:00
generator = generator ,
) . images
all_images . extend ( batch_results )
2025-11-28 00:48:53 +01:00
2025-12-01 19:42:03 +01:00
save_generated_images ( all_images , state [ ' image_prompt ' ] , seed )
2025-11-27 23:24:35 +01:00
return all_images , f " Seed: { seed } "
2025-11-27 22:53:46 +01:00
2025-11-28 00:32:01 +01:00
def load_image_model_wrapper ( model_name , dtype , attn_backend , cpu_offload , compile_model ) :
2025-12-01 19:42:03 +01:00
if not model_name or model_name == ' None ' :
2025-11-28 00:32:01 +01:00
yield " No model selected "
return
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
try :
yield f " Loading ` { model_name } `... "
unload_image_model ( )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
result = load_image_model (
model_name ,
dtype = dtype ,
attn_backend = attn_backend ,
cpu_offload = cpu_offload ,
compile_model = compile_model
)
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
if result is not None :
2025-12-01 19:42:03 +01:00
shared . image_model_name = model_name
2025-11-28 00:32:01 +01:00
yield f " ✓ Loaded ** { model_name } ** "
else :
yield f " ✗ Failed to load ` { model_name } ` "
except Exception :
2025-12-01 19:42:03 +01:00
yield f " Error: \n ``` \n { traceback . format_exc ( ) } \n ``` "
2025-11-28 00:32:01 +01:00
def unload_image_model_wrapper ( ) :
unload_image_model ( )
if shared . image_model_name != ' None ' :
return f " Model: ** { shared . image_model_name } ** (not loaded) "
2025-12-01 19:42:03 +01:00
return " No model loaded "
2025-11-28 00:32:01 +01:00
def download_image_model_wrapper ( model_path ) :
from huggingface_hub import snapshot_download
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
if not model_path :
yield " No model specified " , gr . update ( )
return
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
try :
if ' : ' in model_path :
model_id , branch = model_path . rsplit ( ' : ' , 1 )
else :
model_id , branch = model_path , ' main '
2025-11-28 00:48:53 +01:00
2025-11-28 01:37:03 +01:00
folder_name = model_id . replace ( ' / ' , ' _ ' )
2025-11-28 00:32:01 +01:00
output_folder = Path ( shared . args . image_model_dir ) / folder_name
2025-11-28 00:48:53 +01:00
2025-12-01 19:42:03 +01:00
yield f " Downloading ` { model_id } ` (branch: { branch } )... " , gr . update ( )
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
snapshot_download (
repo_id = model_id ,
revision = branch ,
local_dir = output_folder ,
local_dir_use_symlinks = False ,
)
2025-11-28 00:48:53 +01:00
2025-11-28 00:32:01 +01:00
new_choices = utils . get_available_image_models ( )
yield f " ✓ Downloaded to ` { output_folder } ` " , gr . update ( choices = new_choices , value = folder_name )
except Exception :
2025-12-01 19:42:03 +01:00
yield f " Error: \n ``` \n { traceback . format_exc ( ) } \n ``` " , gr . update ( )
2025-11-28 00:32:01 +01:00
2025-11-27 22:53:46 +01:00
def save_generated_images ( images , prompt , seed ) :
date_str = datetime . now ( ) . strftime ( " % Y- % m- %d " )
2025-11-28 00:32:01 +01:00
folder_path = os . path . join ( " user_data " , " image_outputs " , date_str )
2025-11-27 22:53:46 +01:00
os . makedirs ( folder_path , exist_ok = True )
for idx , img in enumerate ( images ) :
timestamp = datetime . now ( ) . strftime ( " % H- % M- % S " )
filename = f " { timestamp } _ { seed } _ { idx } .png "
2025-12-01 19:42:03 +01:00
img . save ( os . path . join ( folder_path , filename ) )
2025-11-27 22:53:46 +01:00
def get_history_images ( ) :
2025-11-28 00:32:01 +01:00
output_dir = os . path . join ( " user_data " , " image_outputs " )
if not os . path . exists ( output_dir ) :
2025-11-27 22:53:46 +01:00
return [ ]
image_files = [ ]
2025-12-01 19:42:03 +01:00
for root , _ , files in os . walk ( output_dir ) :
2025-11-27 22:53:46 +01:00
for file in files :
if file . endswith ( ( " .png " , " .jpg " , " .jpeg " ) ) :
full_path = os . path . join ( root , file )
2025-12-01 19:42:03 +01:00
image_files . append ( ( full_path , os . path . getmtime ( full_path ) ) )
2025-11-27 22:53:46 +01:00
image_files . sort ( key = lambda x : x [ 1 ] , reverse = True )
return [ x [ 0 ] for x in image_files ]