From 4ad2ad468e6bd29055dc6a66f006bab6f9d7f4c7 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 10:10:11 -0800
Subject: [PATCH 01/38] Add basic structure
---
modules/ui_image_generation.py | 9 +++++++++
server.py | 2 ++
2 files changed, 11 insertions(+)
create mode 100644 modules/ui_image_generation.py
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
new file mode 100644
index 00000000..389bc1e2
--- /dev/null
+++ b/modules/ui_image_generation.py
@@ -0,0 +1,9 @@
+import gradio as gr
+
+
+def create_ui():
+ pass
+
+
+def create_event_handlers():
+ pass
diff --git a/server.py b/server.py
index c804c342..87bbdc4a 100644
--- a/server.py
+++ b/server.py
@@ -50,6 +50,7 @@ from modules import (
ui_chat,
ui_default,
ui_file_saving,
+ ui_image_generation,
ui_model_menu,
ui_notebook,
ui_parameters,
@@ -163,6 +164,7 @@ def create_interface():
ui_chat.create_character_settings_ui() # Character tab
ui_model_menu.create_ui() # Model tab
if not shared.args.portable:
+ ui_image_generation.create_ui() # Image generation tab
training.create_ui() # Training tab
ui_session.create_ui() # Session tab
From 164c6fcdbf70266c4c7ce78b40a6c7c78577e53a Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 13:44:07 -0800
Subject: [PATCH 02/38] Add the UI structure
---
css/main.css | 4 +-
modules/shared.py | 9 +-
modules/ui_image_generation.py | 88 ++++++++++++-
modules/ui_model_menu.py | 223 +++++++++++++++++++--------------
4 files changed, 227 insertions(+), 97 deletions(-)
diff --git a/css/main.css b/css/main.css
index fd79d24c..61a33a4b 100644
--- a/css/main.css
+++ b/css/main.css
@@ -93,11 +93,11 @@ ol li p, ul li p {
display: inline-block;
}
-#notebook-parent-tab, #chat-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab {
+#notebook-parent-tab, #chat-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab, #image-ai-tab {
border: 0;
}
-#notebook-parent-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab {
+#notebook-parent-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab, #image-ai-tab {
padding: 1rem;
}
diff --git a/modules/shared.py b/modules/shared.py
index 289fda54..2fbc205f 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -11,7 +11,7 @@ import yaml
from modules.logging_colors import logger
from modules.presets import default_preset
-# Model variables
+# Text model variables
model = None
tokenizer = None
model_name = 'None'
@@ -20,6 +20,9 @@ is_multimodal = False
model_dirty_from_training = False
lora_names = []
+# Image model variables
+image_model = None
+
# Generation variables
stop_everything = False
generation_lock = None
@@ -46,6 +49,10 @@ group.add_argument('--extensions', type=str, nargs='+', help='The list of extens
group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.')
+# Image generation
+group.add_argument('--image-model', type=str, help='Name of the image model to load by default.')
+group.add_argument('--image-model-dir', type=str, default='user_data/image_models', help='Path to directory with all the image models.')
+
# Model loader
group = parser.add_argument_group('Model loader')
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, TensorRT-LLM.')
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 389bc1e2..5b5c624d 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -1,8 +1,94 @@
import gradio as gr
+import os
def create_ui():
- pass
+ with gr.Tab("Image AI", elem_id="image-ai-tab"):
+ with gr.Tabs():
+ # TAB 1: GENERATION STUDIO
+ with gr.TabItem("Generate Images"):
+ with gr.Row():
+
+ # === LEFT COLUMN: CONTROLS ===
+ with gr.Column(scale=4, min_width=350):
+
+ # 1. PROMPT
+ prompt = gr.Textbox(label="Prompt", placeholder="Describe your imagination...", lines=3, autofocus=True)
+ neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="Low quality...", lines=3)
+
+ # 2. GENERATE BUTTON
+ generate_btn = gr.Button("✨ GENERATE", variant="primary", size="lg", elem_id="gen-btn")
+ gr.HTML("
")
+
+ # 3. DIMENSIONS
+ gr.Markdown("### 📐 Dimensions")
+ with gr.Row():
+ with gr.Column():
+ width_slider = gr.Slider(256, 2048, value=1024, step=32, label="Width")
+
+ with gr.Column():
+ height_slider = gr.Slider(256, 2048, value=1024, step=32, label="Height")
+
+ preset_radio = gr.Radio(
+ choices=["1:1 Square", "16:9 Cinema", "9:16 Mobile", "4:3 Photo", "Custom"],
+ value="1:1 Square",
+ label="Aspect Ratio",
+ interactive=True
+ )
+
+ # 4. SETTINGS & BATCHING
+ gr.Markdown("### ⚙️ Config")
+ with gr.Row():
+ with gr.Column():
+ steps_slider = gr.Slider(1, 15, value=9, step=1, label="Steps")
+ cfg_slider = gr.Slider(value=0.0, label="Guidance", interactive=False, info="Locked")
+ seed_input = gr.Number(label="Seed", value=-1, precision=0, info="-1 = Random")
+
+ with gr.Column():
+ batch_size_parallel = gr.Slider(1, 32, value=1, step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
+ batch_count_seq = gr.Slider(1, 128, value=1, step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
+
+ # === RIGHT COLUMN: VIEWPORT ===
+ with gr.Column(scale=6, min_width=500):
+ with gr.Column(elem_classes=["viewport-container"]):
+ output_gallery = gr.Gallery(
+ label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True
+ )
+ with gr.Row():
+ used_seed = gr.Markdown(label="Info", interactive=False, lines=3)
+
+ # TAB 2: HISTORY VIEWER
+ with gr.TabItem("Gallery"):
+ with gr.Row():
+ refresh_btn = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
+
+ history_gallery = gr.Gallery(
+ label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True
+ )
+
+ # === WIRING ===
+
+ # Aspect Buttons
+ # btn_sq.click(lambda: set_dims(1024, 1024), outputs=[width_slider, height_slider])
+ # btn_port.click(lambda: set_dims(720, 1280), outputs=[width_slider, height_slider])
+ # btn_land.click(lambda: set_dims(1280, 720), outputs=[width_slider, height_slider])
+ # btn_wide.click(lambda: set_dims(1536, 640), outputs=[width_slider, height_slider])
+
+ # Generation
+ inputs = [prompt, neg_prompt, width_slider, height_slider, steps_slider, seed_input, batch_size_parallel, batch_count_seq]
+ outputs = [output_gallery, used_seed]
+
+ # generate_btn.click(fn=generate, inputs=inputs, outputs=outputs)
+ # prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
+ # neg_prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
+
+ # System
+ # load_btn.click(fn=load_pipeline, inputs=[backend_drop, compile_check, offload_check, gr.State("bfloat16")], outputs=None)
+
+ # History
+ # refresh_btn.click(fn=get_history_images, inputs=None, outputs=history_gallery)
+ # Load history on app launch
+ # demo.load(fn=get_history_images, inputs=None, outputs=history_gallery)
def create_event_handlers():
diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py
index 86adc229..dbbd3274 100644
--- a/modules/ui_model_menu.py
+++ b/modules/ui_model_menu.py
@@ -27,112 +27,149 @@ def create_ui():
mu = shared.args.multi_user
with gr.Tab("Model", elem_id="model-tab"):
- with gr.Row():
- with gr.Column():
- with gr.Row():
- shared.gradio['model_menu'] = gr.Dropdown(choices=utils.get_available_models(), value=lambda: shared.model_name, label='Model', elem_classes='slim-dropdown', interactive=not mu)
- ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu)
- shared.gradio['load_model'] = gr.Button("Load", elem_classes='refresh-button', interactive=not mu)
- shared.gradio['unload_model'] = gr.Button("Unload", elem_classes='refresh-button', interactive=not mu)
- shared.gradio['save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button', interactive=not mu)
-
- shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=loaders.loaders_and_params.keys() if not shared.args.portable else ['llama.cpp'], value=None)
- with gr.Blocks():
- gr.Markdown("## Main options")
+ with gr.Tab("Text model"):
+ with gr.Row():
+ with gr.Column():
with gr.Row():
- with gr.Column():
- shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=0, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Must be greater than 0 for the GPU to be used. ⚠️ Lower this value if you can\'t load the model.')
- shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=256, maximum=131072, step=256, value=shared.args.ctx_size, info='Context length. Common values: 4096, 8192, 16384, 32768, 65536, 131072.')
- shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
- shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.')
- shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
- shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.')
+ shared.gradio['model_menu'] = gr.Dropdown(choices=utils.get_available_models(), value=lambda: shared.model_name, label='Model', elem_classes='slim-dropdown', interactive=not mu)
+ ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu)
+ shared.gradio['load_model'] = gr.Button("Load", elem_classes='refresh-button', interactive=not mu)
+ shared.gradio['unload_model'] = gr.Button("Unload", elem_classes='refresh-button', interactive=not mu)
+ shared.gradio['save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button', interactive=not mu)
- with gr.Column():
- shared.gradio['vram_info'] = gr.HTML(value=get_initial_vram_info())
- shared.gradio['cpu_moe'] = gr.Checkbox(label="cpu-moe", value=shared.args.cpu_moe, info='Move the experts to the CPU. Saves VRAM on MoE models.')
- shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming-llm", value=shared.args.streaming_llm, info='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
- shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
- shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
- shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant, info='Used by load-in-4bit.')
- shared.gradio['autosplit'] = gr.Checkbox(label="autosplit", value=shared.args.autosplit, info='Automatically split the model tensors across the available GPUs.')
- shared.gradio['enable_tp'] = gr.Checkbox(label="enable_tp", value=shared.args.enable_tp, info='Enable tensor parallelism (TP).')
- shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.')
- shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
-
- # Multimodal
- with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']:
- with gr.Row():
- shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info='Select a file that matches your model. Must be placed in user_data/mmproj/', interactive=not mu)
- ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
-
- # Speculative decoding
- with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
- with gr.Row():
- shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=['None'] + utils.get_available_models(), value=lambda: shared.args.model_draft, elem_classes='slim-dropdown', info='Draft model. Speculative decoding only works with models sharing the same vocabulary (e.g., same model family).', interactive=not mu)
- ui.create_refresh_button(shared.gradio['model_draft'], lambda: None, lambda: {'choices': ['None'] + utils.get_available_models()}, 'refresh-button', interactive=not mu)
-
- shared.gradio['gpu_layers_draft'] = gr.Slider(label="gpu-layers-draft", minimum=0, maximum=256, value=shared.args.gpu_layers_draft, info='Number of layers to offload to the GPU for the draft model.')
- shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Number of tokens to draft for speculative decoding. Recommended value: 4.')
- shared.gradio['device_draft'] = gr.Textbox(label="device-draft", value=shared.args.device_draft, info='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1')
- shared.gradio['ctx_size_draft'] = gr.Number(label="ctx-size-draft", precision=0, step=256, value=shared.args.ctx_size_draft, info='Size of the prompt context for the draft model. If 0, uses the same as the main model.')
-
- gr.Markdown("## Other options")
- with gr.Accordion("See more options", open=False, elem_classes='tgw-accordion'):
+ shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=loaders.loaders_and_params.keys() if not shared.args.portable else ['llama.cpp'], value=None)
+ with gr.Blocks():
+ gr.Markdown("## Main options")
with gr.Row():
with gr.Column():
- shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
- shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
- shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
- shared.gradio['ubatch_size'] = gr.Slider(label="ubatch_size", minimum=1, maximum=4096, step=1, value=shared.args.ubatch_size)
- shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
- shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"', value=shared.args.extra_flags)
- shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory)
- shared.gradio['alpha_value'] = gr.Number(label='alpha_value', value=shared.args.alpha_value, precision=2, info='Positional embeddings alpha factor for NTK RoPE scaling. Recommended values (NTKv1): 1.75 for 1.5x context, 2.5 for 2x context. Use either this or compress_pos_emb, not both.')
- shared.gradio['rope_freq_base'] = gr.Number(label='rope_freq_base', value=shared.args.rope_freq_base, precision=0, info='Positional embeddings frequency base for NTK RoPE scaling. Related to alpha_value by rope_freq_base = 10000 * alpha_value ^ (64 / 63). 0 = from model.')
- shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
- shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
- shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
- shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
+ shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=0, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Must be greater than 0 for the GPU to be used. ⚠️ Lower this value if you can\'t load the model.')
+ shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=256, maximum=131072, step=256, value=shared.args.ctx_size, info='Context length. Common values: 4096, 8192, 16384, 32768, 65536, 131072.')
+ shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
+ shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.')
+ shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
+ shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.')
with gr.Column():
- shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='Use PyTorch in CPU mode.')
- shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
- shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.')
- shared.gradio['no_kv_offload'] = gr.Checkbox(label="no_kv_offload", value=shared.args.no_kv_offload, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
- shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
- shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
- shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.')
- shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
- shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn)
- shared.gradio['no_xformers'] = gr.Checkbox(label="no_xformers", value=shared.args.no_xformers)
- shared.gradio['no_sdpa'] = gr.Checkbox(label="no_sdpa", value=shared.args.no_sdpa)
- shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.')
- shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')
- if not shared.args.portable:
+ shared.gradio['vram_info'] = gr.HTML(value=get_initial_vram_info())
+ shared.gradio['cpu_moe'] = gr.Checkbox(label="cpu-moe", value=shared.args.cpu_moe, info='Move the experts to the CPU. Saves VRAM on MoE models.')
+ shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming-llm", value=shared.args.streaming_llm, info='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
+ shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
+ shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
+ shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant, info='Used by load-in-4bit.')
+ shared.gradio['autosplit'] = gr.Checkbox(label="autosplit", value=shared.args.autosplit, info='Automatically split the model tensors across the available GPUs.')
+ shared.gradio['enable_tp'] = gr.Checkbox(label="enable_tp", value=shared.args.enable_tp, info='Enable tensor parallelism (TP).')
+ shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.')
+ shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
+
+ # Multimodal
+ with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']:
with gr.Row():
- shared.gradio['lora_menu'] = gr.Dropdown(multiselect=True, choices=utils.get_available_loras(), value=shared.lora_names, label='LoRA(s)', elem_classes='slim-dropdown', interactive=not mu)
- ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': utils.get_available_loras(), 'value': shared.lora_names}, 'refresh-button', interactive=not mu)
- shared.gradio['lora_menu_apply'] = gr.Button(value='Apply LoRAs', elem_classes='refresh-button', interactive=not mu)
+ shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info='Select a file that matches your model. Must be placed in user_data/mmproj/', interactive=not mu)
+ ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
+
+ # Speculative decoding
+ with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
+ with gr.Row():
+ shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=['None'] + utils.get_available_models(), value=lambda: shared.args.model_draft, elem_classes='slim-dropdown', info='Draft model. Speculative decoding only works with models sharing the same vocabulary (e.g., same model family).', interactive=not mu)
+ ui.create_refresh_button(shared.gradio['model_draft'], lambda: None, lambda: {'choices': ['None'] + utils.get_available_models()}, 'refresh-button', interactive=not mu)
+
+ shared.gradio['gpu_layers_draft'] = gr.Slider(label="gpu-layers-draft", minimum=0, maximum=256, value=shared.args.gpu_layers_draft, info='Number of layers to offload to the GPU for the draft model.')
+ shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Number of tokens to draft for speculative decoding. Recommended value: 4.')
+ shared.gradio['device_draft'] = gr.Textbox(label="device-draft", value=shared.args.device_draft, info='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1')
+ shared.gradio['ctx_size_draft'] = gr.Number(label="ctx-size-draft", precision=0, step=256, value=shared.args.ctx_size_draft, info='Size of the prompt context for the draft model. If 0, uses the same as the main model.')
+
+ gr.Markdown("## Other options")
+ with gr.Accordion("See more options", open=False, elem_classes='tgw-accordion'):
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
+ shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
+ shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
+ shared.gradio['ubatch_size'] = gr.Slider(label="ubatch_size", minimum=1, maximum=4096, step=1, value=shared.args.ubatch_size)
+ shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
+ shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"', value=shared.args.extra_flags)
+ shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory)
+ shared.gradio['alpha_value'] = gr.Number(label='alpha_value', value=shared.args.alpha_value, precision=2, info='Positional embeddings alpha factor for NTK RoPE scaling. Recommended values (NTKv1): 1.75 for 1.5x context, 2.5 for 2x context. Use either this or compress_pos_emb, not both.')
+ shared.gradio['rope_freq_base'] = gr.Number(label='rope_freq_base', value=shared.args.rope_freq_base, precision=0, info='Positional embeddings frequency base for NTK RoPE scaling. Related to alpha_value by rope_freq_base = 10000 * alpha_value ^ (64 / 63). 0 = from model.')
+ shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
+ shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
+ shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
+ shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
+
+ with gr.Column():
+ shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='Use PyTorch in CPU mode.')
+ shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
+ shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.')
+ shared.gradio['no_kv_offload'] = gr.Checkbox(label="no_kv_offload", value=shared.args.no_kv_offload, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
+ shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
+ shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
+ shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.')
+ shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
+ shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn)
+ shared.gradio['no_xformers'] = gr.Checkbox(label="no_xformers", value=shared.args.no_xformers)
+ shared.gradio['no_sdpa'] = gr.Checkbox(label="no_sdpa", value=shared.args.no_sdpa)
+ shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.')
+ shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')
+ if not shared.args.portable:
+ with gr.Row():
+ shared.gradio['lora_menu'] = gr.Dropdown(multiselect=True, choices=utils.get_available_loras(), value=shared.lora_names, label='LoRA(s)', elem_classes='slim-dropdown', interactive=not mu)
+ ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': utils.get_available_loras(), 'value': shared.lora_names}, 'refresh-button', interactive=not mu)
+ shared.gradio['lora_menu_apply'] = gr.Button(value='Apply LoRAs', elem_classes='refresh-button', interactive=not mu)
+
+ with gr.Column():
+ with gr.Tab("Download"):
+ shared.gradio['custom_model_menu'] = gr.Textbox(label="Download model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main. To download a single file, enter its name in the second box.", interactive=not mu)
+ shared.gradio['download_specific_file'] = gr.Textbox(placeholder="File name (for GGUF models)", show_label=False, max_lines=1, interactive=not mu)
+ with gr.Row():
+ shared.gradio['download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu)
+ shared.gradio['get_file_list'] = gr.Button("Get file list", interactive=not mu)
+
+ with gr.Tab("Customize instruction template"):
+ with gr.Row():
+ shared.gradio['customized_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), value='None', label='Select the desired instruction template', elem_classes='slim-dropdown')
+ ui.create_refresh_button(shared.gradio['customized_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
+
+ shared.gradio['customized_template_submit'] = gr.Button("Submit", variant="primary", interactive=not mu)
+ gr.Markdown("This allows you to set a customized template for the model currently selected in the \"Model loader\" menu. Whenever the model gets loaded, this template will be used in place of the template specified in the model's medatada, which sometimes is wrong.")
- with gr.Column():
- with gr.Tab("Download"):
- shared.gradio['custom_model_menu'] = gr.Textbox(label="Download model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main. To download a single file, enter its name in the second box.", interactive=not mu)
- shared.gradio['download_specific_file'] = gr.Textbox(placeholder="File name (for GGUF models)", show_label=False, max_lines=1, interactive=not mu)
with gr.Row():
- shared.gradio['download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu)
- shared.gradio['get_file_list'] = gr.Button("Get file list", interactive=not mu)
+ shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
- with gr.Tab("Customize instruction template"):
+ with gr.Tab("Image model"):
+ with gr.Row():
+ with gr.Column():
with gr.Row():
- shared.gradio['customized_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), value='None', label='Select the desired instruction template', elem_classes='slim-dropdown')
- ui.create_refresh_button(shared.gradio['customized_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
+ shared.gradio['image_model_menu'] = gr.Dropdown(choices=utils.get_available_image_models(), value=lambda: shared.image_model_name, label='Model', elem_classes='slim-dropdown', interactive=not mu)
+ ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu)
+ shared.gradio['image_load_model'] = gr.Button("Load", elem_classes='refresh-button', interactive=not mu)
+ shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button', interactive=not mu)
+ shared.gradio['image_save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button', interactive=not mu)
- shared.gradio['customized_template_submit'] = gr.Button("Submit", variant="primary", interactive=not mu)
- gr.Markdown("This allows you to set a customized template for the model currently selected in the \"Model loader\" menu. Whenever the model gets loaded, this template will be used in place of the template specified in the model's medatada, which sometimes is wrong.")
+ with gr.Blocks():
+ gr.Markdown("## Main options")
+ with gr.Row():
+ with gr.Column():
+ pass
- with gr.Row():
- shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
+ with gr.Column():
+ pass
+
+ gr.Markdown("## Other options")
+ with gr.Accordion("See more options", open=False, elem_classes='tgw-accordion'):
+ with gr.Row():
+ with gr.Column():
+ pass
+
+ with gr.Column():
+ pass
+
+ with gr.Column():
+ shared.gradio['image_custom_model_menu'] = gr.Textbox(label="Download model (diffusers format)", info="Enter the Hugging Face username/model path, for instance: Tongyi-MAI/Z-Image-Turbo. To specify a branch, add it at the end after a \":\" character like this: Tongyi-MAI/Z-Image-Turbo:main.", interactive=not mu)
+ with gr.Row():
+ shared.gradio['image_download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu)
+
+ with gr.Row():
+ shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
def create_event_handlers():
From aa63c612dea93925d9ad7ffb8e3094434a41d1af Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 13:46:54 -0800
Subject: [PATCH 03/38] Progress on model loading
---
modules/shared.py | 1 +
modules/utils.py | 25 +++++++++++++++++++++++++
2 files changed, 26 insertions(+)
diff --git a/modules/shared.py b/modules/shared.py
index 2fbc205f..fda4ece6 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -22,6 +22,7 @@ lora_names = []
# Image model variables
image_model = None
+image_model_name = 'None'
# Generation variables
stop_everything = False
diff --git a/modules/utils.py b/modules/utils.py
index e8d23a02..5315d0f8 100644
--- a/modules/utils.py
+++ b/modules/utils.py
@@ -153,6 +153,31 @@ def get_available_models():
return filtered_gguf_files + model_dirs
+def get_available_image_models():
+ model_dir = Path(shared.args.image_model_dir)
+
+ # Find directories with safetensors files
+ dirs_with_safetensors = set()
+ for item in os.listdir(model_dir):
+ item_path = model_dir / item
+ if item_path.is_dir():
+ if any(file.lower().endswith(('.safetensors', '.pt')) for file in os.listdir(item_path) if (item_path / file).is_file()):
+ dirs_with_safetensors.add(item)
+
+ # Find valid model directories
+ model_dirs = []
+ for item in os.listdir(model_dir):
+ item_path = model_dir / item
+ if not item_path.is_dir():
+ continue
+
+ model_dirs.append(item)
+
+ model_dirs = sorted(model_dirs, key=natural_keys)
+
+ return model_dirs
+
+
def get_available_ggufs():
model_list = []
model_dir = Path(shared.args.model_dir)
From 2f11b3040d5ad96c1643136589c57abafc1331d2 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 13:53:46 -0800
Subject: [PATCH 04/38] Add functions
---
modules/ui_image_generation.py | 127 +++++++++++++++++++++++++++++++--
modules/utils.py | 4 +-
2 files changed, 124 insertions(+), 7 deletions(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 5b5c624d..b59a8458 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -1,6 +1,6 @@
import gradio as gr
import os
-
+from modules.utils import resolve_model_path
def create_ui():
with gr.Tab("Image AI", elem_id="image-ai-tab"):
@@ -78,9 +78,9 @@ def create_ui():
inputs = [prompt, neg_prompt, width_slider, height_slider, steps_slider, seed_input, batch_size_parallel, batch_count_seq]
outputs = [output_gallery, used_seed]
- # generate_btn.click(fn=generate, inputs=inputs, outputs=outputs)
- # prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
- # neg_prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
+ generate_btn.click(fn=generate, inputs=inputs, outputs=outputs)
+ prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
+ neg_prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
# System
# load_btn.click(fn=load_pipeline, inputs=[backend_drop, compile_check, offload_check, gr.State("bfloat16")], outputs=None)
@@ -91,5 +91,120 @@ def create_ui():
# demo.load(fn=get_history_images, inputs=None, outputs=history_gallery)
-def create_event_handlers():
- pass
+def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq):
+ if engine.pipe is None:
+ load_pipeline("SDPA", False, False, "bfloat16")
+
+ if seed == -1: seed = np.random.randint(0, 2**32 - 1)
+
+ # We use a base generator. For sequential batches, we might increment seed if desired,
+ # but here we keep the base seed logic consistent.
+ generator = torch.Generator("cuda").manual_seed(int(seed))
+
+ all_images = []
+
+ # SEQUENTIAL LOOP (Easy on VRAM)
+ for i in range(batch_count_seq):
+ # Update seed for subsequent batches so they aren't identical
+ current_seed = seed + i
+ generator.manual_seed(int(current_seed))
+
+ # PARALLEL GENERATION (Fast, Heavy VRAM)
+ # diffusers handles 'num_images_per_prompt' for parallel execution
+ batch_results = engine.pipe(
+ prompt=prompt,
+ negative_prompt=neg_prompt,
+ height=int(height),
+ width=int(width),
+ num_inference_steps=int(steps),
+ guidance_scale=0.0,
+ num_images_per_prompt=int(batch_size_parallel),
+ generator=generator,
+ ).images
+
+ all_images.extend(batch_results)
+
+ # Save to disk
+ save_generated_images(all_images, prompt, seed)
+
+ return all_images, seed
+
+
+# --- File Saving Logic ---
+def save_generated_images(images, prompt, seed):
+ # Create folder structure: outputs/YYYY-MM-DD/
+ date_str = datetime.now().strftime("%Y-%m-%d")
+ folder_path = os.path.join("outputs", date_str)
+ os.makedirs(folder_path, exist_ok=True)
+
+ saved_paths = []
+
+ for idx, img in enumerate(images):
+ timestamp = datetime.now().strftime("%H-%M-%S")
+ # Filename: Time_Seed_Index.png
+ filename = f"{timestamp}_{seed}_{idx}.png"
+ full_path = os.path.join(folder_path, filename)
+
+ # Save image
+ img.save(full_path)
+ saved_paths.append(full_path)
+
+ # Optional: Save prompt metadata in a text file next to it?
+ # For now, we just save the image.
+
+ return saved_paths
+
+
+# --- History Logic ---
+def get_history_images():
+ """Scans the outputs folder and returns all images, newest first"""
+ if not os.path.exists("outputs"):
+ return []
+
+ image_files = []
+ for root, dirs, files in os.walk("outputs"):
+ for file in files:
+ if file.endswith((".png", ".jpg", ".jpeg")):
+ full_path = os.path.join(root, file)
+ # Get creation time for sorting
+ mtime = os.path.getmtime(full_path)
+ image_files.append((full_path, mtime))
+
+ # Sort by time, newest first
+ image_files.sort(key=lambda x: x[1], reverse=True)
+ return [x[0] for x in image_files]
+
+
+def load_pipeline(attn_backend, compile_model, offload_cpu, dtype_str):
+ dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
+ target_dtype = dtype_map.get(dtype_str, torch.bfloat16)
+
+ if engine.pipe is not None and engine.config["backend"] == attn_backend:
+ return gr.Info("Pipeline ready.")
+
+ try:
+ gr.Info(f"Loading Model ({attn_backend})...")
+ pipe = ZImagePipeline.from_pretrained(
+ engine.config["model_id"],
+ torch_dtype=target_dtype,
+ low_cpu_mem_usage=False,
+ )
+ if not offload_cpu: pipe.to("cuda")
+
+ if attn_backend == "Flash Attention 2":
+ pipe.transformer.set_attention_backend("flash")
+ elif attn_backend == "Flash Attention 3":
+ pipe.transformer.set_attention_backend("_flash_3")
+
+ if compile_model:
+ gr.Warning("Compiling... First run will be slow.")
+ pipe.transformer.compile()
+
+ if offload_cpu: pipe.enable_model_cpu_offload()
+
+ engine.pipe = pipe
+ engine.config["backend"] = attn_backend
+ return gr.Success("System Ready.")
+ except Exception as e:
+ return gr.Error(f"Init Failed: {str(e)}")
+
diff --git a/modules/utils.py b/modules/utils.py
index 5315d0f8..13a814ae 100644
--- a/modules/utils.py
+++ b/modules/utils.py
@@ -86,7 +86,7 @@ def check_model_loaded():
return True, None
-def resolve_model_path(model_name_or_path):
+def resolve_model_path(model_name_or_path, image_model=False):
"""
Resolves a model path, checking for a direct path
before the default models directory.
@@ -95,6 +95,8 @@ def resolve_model_path(model_name_or_path):
path_candidate = Path(model_name_or_path)
if path_candidate.exists():
return path_candidate
+ elif image_model:
+ return Path(f'{shared.args.image_model_dir}/{model_name_or_path}')
else:
return Path(f'{shared.args.model_dir}/{model_name_or_path}')
From a8736922349b82d83ecbe523f606fd08d1940e73 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 14:24:35 -0800
Subject: [PATCH 05/38] Image generation now functional
---
modules/shared.py | 4 ++
modules/ui_image_generation.py | 93 +++++++++++++---------------------
modules/ui_model_menu.py | 89 +++++++++++++++++++++++++++++++-
3 files changed, 127 insertions(+), 59 deletions(-)
diff --git a/modules/shared.py b/modules/shared.py
index fda4ece6..66666b75 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -51,8 +51,12 @@ group.add_argument('--verbose', action='store_true', help='Print the prompts to
group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.')
# Image generation
+group = parser.add_argument_group('Image model')
group.add_argument('--image-model', type=str, help='Name of the image model to load by default.')
group.add_argument('--image-model-dir', type=str, default='user_data/image_models', help='Path to directory with all the image models.')
+group.add_argument('--image-dtype', type=str, default='bfloat16', choices=['bfloat16', 'float16'], help='Data type for image model.')
+group.add_argument('--image-attn-backend', type=str, default='sdpa', choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.')
+group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
# Model loader
group = parser.add_argument_group('Model loader')
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index b59a8458..25bfeb21 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -1,6 +1,13 @@
-import gradio as gr
import os
-from modules.utils import resolve_model_path
+from datetime import datetime
+
+import gradio as gr
+import numpy as np
+import torch
+
+from modules import shared
+from modules.image_models import load_image_model, unload_image_model
+
def create_ui():
with gr.Tab("Image AI", elem_id="image-ai-tab"):
@@ -92,26 +99,33 @@ def create_ui():
def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq):
- if engine.pipe is None:
- load_pipeline("SDPA", False, False, "bfloat16")
-
- if seed == -1: seed = np.random.randint(0, 2**32 - 1)
-
- # We use a base generator. For sequential batches, we might increment seed if desired,
- # but here we keep the base seed logic consistent.
+ import numpy as np
+ import torch
+ from modules import shared
+ from modules.image_models import load_image_model
+
+ # Auto-load model if not loaded
+ if shared.image_model is None:
+ if shared.image_model_name == 'None':
+ return [], "No image model selected. Please load a model first."
+ load_image_model(shared.image_model_name)
+
+ if shared.image_model is None:
+ return [], "Failed to load image model."
+
+ if seed == -1:
+ seed = np.random.randint(0, 2**32 - 1)
+
generator = torch.Generator("cuda").manual_seed(int(seed))
-
all_images = []
-
- # SEQUENTIAL LOOP (Easy on VRAM)
- for i in range(batch_count_seq):
- # Update seed for subsequent batches so they aren't identical
+
+ # Sequential loop (easier on VRAM)
+ for i in range(int(batch_count_seq)):
current_seed = seed + i
generator.manual_seed(int(current_seed))
-
- # PARALLEL GENERATION (Fast, Heavy VRAM)
- # diffusers handles 'num_images_per_prompt' for parallel execution
- batch_results = engine.pipe(
+
+ # Parallel generation
+ batch_results = shared.image_model(
prompt=prompt,
negative_prompt=neg_prompt,
height=int(height),
@@ -121,13 +135,13 @@ def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel
num_images_per_prompt=int(batch_size_parallel),
generator=generator,
).images
-
+
all_images.extend(batch_results)
-
+
# Save to disk
save_generated_images(all_images, prompt, seed)
-
- return all_images, seed
+
+ return all_images, f"Seed: {seed}"
# --- File Saving Logic ---
@@ -173,38 +187,3 @@ def get_history_images():
# Sort by time, newest first
image_files.sort(key=lambda x: x[1], reverse=True)
return [x[0] for x in image_files]
-
-
-def load_pipeline(attn_backend, compile_model, offload_cpu, dtype_str):
- dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
- target_dtype = dtype_map.get(dtype_str, torch.bfloat16)
-
- if engine.pipe is not None and engine.config["backend"] == attn_backend:
- return gr.Info("Pipeline ready.")
-
- try:
- gr.Info(f"Loading Model ({attn_backend})...")
- pipe = ZImagePipeline.from_pretrained(
- engine.config["model_id"],
- torch_dtype=target_dtype,
- low_cpu_mem_usage=False,
- )
- if not offload_cpu: pipe.to("cuda")
-
- if attn_backend == "Flash Attention 2":
- pipe.transformer.set_attention_backend("flash")
- elif attn_backend == "Flash Attention 3":
- pipe.transformer.set_attention_backend("_flash_3")
-
- if compile_model:
- gr.Warning("Compiling... First run will be slow.")
- pipe.transformer.compile()
-
- if offload_cpu: pipe.enable_model_cpu_offload()
-
- engine.pipe = pipe
- engine.config["backend"] = attn_backend
- return gr.Success("System Ready.")
- except Exception as e:
- return gr.Error(f"Init Failed: {str(e)}")
-
diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py
index dbbd3274..cb3508f8 100644
--- a/modules/ui_model_menu.py
+++ b/modules/ui_model_menu.py
@@ -140,7 +140,7 @@ def create_ui():
with gr.Column():
with gr.Row():
shared.gradio['image_model_menu'] = gr.Dropdown(choices=utils.get_available_image_models(), value=lambda: shared.image_model_name, label='Model', elem_classes='slim-dropdown', interactive=not mu)
- ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu)
+ ui.create_refresh_button(shared.gradio['image_model_menu'], lambda: None, lambda: {'choices': utils.get_available_image_models()}, 'refresh-button', interactive=not mu)
shared.gradio['image_load_model'] = gr.Button("Load", elem_classes='refresh-button', interactive=not mu)
shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button', interactive=not mu)
shared.gradio['image_save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button', interactive=not mu)
@@ -169,7 +169,7 @@ def create_ui():
shared.gradio['image_download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu)
with gr.Row():
- shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
+ shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.image_model_name == 'None' else 'Ready')
def create_event_handlers():
@@ -220,6 +220,28 @@ def create_event_handlers():
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
+ # Image model event handlers
+ shared.gradio['image_load_model'].click(
+ load_image_model_wrapper,
+ gradio('image_model_menu'),
+ gradio('image_model_status'),
+ show_progress=True
+ )
+
+ shared.gradio['image_unload_model'].click(
+ handle_unload_image_model_click,
+ None,
+ gradio('image_model_status'),
+ show_progress=False
+ )
+
+ shared.gradio['image_download_model_button'].click(
+ download_image_model_wrapper,
+ gradio('image_custom_model_menu'),
+ gradio('image_model_status'),
+ show_progress=True
+ )
+
def load_model_wrapper(selected_model, loader, autoload=False):
try:
@@ -471,3 +493,66 @@ def format_file_size(size_bytes):
return f"{s:.2f} {size_names[i]}"
else:
return f"{s:.1f} {size_names[i]}"
+
+
+def load_image_model_wrapper(selected_model):
+ """Wrapper for loading image models with status updates."""
+ from modules.image_models import load_image_model, unload_image_model
+
+ if selected_model == 'None' or not selected_model:
+ yield "No model selected"
+ return
+
+ try:
+ yield f"Loading `{selected_model}`..."
+ unload_image_model()
+ result = load_image_model(selected_model)
+
+ if result is not None:
+ yield f"Successfully loaded `{selected_model}`."
+ else:
+ yield f"Failed to load `{selected_model}`."
+ except Exception:
+ exc = traceback.format_exc()
+ yield exc.replace('\n', '\n\n')
+
+
+def handle_unload_image_model_click():
+ """Handler for the image model unload button."""
+ from modules.image_models import unload_image_model
+ unload_image_model()
+ return "Image model unloaded"
+
+
+def download_image_model_wrapper(custom_model):
+ """Download an image model from Hugging Face."""
+ from huggingface_hub import snapshot_download
+
+ if not custom_model:
+ yield "No model specified"
+ return
+
+ try:
+ # Parse model name and branch
+ if ':' in custom_model:
+ model_name, branch = custom_model.rsplit(':', 1)
+ else:
+ model_name, branch = custom_model, 'main'
+
+ # Output folder
+ output_folder = Path(shared.args.image_model_dir) / model_name.split('/')[-1]
+
+ yield f"Downloading `{model_name}` (branch: {branch})..."
+
+ snapshot_download(
+ repo_id=model_name,
+ revision=branch,
+ local_dir=output_folder,
+ local_dir_use_symlinks=False,
+ )
+
+ yield f"Model successfully saved to `{output_folder}/`."
+
+ except Exception:
+ exc = traceback.format_exc()
+ yield exc.replace('\n', '\n\n')
From be799ba8ebbf1aebe279160297effa98d326b306 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 14:25:49 -0800
Subject: [PATCH 06/38] Lint
---
modules/ui_image_generation.py | 32 ++++++++++++++++----------------
modules/ui_model_menu.py | 4 ++--
2 files changed, 18 insertions(+), 18 deletions(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 25bfeb21..ec1799dd 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -15,10 +15,10 @@ def create_ui():
# TAB 1: GENERATION STUDIO
with gr.TabItem("Generate Images"):
with gr.Row():
-
+
# === LEFT COLUMN: CONTROLS ===
with gr.Column(scale=4, min_width=350):
-
+
# 1. PROMPT
prompt = gr.Textbox(label="Prompt", placeholder="Describe your imagination...", lines=3, autofocus=True)
neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="Low quality...", lines=3)
@@ -54,12 +54,12 @@ def create_ui():
with gr.Column():
batch_size_parallel = gr.Slider(1, 32, value=1, step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
batch_count_seq = gr.Slider(1, 128, value=1, step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
-
+
# === RIGHT COLUMN: VIEWPORT ===
with gr.Column(scale=6, min_width=500):
with gr.Column(elem_classes=["viewport-container"]):
output_gallery = gr.Gallery(
- label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True
+ label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True
)
with gr.Row():
used_seed = gr.Markdown(label="Info", interactive=False, lines=3)
@@ -74,7 +74,7 @@ def create_ui():
)
# === WIRING ===
-
+
# Aspect Buttons
# btn_sq.click(lambda: set_dims(1024, 1024), outputs=[width_slider, height_slider])
# btn_port.click(lambda: set_dims(720, 1280), outputs=[width_slider, height_slider])
@@ -91,7 +91,7 @@ def create_ui():
# System
# load_btn.click(fn=load_pipeline, inputs=[backend_drop, compile_check, offload_check, gr.State("bfloat16")], outputs=None)
-
+
# History
# refresh_btn.click(fn=get_history_images, inputs=None, outputs=history_gallery)
# Load history on app launch
@@ -103,27 +103,27 @@ def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel
import torch
from modules import shared
from modules.image_models import load_image_model
-
+
# Auto-load model if not loaded
if shared.image_model is None:
if shared.image_model_name == 'None':
return [], "No image model selected. Please load a model first."
load_image_model(shared.image_model_name)
-
+
if shared.image_model is None:
return [], "Failed to load image model."
-
+
if seed == -1:
seed = np.random.randint(0, 2**32 - 1)
-
+
generator = torch.Generator("cuda").manual_seed(int(seed))
all_images = []
-
+
# Sequential loop (easier on VRAM)
for i in range(int(batch_count_seq)):
current_seed = seed + i
generator.manual_seed(int(current_seed))
-
+
# Parallel generation
batch_results = shared.image_model(
prompt=prompt,
@@ -135,12 +135,12 @@ def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel
num_images_per_prompt=int(batch_size_parallel),
generator=generator,
).images
-
+
all_images.extend(batch_results)
-
+
# Save to disk
save_generated_images(all_images, prompt, seed)
-
+
return all_images, f"Seed: {seed}"
@@ -163,7 +163,7 @@ def save_generated_images(images, prompt, seed):
img.save(full_path)
saved_paths.append(full_path)
- # Optional: Save prompt metadata in a text file next to it?
+ # Optional: Save prompt metadata in a text file next to it?
# For now, we just save the image.
return saved_paths
diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py
index cb3508f8..c4cce35d 100644
--- a/modules/ui_model_menu.py
+++ b/modules/ui_model_menu.py
@@ -227,14 +227,14 @@ def create_event_handlers():
gradio('image_model_status'),
show_progress=True
)
-
+
shared.gradio['image_unload_model'].click(
handle_unload_image_model_click,
None,
gradio('image_model_status'),
show_progress=False
)
-
+
shared.gradio['image_download_model_button'].click(
download_image_model_wrapper,
gradio('image_custom_model_menu'),
From aa074409cb1bdc652faa56fd5bff7fe0c43dd0d0 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 14:38:50 -0800
Subject: [PATCH 07/38] Better events for the dimensions
---
modules/ui_image_generation.py | 165 +++++++++++++++++++++++++++++++--
1 file changed, 159 insertions(+), 6 deletions(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index ec1799dd..df87c4ab 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -9,6 +9,18 @@ from modules import shared
from modules.image_models import load_image_model, unload_image_model
+# Aspect ratio definitions: name -> (width_ratio, height_ratio)
+ASPECT_RATIOS = {
+ "1:1 Square": (1, 1),
+ "16:9 Cinema": (16, 9),
+ "9:16 Mobile": (9, 16),
+ "4:3 Photo": (4, 3),
+ "Custom": None,
+}
+
+STEP = 32 # Slider step for rounding
+
+
def create_ui():
with gr.Tab("Image AI", elem_id="image-ai-tab"):
with gr.Tabs():
@@ -36,12 +48,14 @@ def create_ui():
with gr.Column():
height_slider = gr.Slider(256, 2048, value=1024, step=32, label="Height")
- preset_radio = gr.Radio(
- choices=["1:1 Square", "16:9 Cinema", "9:16 Mobile", "4:3 Photo", "Custom"],
- value="1:1 Square",
- label="Aspect Ratio",
- interactive=True
- )
+ with gr.Row():
+ preset_radio = gr.Radio(
+ choices=["1:1 Square", "16:9 Cinema", "9:16 Mobile", "4:3 Photo", "Custom"],
+ value="1:1 Square",
+ label="Aspect Ratio",
+ interactive=True
+ )
+ swap_btn = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80)
# 4. SETTINGS & BATCHING
gr.Markdown("### ⚙️ Config")
@@ -75,6 +89,49 @@ def create_ui():
# === WIRING ===
+ # Aspect ratio preset changes -> update dimensions
+ preset_radio.change(
+ fn=apply_aspect_ratio,
+ inputs=[preset_radio, width_slider, height_slider],
+ outputs=[width_slider, height_slider],
+ show_progress=False
+ )
+
+ # Width slider changes -> update height (if not Custom)
+ width_slider.release(
+ fn=update_height_from_width,
+ inputs=[width_slider, preset_radio],
+ outputs=[height_slider],
+ show_progress=False
+ )
+
+ # Height slider changes -> update width (if not Custom)
+ height_slider.release(
+ fn=update_width_from_height,
+ inputs=[height_slider, preset_radio],
+ outputs=[width_slider],
+ show_progress=False
+ )
+
+ # Swap button -> swap dimensions and update aspect ratio
+ swap_btn.click(
+ fn=swap_dimensions_and_update_ratio,
+ inputs=[width_slider, height_slider, preset_radio],
+ outputs=[width_slider, height_slider, preset_radio],
+ show_progress=False
+ )
+
+ # Generation
+ inputs = [prompt, neg_prompt, width_slider, height_slider, steps_slider, seed_input, batch_size_parallel, batch_count_seq]
+ outputs = [output_gallery, used_seed]
+
+ generate_btn.click(fn=generate, inputs=inputs, outputs=outputs)
+ prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
+ neg_prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
+
+ # History
+ # refresh_btn.click(fn=get_history_images, inputs=None, outputs=history_gallery)
+
# Aspect Buttons
# btn_sq.click(lambda: set_dims(1024, 1024), outputs=[width_slider, height_slider])
# btn_port.click(lambda: set_dims(720, 1280), outputs=[width_slider, height_slider])
@@ -187,3 +244,99 @@ def get_history_images():
# Sort by time, newest first
image_files.sort(key=lambda x: x[1], reverse=True)
return [x[0] for x in image_files]
+
+
+def round_to_step(value, step=STEP):
+ """Round a value to the nearest step."""
+ return round(value / step) * step
+
+
+def clamp(value, min_val, max_val):
+ """Clamp value between min and max."""
+ return max(min_val, min(max_val, value))
+
+
+def apply_aspect_ratio(aspect_ratio, current_width, current_height):
+ """
+ Apply an aspect ratio preset.
+
+ Logic to prevent dimension creep:
+ - For tall ratios (like 9:16): keep width fixed, calculate height
+ - For wide ratios (like 16:9): keep height fixed, calculate width
+ - For square (1:1): use the smaller of the current dimensions
+
+ Returns (new_width, new_height).
+ """
+ if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
+ return current_width, current_height
+
+ w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
+
+ if w_ratio == h_ratio:
+ # Square ratio - use the smaller current dimension to prevent creep
+ base = min(current_width, current_height)
+ new_width = base
+ new_height = base
+ elif w_ratio < h_ratio:
+ # Tall ratio (like 9:16) - width is the smaller side, keep it fixed
+ new_width = current_width
+ new_height = round_to_step(current_width * h_ratio / w_ratio)
+ else:
+ # Wide ratio (like 16:9) - height is the smaller side, keep it fixed
+ new_height = current_height
+ new_width = round_to_step(current_height * w_ratio / h_ratio)
+
+ # Clamp to slider bounds
+ new_width = clamp(new_width, 256, 2048)
+ new_height = clamp(new_height, 256, 2048)
+
+ return int(new_width), int(new_height)
+
+
+def update_height_from_width(width, aspect_ratio):
+ """Update height when width changes (if not Custom)."""
+ if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
+ return gr.update()
+
+ w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
+ new_height = round_to_step(width * h_ratio / w_ratio)
+ new_height = clamp(new_height, 256, 2048)
+
+ return int(new_height)
+
+
+def update_width_from_height(height, aspect_ratio):
+ """Update width when height changes (if not Custom)."""
+ if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
+ return gr.update()
+
+ w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
+ new_width = round_to_step(height * w_ratio / h_ratio)
+ new_width = clamp(new_width, 256, 2048)
+
+ return int(new_width)
+
+
+def swap_dimensions(width, height):
+ """Swap width and height values."""
+ return height, width
+
+
+def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
+ """Swap dimensions and update aspect ratio to match (or set to Custom)."""
+ new_width, new_height = height, width
+
+ # Try to find a matching aspect ratio for the swapped dimensions
+ new_ratio = "Custom"
+ for name, ratios in ASPECT_RATIOS.items():
+ if ratios is None:
+ continue
+ w_r, h_r = ratios
+ # Check if the swapped dimensions match this ratio (within tolerance)
+ expected_height = new_width * h_r / w_r
+ if abs(expected_height - new_height) < STEP:
+ new_ratio = name
+ break
+
+ return new_width, new_height, new_ratio
+
From 0adda7a5c57d3f5f91be1015545319de548f3d15 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 14:39:21 -0800
Subject: [PATCH 08/38] Lint
---
modules/ui_image_generation.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index df87c4ab..d9f3ae30 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -339,4 +339,3 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
break
return new_width, new_height, new_ratio
-
From 148a5d1e44a60372e960b677b4513ecb519d9207 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 15:32:01 -0800
Subject: [PATCH 09/38] Keep things more modular
---
modules/shared.py | 7 +-
modules/ui_image_generation.py | 502 ++++++++++++++++++++++-----------
modules/ui_model_menu.py | 247 ++++++----------
3 files changed, 433 insertions(+), 323 deletions(-)
diff --git a/modules/shared.py b/modules/shared.py
index 66666b75..c77259e6 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -52,11 +52,12 @@ group.add_argument('--idle-timeout', type=int, default=0, help='Unload model aft
# Image generation
group = parser.add_argument_group('Image model')
-group.add_argument('--image-model', type=str, help='Name of the image model to load by default.')
+group.add_argument('--image-model', type=str, help='Name of the image model to select on startup (overrides saved setting).')
group.add_argument('--image-model-dir', type=str, default='user_data/image_models', help='Path to directory with all the image models.')
-group.add_argument('--image-dtype', type=str, default='bfloat16', choices=['bfloat16', 'float16'], help='Data type for image model.')
-group.add_argument('--image-attn-backend', type=str, default='sdpa', choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.')
+group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16', 'float16'], help='Data type for image model.')
+group.add_argument('--image-attn-backend', type=str, default=None, choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.')
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
+group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.')
# Model loader
group = parser.add_argument_group('Model loader')
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index d9f3ae30..e01f8ea7 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -1,12 +1,16 @@
+# modules/ui_image_generation.py
import os
+import traceback
from datetime import datetime
+from pathlib import Path
import gradio as gr
import numpy as np
import torch
-from modules import shared
+from modules import shared, utils
from modules.image_models import load_image_model, unload_image_model
+from modules.image_model_settings import get_effective_settings, save_image_model_settings
# Aspect ratio definitions: name -> (width_ratio, height_ratio)
@@ -21,16 +25,113 @@ ASPECT_RATIOS = {
STEP = 32 # Slider step for rounding
+def round_to_step(value, step=STEP):
+ """Round a value to the nearest step."""
+ return round(value / step) * step
+
+
+def clamp(value, min_val, max_val):
+ """Clamp value between min and max."""
+ return max(min_val, min(max_val, value))
+
+
+def apply_aspect_ratio(aspect_ratio, current_width, current_height):
+ """
+ Apply an aspect ratio preset.
+
+ Logic to prevent dimension creep:
+ - For tall ratios (like 9:16): keep width fixed, calculate height
+ - For wide ratios (like 16:9): keep height fixed, calculate width
+ - For square (1:1): use the smaller of the current dimensions
+
+ Returns (new_width, new_height).
+ """
+ if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
+ return current_width, current_height
+
+ w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
+
+ if w_ratio == h_ratio:
+ # Square ratio - use the smaller current dimension to prevent creep
+ base = min(current_width, current_height)
+ new_width = base
+ new_height = base
+ elif w_ratio < h_ratio:
+ # Tall ratio (like 9:16) - width is the smaller side, keep it fixed
+ new_width = current_width
+ new_height = round_to_step(current_width * h_ratio / w_ratio)
+ else:
+ # Wide ratio (like 16:9) - height is the smaller side, keep it fixed
+ new_height = current_height
+ new_width = round_to_step(current_height * w_ratio / h_ratio)
+
+ # Clamp to slider bounds
+ new_width = clamp(new_width, 256, 2048)
+ new_height = clamp(new_height, 256, 2048)
+
+ return int(new_width), int(new_height)
+
+
+def update_height_from_width(width, aspect_ratio):
+ """Update height when width changes (if not Custom)."""
+ if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
+ return gr.update()
+
+ w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
+ new_height = round_to_step(width * h_ratio / w_ratio)
+ new_height = clamp(new_height, 256, 2048)
+
+ return int(new_height)
+
+
+def update_width_from_height(height, aspect_ratio):
+ """Update width when height changes (if not Custom)."""
+ if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
+ return gr.update()
+
+ w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
+ new_width = round_to_step(height * w_ratio / h_ratio)
+ new_width = clamp(new_width, 256, 2048)
+
+ return int(new_width)
+
+
+def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
+ """Swap dimensions and update aspect ratio to match (or set to Custom)."""
+ new_width, new_height = height, width
+
+ # Try to find a matching aspect ratio for the swapped dimensions
+ new_ratio = "Custom"
+ for name, ratios in ASPECT_RATIOS.items():
+ if ratios is None:
+ continue
+ w_r, h_r = ratios
+ # Check if the swapped dimensions match this ratio (within tolerance)
+ expected_height = new_width * h_r / w_r
+ if abs(expected_height - new_height) < STEP:
+ new_ratio = name
+ break
+
+ return new_width, new_height, new_ratio
+
+
def create_ui():
+ # Get effective settings (CLI > yaml > defaults)
+ settings = get_effective_settings()
+
+ # Update shared state (but don't load the model yet)
+ if settings['model_name'] != 'None':
+ shared.image_model_name = settings['model_name']
+
with gr.Tab("Image AI", elem_id="image-ai-tab"):
with gr.Tabs():
# TAB 1: GENERATION STUDIO
- with gr.TabItem("Generate Images"):
+ with gr.TabItem("Generate"):
with gr.Row():
-
+
# === LEFT COLUMN: CONTROLS ===
with gr.Column(scale=4, min_width=350):
-
+
# 1. PROMPT
prompt = gr.Textbox(label="Prompt", placeholder="Describe your imagination...", lines=3, autofocus=True)
neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="Low quality...", lines=3)
@@ -58,7 +159,7 @@ def create_ui():
swap_btn = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80)
# 4. SETTINGS & BATCHING
- gr.Markdown("### ⚙️ Config")
+ gr.Markdown("### ⚙️ Config")
with gr.Row():
with gr.Column():
steps_slider = gr.Slider(1, 15, value=9, step=1, label="Steps")
@@ -68,15 +169,15 @@ def create_ui():
with gr.Column():
batch_size_parallel = gr.Slider(1, 32, value=1, step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
batch_count_seq = gr.Slider(1, 128, value=1, step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
-
+
# === RIGHT COLUMN: VIEWPORT ===
with gr.Column(scale=6, min_width=500):
with gr.Column(elem_classes=["viewport-container"]):
output_gallery = gr.Gallery(
- label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True
+ label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True
)
with gr.Row():
- used_seed = gr.Markdown(label="Info", interactive=False, lines=3)
+ used_seed = gr.Markdown(label="Info", interactive=False)
# TAB 2: HISTORY VIEWER
with gr.TabItem("Gallery"):
@@ -87,8 +188,67 @@ def create_ui():
label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True
)
- # === WIRING ===
+ # TAB 3: MODEL SETTINGS
+ with gr.TabItem("Model"):
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ image_model_menu = gr.Dropdown(
+ choices=utils.get_available_image_models(),
+ value=settings['model_name'],
+ label='Model',
+ elem_classes='slim-dropdown'
+ )
+ image_refresh_models = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40)
+
+ with gr.Row():
+ image_load_model = gr.Button("Load", variant='primary')
+ image_unload_model = gr.Button("Unload")
+
+ gr.Markdown("### Settings")
+
+ image_dtype = gr.Dropdown(
+ choices=['bfloat16', 'float16'],
+ value=settings['dtype'],
+ label='Data Type',
+ info='bfloat16 recommended for modern GPUs'
+ )
+
+ image_attn_backend = gr.Dropdown(
+ choices=['sdpa', 'flash_attention_2', 'flash_attention_3'],
+ value=settings['attn_backend'],
+ label='Attention Backend',
+ info='SDPA is default. Flash Attention requires compatible GPU.'
+ )
+
+ image_cpu_offload = gr.Checkbox(
+ value=settings['cpu_offload'],
+ label='CPU Offload',
+ info='Enable for low VRAM GPUs. Slower but uses less memory.'
+ )
+
+ image_compile = gr.Checkbox(
+ value=settings['compile_model'],
+ label='Compile Model',
+ info='Faster inference after first run. First run will be slow.'
+ )
+
+ image_model_status = gr.Markdown(
+ value=f"Model: **{settings['model_name']}** (not loaded)" if settings['model_name'] != 'None' else "No model selected"
+ )
+ with gr.Column():
+ gr.Markdown("### Download Model")
+ image_download_path = gr.Textbox(
+ label="Hugging Face Model",
+ placeholder="Tongyi-MAI/Z-Image-Turbo",
+ info="Enter the HuggingFace model path. Use : for branch, e.g. model:main"
+ )
+ image_download_btn = gr.Button("Download", variant='primary')
+ image_download_status = gr.Markdown("")
+
+ # === WIRING ===
+
# Aspect ratio preset changes -> update dimensions
preset_radio.change(
fn=apply_aspect_ratio,
@@ -96,7 +256,7 @@ def create_ui():
outputs=[width_slider, height_slider],
show_progress=False
)
-
+
# Width slider changes -> update height (if not Custom)
width_slider.release(
fn=update_height_from_width,
@@ -104,7 +264,7 @@ def create_ui():
outputs=[height_slider],
show_progress=False
)
-
+
# Height slider changes -> update width (if not Custom)
height_slider.release(
fn=update_width_from_height,
@@ -112,7 +272,7 @@ def create_ui():
outputs=[width_slider],
show_progress=False
)
-
+
# Swap button -> swap dimensions and update aspect ratio
swap_btn.click(
fn=swap_dimensions_and_update_ratio,
@@ -125,62 +285,92 @@ def create_ui():
inputs = [prompt, neg_prompt, width_slider, height_slider, steps_slider, seed_input, batch_size_parallel, batch_count_seq]
outputs = [output_gallery, used_seed]
- generate_btn.click(fn=generate, inputs=inputs, outputs=outputs)
- prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
- neg_prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
-
+ generate_btn.click(
+ fn=lambda *args: generate(*args, image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile),
+ inputs=inputs,
+ outputs=outputs
+ )
+ prompt.submit(
+ fn=lambda *args: generate(*args, image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile),
+ inputs=inputs,
+ outputs=outputs
+ )
+ neg_prompt.submit(
+ fn=lambda *args: generate(*args, image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile),
+ inputs=inputs,
+ outputs=outputs
+ )
+
+ # Model tab events
+ image_refresh_models.click(
+ fn=lambda: gr.update(choices=utils.get_available_image_models()),
+ inputs=None,
+ outputs=[image_model_menu],
+ show_progress=False
+ )
+
+ image_load_model.click(
+ fn=load_image_model_wrapper,
+ inputs=[image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile],
+ outputs=[image_model_status],
+ show_progress=True
+ )
+
+ image_unload_model.click(
+ fn=unload_image_model_wrapper,
+ inputs=None,
+ outputs=[image_model_status],
+ show_progress=False
+ )
+
+ image_download_btn.click(
+ fn=download_image_model_wrapper,
+ inputs=[image_download_path],
+ outputs=[image_download_status, image_model_menu],
+ show_progress=True
+ )
+
# History
- # refresh_btn.click(fn=get_history_images, inputs=None, outputs=history_gallery)
-
- # Aspect Buttons
- # btn_sq.click(lambda: set_dims(1024, 1024), outputs=[width_slider, height_slider])
- # btn_port.click(lambda: set_dims(720, 1280), outputs=[width_slider, height_slider])
- # btn_land.click(lambda: set_dims(1280, 720), outputs=[width_slider, height_slider])
- # btn_wide.click(lambda: set_dims(1536, 640), outputs=[width_slider, height_slider])
-
- # Generation
- inputs = [prompt, neg_prompt, width_slider, height_slider, steps_slider, seed_input, batch_size_parallel, batch_count_seq]
- outputs = [output_gallery, used_seed]
-
- generate_btn.click(fn=generate, inputs=inputs, outputs=outputs)
- prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
- neg_prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
-
- # System
- # load_btn.click(fn=load_pipeline, inputs=[backend_drop, compile_check, offload_check, gr.State("bfloat16")], outputs=None)
-
- # History
- # refresh_btn.click(fn=get_history_images, inputs=None, outputs=history_gallery)
- # Load history on app launch
- # demo.load(fn=get_history_images, inputs=None, outputs=history_gallery)
+ refresh_btn.click(fn=get_history_images, inputs=None, outputs=history_gallery, show_progress=False)
-def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq):
- import numpy as np
- import torch
- from modules import shared
- from modules.image_models import load_image_model
-
+def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq,
+ model_menu, dtype_dropdown, attn_dropdown, cpu_offload_checkbox, compile_checkbox):
+ """Generate images with the current model settings."""
+
+ # Get current UI values (these are Gradio components, we need their values)
+ model_name = shared.image_model_name
+
+ if model_name == 'None':
+ return [], "No image model selected. Go to the Model tab and select a model."
+
# Auto-load model if not loaded
if shared.image_model is None:
- if shared.image_model_name == 'None':
- return [], "No image model selected. Please load a model first."
- load_image_model(shared.image_model_name)
-
- if shared.image_model is None:
- return [], "Failed to load image model."
-
+ # Load saved settings for the model
+ saved_settings = load_image_model_settings()
+
+ result = load_image_model(
+ model_name,
+ dtype=saved_settings['dtype'],
+ attn_backend=saved_settings['attn_backend'],
+ cpu_offload=saved_settings['cpu_offload'],
+ compile_model=saved_settings['compile_model']
+ )
+
+ if result is None:
+ return [], f"Failed to load model `{model_name}`."
+
if seed == -1:
seed = np.random.randint(0, 2**32 - 1)
-
+
generator = torch.Generator("cuda").manual_seed(int(seed))
all_images = []
-
+
# Sequential loop (easier on VRAM)
for i in range(int(batch_count_seq)):
current_seed = seed + i
generator.manual_seed(int(current_seed))
-
+
# Parallel generation
batch_results = shared.image_model(
prompt=prompt,
@@ -192,150 +382,128 @@ def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel
num_images_per_prompt=int(batch_size_parallel),
generator=generator,
).images
-
+
all_images.extend(batch_results)
-
+
# Save to disk
save_generated_images(all_images, prompt, seed)
-
+
return all_images, f"Seed: {seed}"
-# --- File Saving Logic ---
+def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model):
+ """Load model and save settings."""
+ if model_name == 'None' or not model_name:
+ yield "No model selected"
+ return
+
+ try:
+ yield f"Loading `{model_name}`..."
+
+ # Unload existing model first
+ unload_image_model()
+
+ # Load the new model
+ result = load_image_model(
+ model_name,
+ dtype=dtype,
+ attn_backend=attn_backend,
+ cpu_offload=cpu_offload,
+ compile_model=compile_model
+ )
+
+ if result is not None:
+ # Save settings to yaml
+ save_image_model_settings(model_name, dtype, attn_backend, cpu_offload, compile_model)
+ yield f"✓ Loaded **{model_name}**"
+ else:
+ yield f"✗ Failed to load `{model_name}`"
+
+ except Exception:
+ exc = traceback.format_exc()
+ yield f"Error:\n```\n{exc}\n```"
+
+
+def unload_image_model_wrapper():
+ """Unload model wrapper."""
+ unload_image_model()
+
+ if shared.image_model_name != 'None':
+ return f"Model: **{shared.image_model_name}** (not loaded)"
+ else:
+ return "No model loaded"
+
+
+def download_image_model_wrapper(model_path):
+ """Download a model from Hugging Face."""
+ from huggingface_hub import snapshot_download
+
+ if not model_path:
+ yield "No model specified", gr.update()
+ return
+
+ try:
+ # Parse model name and branch
+ if ':' in model_path:
+ model_id, branch = model_path.rsplit(':', 1)
+ else:
+ model_id, branch = model_path, 'main'
+
+ # Output folder name
+ folder_name = model_id.split('/')[-1]
+ output_folder = Path(shared.args.image_model_dir) / folder_name
+
+ yield f"Downloading `{model_id}` (branch: {branch})...", gr.update()
+
+ snapshot_download(
+ repo_id=model_id,
+ revision=branch,
+ local_dir=output_folder,
+ local_dir_use_symlinks=False,
+ )
+
+ # Refresh the model list
+ new_choices = utils.get_available_image_models()
+
+ yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
+
+ except Exception:
+ exc = traceback.format_exc()
+ yield f"Error:\n```\n{exc}\n```", gr.update()
+
+
def save_generated_images(images, prompt, seed):
- # Create folder structure: outputs/YYYY-MM-DD/
+ """Save generated images to disk."""
date_str = datetime.now().strftime("%Y-%m-%d")
- folder_path = os.path.join("outputs", date_str)
+ folder_path = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder_path, exist_ok=True)
saved_paths = []
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
- # Filename: Time_Seed_Index.png
filename = f"{timestamp}_{seed}_{idx}.png"
full_path = os.path.join(folder_path, filename)
- # Save image
img.save(full_path)
saved_paths.append(full_path)
- # Optional: Save prompt metadata in a text file next to it?
- # For now, we just save the image.
-
return saved_paths
-# --- History Logic ---
def get_history_images():
- """Scans the outputs folder and returns all images, newest first"""
- if not os.path.exists("outputs"):
+ """Scan the outputs folder and return all images, newest first."""
+ output_dir = os.path.join("user_data", "image_outputs")
+ if not os.path.exists(output_dir):
return []
image_files = []
- for root, dirs, files in os.walk("outputs"):
+ for root, dirs, files in os.walk(output_dir):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
full_path = os.path.join(root, file)
- # Get creation time for sorting
mtime = os.path.getmtime(full_path)
image_files.append((full_path, mtime))
- # Sort by time, newest first
image_files.sort(key=lambda x: x[1], reverse=True)
return [x[0] for x in image_files]
-
-
-def round_to_step(value, step=STEP):
- """Round a value to the nearest step."""
- return round(value / step) * step
-
-
-def clamp(value, min_val, max_val):
- """Clamp value between min and max."""
- return max(min_val, min(max_val, value))
-
-
-def apply_aspect_ratio(aspect_ratio, current_width, current_height):
- """
- Apply an aspect ratio preset.
-
- Logic to prevent dimension creep:
- - For tall ratios (like 9:16): keep width fixed, calculate height
- - For wide ratios (like 16:9): keep height fixed, calculate width
- - For square (1:1): use the smaller of the current dimensions
-
- Returns (new_width, new_height).
- """
- if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
- return current_width, current_height
-
- w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
-
- if w_ratio == h_ratio:
- # Square ratio - use the smaller current dimension to prevent creep
- base = min(current_width, current_height)
- new_width = base
- new_height = base
- elif w_ratio < h_ratio:
- # Tall ratio (like 9:16) - width is the smaller side, keep it fixed
- new_width = current_width
- new_height = round_to_step(current_width * h_ratio / w_ratio)
- else:
- # Wide ratio (like 16:9) - height is the smaller side, keep it fixed
- new_height = current_height
- new_width = round_to_step(current_height * w_ratio / h_ratio)
-
- # Clamp to slider bounds
- new_width = clamp(new_width, 256, 2048)
- new_height = clamp(new_height, 256, 2048)
-
- return int(new_width), int(new_height)
-
-
-def update_height_from_width(width, aspect_ratio):
- """Update height when width changes (if not Custom)."""
- if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
- return gr.update()
-
- w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
- new_height = round_to_step(width * h_ratio / w_ratio)
- new_height = clamp(new_height, 256, 2048)
-
- return int(new_height)
-
-
-def update_width_from_height(height, aspect_ratio):
- """Update width when height changes (if not Custom)."""
- if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
- return gr.update()
-
- w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
- new_width = round_to_step(height * w_ratio / h_ratio)
- new_width = clamp(new_width, 256, 2048)
-
- return int(new_width)
-
-
-def swap_dimensions(width, height):
- """Swap width and height values."""
- return height, width
-
-
-def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
- """Swap dimensions and update aspect ratio to match (or set to Custom)."""
- new_width, new_height = height, width
-
- # Try to find a matching aspect ratio for the swapped dimensions
- new_ratio = "Custom"
- for name, ratios in ASPECT_RATIOS.items():
- if ratios is None:
- continue
- w_r, h_r = ratios
- # Check if the swapped dimensions match this ratio (within tolerance)
- expected_height = new_width * h_r / w_r
- if abs(expected_height - new_height) < STEP:
- new_ratio = name
- break
-
- return new_width, new_height, new_ratio
diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py
index c4cce35d..a5e0f640 100644
--- a/modules/ui_model_menu.py
+++ b/modules/ui_model_menu.py
@@ -27,149 +27,112 @@ def create_ui():
mu = shared.args.multi_user
with gr.Tab("Model", elem_id="model-tab"):
- with gr.Tab("Text model"):
- with gr.Row():
- with gr.Column():
- with gr.Row():
- shared.gradio['model_menu'] = gr.Dropdown(choices=utils.get_available_models(), value=lambda: shared.model_name, label='Model', elem_classes='slim-dropdown', interactive=not mu)
- ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu)
- shared.gradio['load_model'] = gr.Button("Load", elem_classes='refresh-button', interactive=not mu)
- shared.gradio['unload_model'] = gr.Button("Unload", elem_classes='refresh-button', interactive=not mu)
- shared.gradio['save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button', interactive=not mu)
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['model_menu'] = gr.Dropdown(choices=utils.get_available_models(), value=lambda: shared.model_name, label='Model', elem_classes='slim-dropdown', interactive=not mu)
+ ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu)
+ shared.gradio['load_model'] = gr.Button("Load", elem_classes='refresh-button', interactive=not mu)
+ shared.gradio['unload_model'] = gr.Button("Unload", elem_classes='refresh-button', interactive=not mu)
+ shared.gradio['save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button', interactive=not mu)
- shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=loaders.loaders_and_params.keys() if not shared.args.portable else ['llama.cpp'], value=None)
- with gr.Blocks():
- gr.Markdown("## Main options")
+ shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=loaders.loaders_and_params.keys() if not shared.args.portable else ['llama.cpp'], value=None)
+ with gr.Blocks():
+ gr.Markdown("## Main options")
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=0, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Must be greater than 0 for the GPU to be used. ⚠️ Lower this value if you can\'t load the model.')
+ shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=256, maximum=131072, step=256, value=shared.args.ctx_size, info='Context length. Common values: 4096, 8192, 16384, 32768, 65536, 131072.')
+ shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
+ shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.')
+ shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
+ shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.')
+
+ with gr.Column():
+ shared.gradio['vram_info'] = gr.HTML(value=get_initial_vram_info())
+ shared.gradio['cpu_moe'] = gr.Checkbox(label="cpu-moe", value=shared.args.cpu_moe, info='Move the experts to the CPU. Saves VRAM on MoE models.')
+ shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming-llm", value=shared.args.streaming_llm, info='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
+ shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
+ shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
+ shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant, info='Used by load-in-4bit.')
+ shared.gradio['autosplit'] = gr.Checkbox(label="autosplit", value=shared.args.autosplit, info='Automatically split the model tensors across the available GPUs.')
+ shared.gradio['enable_tp'] = gr.Checkbox(label="enable_tp", value=shared.args.enable_tp, info='Enable tensor parallelism (TP).')
+ shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.')
+ shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
+
+ # Multimodal
+ with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']:
+ with gr.Row():
+ shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info='Select a file that matches your model. Must be placed in user_data/mmproj/', interactive=not mu)
+ ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
+
+ # Speculative decoding
+ with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
+ with gr.Row():
+ shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=['None'] + utils.get_available_models(), value=lambda: shared.args.model_draft, elem_classes='slim-dropdown', info='Draft model. Speculative decoding only works with models sharing the same vocabulary (e.g., same model family).', interactive=not mu)
+ ui.create_refresh_button(shared.gradio['model_draft'], lambda: None, lambda: {'choices': ['None'] + utils.get_available_models()}, 'refresh-button', interactive=not mu)
+
+ shared.gradio['gpu_layers_draft'] = gr.Slider(label="gpu-layers-draft", minimum=0, maximum=256, value=shared.args.gpu_layers_draft, info='Number of layers to offload to the GPU for the draft model.')
+ shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Number of tokens to draft for speculative decoding. Recommended value: 4.')
+ shared.gradio['device_draft'] = gr.Textbox(label="device-draft", value=shared.args.device_draft, info='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1')
+ shared.gradio['ctx_size_draft'] = gr.Number(label="ctx-size-draft", precision=0, step=256, value=shared.args.ctx_size_draft, info='Size of the prompt context for the draft model. If 0, uses the same as the main model.')
+
+ gr.Markdown("## Other options")
+ with gr.Accordion("See more options", open=False, elem_classes='tgw-accordion'):
with gr.Row():
with gr.Column():
- shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=0, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Must be greater than 0 for the GPU to be used. ⚠️ Lower this value if you can\'t load the model.')
- shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=256, maximum=131072, step=256, value=shared.args.ctx_size, info='Context length. Common values: 4096, 8192, 16384, 32768, 65536, 131072.')
- shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
- shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.')
- shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
- shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.')
+ shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
+ shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
+ shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
+ shared.gradio['ubatch_size'] = gr.Slider(label="ubatch_size", minimum=1, maximum=4096, step=1, value=shared.args.ubatch_size)
+ shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
+ shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"', value=shared.args.extra_flags)
+ shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory)
+ shared.gradio['alpha_value'] = gr.Number(label='alpha_value', value=shared.args.alpha_value, precision=2, info='Positional embeddings alpha factor for NTK RoPE scaling. Recommended values (NTKv1): 1.75 for 1.5x context, 2.5 for 2x context. Use either this or compress_pos_emb, not both.')
+ shared.gradio['rope_freq_base'] = gr.Number(label='rope_freq_base', value=shared.args.rope_freq_base, precision=0, info='Positional embeddings frequency base for NTK RoPE scaling. Related to alpha_value by rope_freq_base = 10000 * alpha_value ^ (64 / 63). 0 = from model.')
+ shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
+ shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
+ shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
+ shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
with gr.Column():
- shared.gradio['vram_info'] = gr.HTML(value=get_initial_vram_info())
- shared.gradio['cpu_moe'] = gr.Checkbox(label="cpu-moe", value=shared.args.cpu_moe, info='Move the experts to the CPU. Saves VRAM on MoE models.')
- shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming-llm", value=shared.args.streaming_llm, info='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
- shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
- shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
- shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant, info='Used by load-in-4bit.')
- shared.gradio['autosplit'] = gr.Checkbox(label="autosplit", value=shared.args.autosplit, info='Automatically split the model tensors across the available GPUs.')
- shared.gradio['enable_tp'] = gr.Checkbox(label="enable_tp", value=shared.args.enable_tp, info='Enable tensor parallelism (TP).')
- shared.gradio['cpp_runner'] = gr.Checkbox(label="cpp-runner", value=shared.args.cpp_runner, info='Enable inference with ModelRunnerCpp, which is faster than the default ModelRunner.')
- shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `ctx_size` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.')
-
- # Multimodal
- with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']:
+ shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='Use PyTorch in CPU mode.')
+ shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
+ shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.')
+ shared.gradio['no_kv_offload'] = gr.Checkbox(label="no_kv_offload", value=shared.args.no_kv_offload, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
+ shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
+ shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
+ shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.')
+ shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
+ shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn)
+ shared.gradio['no_xformers'] = gr.Checkbox(label="no_xformers", value=shared.args.no_xformers)
+ shared.gradio['no_sdpa'] = gr.Checkbox(label="no_sdpa", value=shared.args.no_sdpa)
+ shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.')
+ shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')
+ if not shared.args.portable:
with gr.Row():
- shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info='Select a file that matches your model. Must be placed in user_data/mmproj/', interactive=not mu)
- ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu)
-
- # Speculative decoding
- with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']:
- with gr.Row():
- shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=['None'] + utils.get_available_models(), value=lambda: shared.args.model_draft, elem_classes='slim-dropdown', info='Draft model. Speculative decoding only works with models sharing the same vocabulary (e.g., same model family).', interactive=not mu)
- ui.create_refresh_button(shared.gradio['model_draft'], lambda: None, lambda: {'choices': ['None'] + utils.get_available_models()}, 'refresh-button', interactive=not mu)
-
- shared.gradio['gpu_layers_draft'] = gr.Slider(label="gpu-layers-draft", minimum=0, maximum=256, value=shared.args.gpu_layers_draft, info='Number of layers to offload to the GPU for the draft model.')
- shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Number of tokens to draft for speculative decoding. Recommended value: 4.')
- shared.gradio['device_draft'] = gr.Textbox(label="device-draft", value=shared.args.device_draft, info='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1')
- shared.gradio['ctx_size_draft'] = gr.Number(label="ctx-size-draft", precision=0, step=256, value=shared.args.ctx_size_draft, info='Size of the prompt context for the draft model. If 0, uses the same as the main model.')
-
- gr.Markdown("## Other options")
- with gr.Accordion("See more options", open=False, elem_classes='tgw-accordion'):
- with gr.Row():
- with gr.Column():
- shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
- shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
- shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
- shared.gradio['ubatch_size'] = gr.Slider(label="ubatch_size", minimum=1, maximum=4096, step=1, value=shared.args.ubatch_size)
- shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
- shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"', value=shared.args.extra_flags)
- shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory)
- shared.gradio['alpha_value'] = gr.Number(label='alpha_value', value=shared.args.alpha_value, precision=2, info='Positional embeddings alpha factor for NTK RoPE scaling. Recommended values (NTKv1): 1.75 for 1.5x context, 2.5 for 2x context. Use either this or compress_pos_emb, not both.')
- shared.gradio['rope_freq_base'] = gr.Number(label='rope_freq_base', value=shared.args.rope_freq_base, precision=0, info='Positional embeddings frequency base for NTK RoPE scaling. Related to alpha_value by rope_freq_base = 10000 * alpha_value ^ (64 / 63). 0 = from model.')
- shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
- shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.')
- shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.')
- shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.')
-
- with gr.Column():
- shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='Use PyTorch in CPU mode.')
- shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
- shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.')
- shared.gradio['no_kv_offload'] = gr.Checkbox(label="no_kv_offload", value=shared.args.no_kv_offload, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
- shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
- shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
- shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.')
- shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
- shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn)
- shared.gradio['no_xformers'] = gr.Checkbox(label="no_xformers", value=shared.args.no_xformers)
- shared.gradio['no_sdpa'] = gr.Checkbox(label="no_sdpa", value=shared.args.no_sdpa)
- shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.')
- shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')
- if not shared.args.portable:
- with gr.Row():
- shared.gradio['lora_menu'] = gr.Dropdown(multiselect=True, choices=utils.get_available_loras(), value=shared.lora_names, label='LoRA(s)', elem_classes='slim-dropdown', interactive=not mu)
- ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': utils.get_available_loras(), 'value': shared.lora_names}, 'refresh-button', interactive=not mu)
- shared.gradio['lora_menu_apply'] = gr.Button(value='Apply LoRAs', elem_classes='refresh-button', interactive=not mu)
-
- with gr.Column():
- with gr.Tab("Download"):
- shared.gradio['custom_model_menu'] = gr.Textbox(label="Download model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main. To download a single file, enter its name in the second box.", interactive=not mu)
- shared.gradio['download_specific_file'] = gr.Textbox(placeholder="File name (for GGUF models)", show_label=False, max_lines=1, interactive=not mu)
- with gr.Row():
- shared.gradio['download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu)
- shared.gradio['get_file_list'] = gr.Button("Get file list", interactive=not mu)
-
- with gr.Tab("Customize instruction template"):
- with gr.Row():
- shared.gradio['customized_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), value='None', label='Select the desired instruction template', elem_classes='slim-dropdown')
- ui.create_refresh_button(shared.gradio['customized_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
-
- shared.gradio['customized_template_submit'] = gr.Button("Submit", variant="primary", interactive=not mu)
- gr.Markdown("This allows you to set a customized template for the model currently selected in the \"Model loader\" menu. Whenever the model gets loaded, this template will be used in place of the template specified in the model's medatada, which sometimes is wrong.")
+ shared.gradio['lora_menu'] = gr.Dropdown(multiselect=True, choices=utils.get_available_loras(), value=shared.lora_names, label='LoRA(s)', elem_classes='slim-dropdown', interactive=not mu)
+ ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': utils.get_available_loras(), 'value': shared.lora_names}, 'refresh-button', interactive=not mu)
+ shared.gradio['lora_menu_apply'] = gr.Button(value='Apply LoRAs', elem_classes='refresh-button', interactive=not mu)
+ with gr.Column():
+ with gr.Tab("Download"):
+ shared.gradio['custom_model_menu'] = gr.Textbox(label="Download model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main. To download a single file, enter its name in the second box.", interactive=not mu)
+ shared.gradio['download_specific_file'] = gr.Textbox(placeholder="File name (for GGUF models)", show_label=False, max_lines=1, interactive=not mu)
with gr.Row():
- shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
+ shared.gradio['download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu)
+ shared.gradio['get_file_list'] = gr.Button("Get file list", interactive=not mu)
- with gr.Tab("Image model"):
- with gr.Row():
- with gr.Column():
+ with gr.Tab("Customize instruction template"):
with gr.Row():
- shared.gradio['image_model_menu'] = gr.Dropdown(choices=utils.get_available_image_models(), value=lambda: shared.image_model_name, label='Model', elem_classes='slim-dropdown', interactive=not mu)
- ui.create_refresh_button(shared.gradio['image_model_menu'], lambda: None, lambda: {'choices': utils.get_available_image_models()}, 'refresh-button', interactive=not mu)
- shared.gradio['image_load_model'] = gr.Button("Load", elem_classes='refresh-button', interactive=not mu)
- shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button', interactive=not mu)
- shared.gradio['image_save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button', interactive=not mu)
+ shared.gradio['customized_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), value='None', label='Select the desired instruction template', elem_classes='slim-dropdown')
+ ui.create_refresh_button(shared.gradio['customized_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
- with gr.Blocks():
- gr.Markdown("## Main options")
- with gr.Row():
- with gr.Column():
- pass
+ shared.gradio['customized_template_submit'] = gr.Button("Submit", variant="primary", interactive=not mu)
+ gr.Markdown("This allows you to set a customized template for the model currently selected in the \"Model loader\" menu. Whenever the model gets loaded, this template will be used in place of the template specified in the model's medatada, which sometimes is wrong.")
- with gr.Column():
- pass
-
- gr.Markdown("## Other options")
- with gr.Accordion("See more options", open=False, elem_classes='tgw-accordion'):
- with gr.Row():
- with gr.Column():
- pass
-
- with gr.Column():
- pass
-
- with gr.Column():
- shared.gradio['image_custom_model_menu'] = gr.Textbox(label="Download model (diffusers format)", info="Enter the Hugging Face username/model path, for instance: Tongyi-MAI/Z-Image-Turbo. To specify a branch, add it at the end after a \":\" character like this: Tongyi-MAI/Z-Image-Turbo:main.", interactive=not mu)
- with gr.Row():
- shared.gradio['image_download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu)
-
- with gr.Row():
- shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.image_model_name == 'None' else 'Ready')
+ with gr.Row():
+ shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
def create_event_handlers():
@@ -220,28 +183,6 @@ def create_event_handlers():
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
- # Image model event handlers
- shared.gradio['image_load_model'].click(
- load_image_model_wrapper,
- gradio('image_model_menu'),
- gradio('image_model_status'),
- show_progress=True
- )
-
- shared.gradio['image_unload_model'].click(
- handle_unload_image_model_click,
- None,
- gradio('image_model_status'),
- show_progress=False
- )
-
- shared.gradio['image_download_model_button'].click(
- download_image_model_wrapper,
- gradio('image_custom_model_menu'),
- gradio('image_model_status'),
- show_progress=True
- )
-
def load_model_wrapper(selected_model, loader, autoload=False):
try:
From 21f992e7f7d6e1cf45d6890c63f31ba0be6c827d Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 15:42:11 -0800
Subject: [PATCH 10/38] Organize the UI
---
css/main.css | 7 +++
modules/ui_image_generation.py | 83 +++++++++++++++++-----------------
2 files changed, 48 insertions(+), 42 deletions(-)
diff --git a/css/main.css b/css/main.css
index 61a33a4b..4e3a0658 100644
--- a/css/main.css
+++ b/css/main.css
@@ -1674,3 +1674,10 @@ button:focus {
.dark .sidebar-vertical-separator {
border-bottom: 1px solid rgb(255 255 255 / 10%);
}
+
+button#swap-height-width {
+ position: absolute;
+ top: -50px;
+ right: 0px;
+ border: 0;
+}
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index e01f8ea7..030379fd 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -149,6 +149,8 @@ def create_ui():
with gr.Column():
height_slider = gr.Slider(256, 2048, value=1024, step=32, label="Height")
+ swap_btn = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width")
+
with gr.Row():
preset_radio = gr.Radio(
choices=["1:1 Square", "16:9 Cinema", "9:16 Mobile", "4:3 Photo", "Custom"],
@@ -156,7 +158,6 @@ def create_ui():
label="Aspect Ratio",
interactive=True
)
- swap_btn = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80)
# 4. SETTINGS & BATCHING
gr.Markdown("### ⚙️ Config")
@@ -200,53 +201,51 @@ def create_ui():
elem_classes='slim-dropdown'
)
image_refresh_models = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40)
+ image_load_model = gr.Button("Load", variant='primary', elem_classes='refresh-button')
+ image_unload_model = gr.Button("Unload", elem_classes='refresh-button')
+
+ gr.Markdown("## Settings")
with gr.Row():
- image_load_model = gr.Button("Load", variant='primary')
- image_unload_model = gr.Button("Unload")
-
- gr.Markdown("### Settings")
-
- image_dtype = gr.Dropdown(
- choices=['bfloat16', 'float16'],
- value=settings['dtype'],
- label='Data Type',
- info='bfloat16 recommended for modern GPUs'
+ with gr.Column():
+ image_dtype = gr.Dropdown(
+ choices=['bfloat16', 'float16'],
+ value=settings['dtype'],
+ label='Data Type',
+ info='bfloat16 recommended for modern GPUs'
+ )
+
+ image_attn_backend = gr.Dropdown(
+ choices=['sdpa', 'flash_attention_2', 'flash_attention_3'],
+ value=settings['attn_backend'],
+ label='Attention Backend',
+ info='SDPA is default. Flash Attention requires compatible GPU.'
+ )
+
+ with gr.Column():
+ image_compile = gr.Checkbox(
+ value=settings['compile_model'],
+ label='Compile Model',
+ info='Faster inference after first run. First run will be slow.'
+ )
+
+ image_cpu_offload = gr.Checkbox(
+ value=settings['cpu_offload'],
+ label='CPU Offload',
+ info='Enable for low VRAM GPUs. Slower but uses less memory.'
+ )
+
+ with gr.Column():
+ image_download_path = gr.Textbox(
+ label="Download model",
+ placeholder="Tongyi-MAI/Z-Image-Turbo",
+ info="Enter the HuggingFace model path like Tongyi-MAI/Z-Image-Turbo. Use : for branch, e.g. Tongyi-MAI/Z-Image-Turbo:main"
)
-
- image_attn_backend = gr.Dropdown(
- choices=['sdpa', 'flash_attention_2', 'flash_attention_3'],
- value=settings['attn_backend'],
- label='Attention Backend',
- info='SDPA is default. Flash Attention requires compatible GPU.'
- )
-
- image_cpu_offload = gr.Checkbox(
- value=settings['cpu_offload'],
- label='CPU Offload',
- info='Enable for low VRAM GPUs. Slower but uses less memory.'
- )
-
- image_compile = gr.Checkbox(
- value=settings['compile_model'],
- label='Compile Model',
- info='Faster inference after first run. First run will be slow.'
- )
-
+ image_download_btn = gr.Button("Download", variant='primary')
image_model_status = gr.Markdown(
value=f"Model: **{settings['model_name']}** (not loaded)" if settings['model_name'] != 'None' else "No model selected"
)
- with gr.Column():
- gr.Markdown("### Download Model")
- image_download_path = gr.Textbox(
- label="Hugging Face Model",
- placeholder="Tongyi-MAI/Z-Image-Turbo",
- info="Enter the HuggingFace model path. Use : for branch, e.g. model:main"
- )
- image_download_btn = gr.Button("Download", variant='primary')
- image_download_status = gr.Markdown("")
-
# === WIRING ===
# Aspect ratio preset changes -> update dimensions
@@ -326,7 +325,7 @@ def create_ui():
image_download_btn.click(
fn=download_image_model_wrapper,
inputs=[image_download_path],
- outputs=[image_download_status, image_model_menu],
+ outputs=[image_model_status, image_model_menu],
show_progress=True
)
From 666816a773c5e7c1ba3945c7eca301addb1958a8 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 15:48:53 -0800
Subject: [PATCH 11/38] Small fixes
---
modules/ui_image_generation.py | 131 ++++++++++++++++-----------------
1 file changed, 65 insertions(+), 66 deletions(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 030379fd..749ca981 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -38,19 +38,19 @@ def clamp(value, min_val, max_val):
def apply_aspect_ratio(aspect_ratio, current_width, current_height):
"""
Apply an aspect ratio preset.
-
+
Logic to prevent dimension creep:
- For tall ratios (like 9:16): keep width fixed, calculate height
- - For wide ratios (like 16:9): keep height fixed, calculate width
+ - For wide ratios (like 16:9): keep height fixed, calculate width
- For square (1:1): use the smaller of the current dimensions
-
+
Returns (new_width, new_height).
"""
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return current_width, current_height
-
+
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
-
+
if w_ratio == h_ratio:
# Square ratio - use the smaller current dimension to prevent creep
base = min(current_width, current_height)
@@ -64,11 +64,11 @@ def apply_aspect_ratio(aspect_ratio, current_width, current_height):
# Wide ratio (like 16:9) - height is the smaller side, keep it fixed
new_height = current_height
new_width = round_to_step(current_height * w_ratio / h_ratio)
-
+
# Clamp to slider bounds
new_width = clamp(new_width, 256, 2048)
new_height = clamp(new_height, 256, 2048)
-
+
return int(new_width), int(new_height)
@@ -76,11 +76,11 @@ def update_height_from_width(width, aspect_ratio):
"""Update height when width changes (if not Custom)."""
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return gr.update()
-
+
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
new_height = round_to_step(width * h_ratio / w_ratio)
new_height = clamp(new_height, 256, 2048)
-
+
return int(new_height)
@@ -88,18 +88,18 @@ def update_width_from_height(height, aspect_ratio):
"""Update width when height changes (if not Custom)."""
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return gr.update()
-
+
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
new_width = round_to_step(height * w_ratio / h_ratio)
new_width = clamp(new_width, 256, 2048)
-
+
return int(new_width)
def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
"""Swap dimensions and update aspect ratio to match (or set to Custom)."""
new_width, new_height = height, width
-
+
# Try to find a matching aspect ratio for the swapped dimensions
new_ratio = "Custom"
for name, ratios in ASPECT_RATIOS.items():
@@ -111,27 +111,27 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
if abs(expected_height - new_height) < STEP:
new_ratio = name
break
-
+
return new_width, new_height, new_ratio
def create_ui():
# Get effective settings (CLI > yaml > defaults)
settings = get_effective_settings()
-
+
# Update shared state (but don't load the model yet)
if settings['model_name'] != 'None':
shared.image_model_name = settings['model_name']
-
+
with gr.Tab("Image AI", elem_id="image-ai-tab"):
with gr.Tabs():
# TAB 1: GENERATION STUDIO
with gr.TabItem("Generate"):
with gr.Row():
-
+
# === LEFT COLUMN: CONTROLS ===
with gr.Column(scale=4, min_width=350):
-
+
# 1. PROMPT
prompt = gr.Textbox(label="Prompt", placeholder="Describe your imagination...", lines=3, autofocus=True)
neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="Low quality...", lines=3)
@@ -170,12 +170,12 @@ def create_ui():
with gr.Column():
batch_size_parallel = gr.Slider(1, 32, value=1, step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
batch_count_seq = gr.Slider(1, 128, value=1, step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
-
+
# === RIGHT COLUMN: VIEWPORT ===
with gr.Column(scale=6, min_width=500):
with gr.Column(elem_classes=["viewport-container"]):
output_gallery = gr.Gallery(
- label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True
+ label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True
)
with gr.Row():
used_seed = gr.Markdown(label="Info", interactive=False)
@@ -203,9 +203,9 @@ def create_ui():
image_refresh_models = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40)
image_load_model = gr.Button("Load", variant='primary', elem_classes='refresh-button')
image_unload_model = gr.Button("Unload", elem_classes='refresh-button')
-
+
gr.Markdown("## Settings")
-
+
with gr.Row():
with gr.Column():
image_dtype = gr.Dropdown(
@@ -214,14 +214,14 @@ def create_ui():
label='Data Type',
info='bfloat16 recommended for modern GPUs'
)
-
+
image_attn_backend = gr.Dropdown(
choices=['sdpa', 'flash_attention_2', 'flash_attention_3'],
value=settings['attn_backend'],
label='Attention Backend',
info='SDPA is default. Flash Attention requires compatible GPU.'
)
-
+
with gr.Column():
image_compile = gr.Checkbox(
value=settings['compile_model'],
@@ -234,7 +234,7 @@ def create_ui():
label='CPU Offload',
info='Enable for low VRAM GPUs. Slower but uses less memory.'
)
-
+
with gr.Column():
image_download_path = gr.Textbox(
label="Download model",
@@ -247,7 +247,7 @@ def create_ui():
)
# === WIRING ===
-
+
# Aspect ratio preset changes -> update dimensions
preset_radio.change(
fn=apply_aspect_ratio,
@@ -255,7 +255,7 @@ def create_ui():
outputs=[width_slider, height_slider],
show_progress=False
)
-
+
# Width slider changes -> update height (if not Custom)
width_slider.release(
fn=update_height_from_width,
@@ -263,7 +263,7 @@ def create_ui():
outputs=[height_slider],
show_progress=False
)
-
+
# Height slider changes -> update width (if not Custom)
height_slider.release(
fn=update_width_from_height,
@@ -271,7 +271,7 @@ def create_ui():
outputs=[width_slider],
show_progress=False
)
-
+
# Swap button -> swap dimensions and update aspect ratio
swap_btn.click(
fn=swap_dimensions_and_update_ratio,
@@ -299,7 +299,7 @@ def create_ui():
inputs=inputs,
outputs=outputs
)
-
+
# Model tab events
image_refresh_models.click(
fn=lambda: gr.update(choices=utils.get_available_image_models()),
@@ -307,28 +307,28 @@ def create_ui():
outputs=[image_model_menu],
show_progress=False
)
-
+
image_load_model.click(
fn=load_image_model_wrapper,
inputs=[image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile],
outputs=[image_model_status],
show_progress=True
)
-
+
image_unload_model.click(
fn=unload_image_model_wrapper,
inputs=None,
outputs=[image_model_status],
show_progress=False
)
-
+
image_download_btn.click(
fn=download_image_model_wrapper,
inputs=[image_download_path],
outputs=[image_model_status, image_model_menu],
show_progress=True
)
-
+
# History
refresh_btn.click(fn=get_history_images, inputs=None, outputs=history_gallery, show_progress=False)
@@ -336,40 +336,39 @@ def create_ui():
def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq,
model_menu, dtype_dropdown, attn_dropdown, cpu_offload_checkbox, compile_checkbox):
"""Generate images with the current model settings."""
-
- # Get current UI values (these are Gradio components, we need their values)
+
model_name = shared.image_model_name
-
+
if model_name == 'None':
return [], "No image model selected. Go to the Model tab and select a model."
-
+
# Auto-load model if not loaded
if shared.image_model is None:
- # Load saved settings for the model
- saved_settings = load_image_model_settings()
-
+ # Get effective settings (CLI > yaml > defaults)
+ settings = get_effective_settings()
+
result = load_image_model(
model_name,
- dtype=saved_settings['dtype'],
- attn_backend=saved_settings['attn_backend'],
- cpu_offload=saved_settings['cpu_offload'],
- compile_model=saved_settings['compile_model']
+ dtype=settings['dtype'],
+ attn_backend=settings['attn_backend'],
+ cpu_offload=settings['cpu_offload'],
+ compile_model=settings['compile_model']
)
-
+
if result is None:
return [], f"Failed to load model `{model_name}`."
-
+
if seed == -1:
seed = np.random.randint(0, 2**32 - 1)
-
+
generator = torch.Generator("cuda").manual_seed(int(seed))
all_images = []
-
+
# Sequential loop (easier on VRAM)
for i in range(int(batch_count_seq)):
current_seed = seed + i
generator.manual_seed(int(current_seed))
-
+
# Parallel generation
batch_results = shared.image_model(
prompt=prompt,
@@ -381,12 +380,12 @@ def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel
num_images_per_prompt=int(batch_size_parallel),
generator=generator,
).images
-
+
all_images.extend(batch_results)
-
+
# Save to disk
save_generated_images(all_images, prompt, seed)
-
+
return all_images, f"Seed: {seed}"
@@ -395,13 +394,13 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi
if model_name == 'None' or not model_name:
yield "No model selected"
return
-
+
try:
yield f"Loading `{model_name}`..."
-
+
# Unload existing model first
unload_image_model()
-
+
# Load the new model
result = load_image_model(
model_name,
@@ -410,14 +409,14 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi
cpu_offload=cpu_offload,
compile_model=compile_model
)
-
+
if result is not None:
# Save settings to yaml
save_image_model_settings(model_name, dtype, attn_backend, cpu_offload, compile_model)
yield f"✓ Loaded **{model_name}**"
else:
yield f"✗ Failed to load `{model_name}`"
-
+
except Exception:
exc = traceback.format_exc()
yield f"Error:\n```\n{exc}\n```"
@@ -426,7 +425,7 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi
def unload_image_model_wrapper():
"""Unload model wrapper."""
unload_image_model()
-
+
if shared.image_model_name != 'None':
return f"Model: **{shared.image_model_name}** (not loaded)"
else:
@@ -436,36 +435,36 @@ def unload_image_model_wrapper():
def download_image_model_wrapper(model_path):
"""Download a model from Hugging Face."""
from huggingface_hub import snapshot_download
-
+
if not model_path:
yield "No model specified", gr.update()
return
-
+
try:
# Parse model name and branch
if ':' in model_path:
model_id, branch = model_path.rsplit(':', 1)
else:
model_id, branch = model_path, 'main'
-
+
# Output folder name
folder_name = model_id.split('/')[-1]
output_folder = Path(shared.args.image_model_dir) / folder_name
-
+
yield f"Downloading `{model_id}` (branch: {branch})...", gr.update()
-
+
snapshot_download(
repo_id=model_id,
revision=branch,
local_dir=output_folder,
local_dir_use_symlinks=False,
)
-
+
# Refresh the model list
new_choices = utils.get_available_image_models()
-
+
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
-
+
except Exception:
exc = traceback.format_exc()
yield f"Error:\n```\n{exc}\n```", gr.update()
From 9e33c6bfb70cb392ba81311504692198f01d8a31 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 15:56:58 -0800
Subject: [PATCH 12/38] Add missing files
---
modules/image_model_settings.py | 108 ++++++++++++++++++++++++++++++++
modules/image_models.py | 81 ++++++++++++++++++++++++
2 files changed, 189 insertions(+)
create mode 100644 modules/image_model_settings.py
create mode 100644 modules/image_models.py
diff --git a/modules/image_model_settings.py b/modules/image_model_settings.py
new file mode 100644
index 00000000..9e644dc8
--- /dev/null
+++ b/modules/image_model_settings.py
@@ -0,0 +1,108 @@
+# modules/image_model_settings.py
+import os
+from pathlib import Path
+
+import yaml
+
+import modules.shared as shared
+from modules.logging_colors import logger
+
+
+DEFAULTS = {
+ 'model_name': 'None',
+ 'dtype': 'bfloat16',
+ 'attn_backend': 'sdpa',
+ 'cpu_offload': False,
+ 'compile_model': False,
+}
+
+
+def get_settings_path():
+ """Get the path to the image model settings file."""
+ return Path(shared.args.image_model_dir) / 'settings.yaml'
+
+
+def load_yaml_settings():
+ """Load raw settings from yaml file."""
+ settings_path = get_settings_path()
+
+ if not settings_path.exists():
+ return {}
+
+ try:
+ with open(settings_path, 'r') as f:
+ saved = yaml.safe_load(f)
+ return saved if saved else {}
+ except Exception as e:
+ logger.warning(f"Failed to load image model settings: {e}")
+ return {}
+
+
+def get_effective_settings():
+ """
+ Get effective settings with precedence:
+ 1. CLI flag (if provided)
+ 2. Saved yaml value (if exists)
+ 3. Hardcoded default
+
+ Returns a dict with all settings.
+ """
+ yaml_settings = load_yaml_settings()
+
+ effective = {}
+
+ # model_name: CLI --image-model > yaml > default
+ if shared.args.image_model:
+ effective['model_name'] = shared.args.image_model
+ else:
+ effective['model_name'] = yaml_settings.get('model_name', DEFAULTS['model_name'])
+
+ # dtype: CLI --image-dtype > yaml > default
+ if shared.args.image_dtype is not None:
+ effective['dtype'] = shared.args.image_dtype
+ else:
+ effective['dtype'] = yaml_settings.get('dtype', DEFAULTS['dtype'])
+
+ # attn_backend: CLI --image-attn-backend > yaml > default
+ if shared.args.image_attn_backend is not None:
+ effective['attn_backend'] = shared.args.image_attn_backend
+ else:
+ effective['attn_backend'] = yaml_settings.get('attn_backend', DEFAULTS['attn_backend'])
+
+ # cpu_offload: CLI --image-cpu-offload > yaml > default
+ # For store_true flags, check if explicitly set (True means it was passed)
+ if shared.args.image_cpu_offload:
+ effective['cpu_offload'] = True
+ else:
+ effective['cpu_offload'] = yaml_settings.get('cpu_offload', DEFAULTS['cpu_offload'])
+
+ # compile_model: CLI --image-compile > yaml > default
+ if shared.args.image_compile:
+ effective['compile_model'] = True
+ else:
+ effective['compile_model'] = yaml_settings.get('compile_model', DEFAULTS['compile_model'])
+
+ return effective
+
+
+def save_image_model_settings(model_name, dtype, attn_backend, cpu_offload, compile_model):
+ """Save image model settings to yaml."""
+ settings_path = get_settings_path()
+
+ # Ensure directory exists
+ settings_path.parent.mkdir(parents=True, exist_ok=True)
+
+ settings = {
+ 'model_name': model_name,
+ 'dtype': dtype,
+ 'attn_backend': attn_backend,
+ 'cpu_offload': cpu_offload,
+ 'compile_model': compile_model,
+ }
+
+ try:
+ with open(settings_path, 'w') as f:
+ yaml.dump(settings, f, default_flow_style=False)
+ logger.info(f"Saved image model settings to {settings_path}")
+ except Exception as e:
+ logger.error(f"Failed to save image model settings: {e}")
diff --git a/modules/image_models.py b/modules/image_models.py
new file mode 100644
index 00000000..0d910e42
--- /dev/null
+++ b/modules/image_models.py
@@ -0,0 +1,81 @@
+# modules/image_models.py
+import time
+import torch
+
+import modules.shared as shared
+from modules.logging_colors import logger
+from modules.utils import resolve_model_path
+from modules.torch_utils import get_device
+
+
+def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False):
+ """
+ Load a diffusers image generation model.
+
+ Args:
+ model_name: Name of the model directory
+ dtype: 'bfloat16' or 'float16'
+ attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3'
+ cpu_offload: Enable CPU offloading for low VRAM
+ compile_model: Compile the model for faster inference (slow first run)
+ """
+ from diffusers import ZImagePipeline
+
+ logger.info(f"Loading image model \"{model_name}\"")
+ t0 = time.time()
+
+ dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
+ target_dtype = dtype_map.get(dtype, torch.bfloat16)
+
+ model_path = resolve_model_path(model_name, image_model=True)
+
+ try:
+ pipe = ZImagePipeline.from_pretrained(
+ str(model_path),
+ torch_dtype=target_dtype,
+ low_cpu_mem_usage=False,
+ )
+
+ if not cpu_offload:
+ pipe.to(get_device())
+
+ # Set attention backend
+ if attn_backend == 'flash_attention_2':
+ pipe.transformer.set_attention_backend("flash")
+ elif attn_backend == 'flash_attention_3':
+ pipe.transformer.set_attention_backend("_flash_3")
+ # sdpa is the default, no action needed
+
+ if compile_model:
+ logger.info("Compiling model (first run will be slow)...")
+ pipe.transformer.compile()
+
+ if cpu_offload:
+ pipe.enable_model_cpu_offload()
+
+ shared.image_model = pipe
+ shared.image_model_name = model_name
+
+ logger.info(f"Loaded image model \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
+ return pipe
+
+ except Exception as e:
+ logger.error(f"Failed to load image model: {str(e)}")
+ return None
+
+
+def unload_image_model():
+ """Unload the current image model and free VRAM."""
+ if shared.image_model is None:
+ return
+
+ del shared.image_model
+ shared.image_model = None
+ shared.image_model_name = 'None'
+
+ # Clear CUDA cache
+ if torch.cuda.is_available():
+
+ torch.cuda.empty_cache()
+
+ logger.info("Image model unloaded.")
From 74eedf605040e3ae2a68fbfab849e14c4a30c552 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 16:13:08 -0800
Subject: [PATCH 13/38] Remove the CFG slider
---
modules/ui_image_generation.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 749ca981..8c5212fe 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -164,7 +164,6 @@ def create_ui():
with gr.Row():
with gr.Column():
steps_slider = gr.Slider(1, 15, value=9, step=1, label="Steps")
- cfg_slider = gr.Slider(value=0.0, label="Guidance", interactive=False, info="Locked")
seed_input = gr.Number(label="Seed", value=-1, precision=0, info="-1 = Random")
with gr.Column():
From 30d1f502aa1383c8991fad048008c5a303744378 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 16:37:03 -0800
Subject: [PATCH 14/38] More informative download message
---
modules/ui_image_generation.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 8c5212fe..7c6a55a6 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -446,11 +446,11 @@ def download_image_model_wrapper(model_path):
else:
model_id, branch = model_path, 'main'
- # Output folder name
- folder_name = model_id.split('/')[-1]
+ # Output folder name (username_model format)
+ folder_name = model_id.replace('/', '_')
output_folder = Path(shared.args.image_model_dir) / folder_name
- yield f"Downloading `{model_id}` (branch: {branch})...", gr.update()
+ yield f"Downloading `{model_id}` (branch: {branch}) to `{output_folder}`...", gr.update()
snapshot_download(
repo_id=model_id,
From 822e74ac970a0a1f07731ab3bc7e666a016fcd1f Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 18:15:15 -0800
Subject: [PATCH 15/38] Lint
---
css/main.css | 2 +-
modules/image_model_settings.py | 3 ---
modules/image_models.py | 12 +++++-------
modules/ui_image_generation.py | 7 ++++---
4 files changed, 10 insertions(+), 14 deletions(-)
diff --git a/css/main.css b/css/main.css
index 4e3a0658..565aaf49 100644
--- a/css/main.css
+++ b/css/main.css
@@ -1678,6 +1678,6 @@ button:focus {
button#swap-height-width {
position: absolute;
top: -50px;
- right: 0px;
+ right: 0;
border: 0;
}
diff --git a/modules/image_model_settings.py b/modules/image_model_settings.py
index 9e644dc8..edb6bf20 100644
--- a/modules/image_model_settings.py
+++ b/modules/image_model_settings.py
@@ -1,5 +1,3 @@
-# modules/image_model_settings.py
-import os
from pathlib import Path
import yaml
@@ -7,7 +5,6 @@ import yaml
import modules.shared as shared
from modules.logging_colors import logger
-
DEFAULTS = {
'model_name': 'None',
'dtype': 'bfloat16',
diff --git a/modules/image_models.py b/modules/image_models.py
index 0d910e42..21612f61 100644
--- a/modules/image_models.py
+++ b/modules/image_models.py
@@ -1,11 +1,11 @@
-# modules/image_models.py
import time
+
import torch
import modules.shared as shared
from modules.logging_colors import logger
-from modules.utils import resolve_model_path
from modules.torch_utils import get_device
+from modules.utils import resolve_model_path
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False):
@@ -56,7 +56,7 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
shared.image_model = pipe
shared.image_model_name = model_name
- logger.info(f"Loaded image model \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
+ logger.info(f"Loaded image model \"{model_name}\" in {(time.time() - t0):.2f} seconds.")
return pipe
except Exception as e:
@@ -73,9 +73,7 @@ def unload_image_model():
shared.image_model = None
shared.image_model_name = 'None'
- # Clear CUDA cache
- if torch.cuda.is_available():
-
- torch.cuda.empty_cache()
+ from modules.torch_utils import clear_torch_cache
+ clear_torch_cache()
logger.info("Image model unloaded.")
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 7c6a55a6..038e96c8 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -1,4 +1,3 @@
-# modules/ui_image_generation.py
import os
import traceback
from datetime import datetime
@@ -9,9 +8,11 @@ import numpy as np
import torch
from modules import shared, utils
+from modules.image_model_settings import (
+ get_effective_settings,
+ save_image_model_settings
+)
from modules.image_models import load_image_model, unload_image_model
-from modules.image_model_settings import get_effective_settings, save_image_model_settings
-
# Aspect ratio definitions: name -> (width_ratio, height_ratio)
ASPECT_RATIOS = {
From 742db85de082b11779c5d51ca2aab87c1210c804 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 18:23:26 -0800
Subject: [PATCH 16/38] Hardcode 8-bit quantization for now
---
modules/image_models.py | 11 +++++++++--
1 file changed, 9 insertions(+), 2 deletions(-)
diff --git a/modules/image_models.py b/modules/image_models.py
index 21612f61..6a6c6547 100644
--- a/modules/image_models.py
+++ b/modules/image_models.py
@@ -19,7 +19,7 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
cpu_offload: Enable CPU offloading for low VRAM
compile_model: Compile the model for faster inference (slow first run)
"""
- from diffusers import ZImagePipeline
+ from diffusers import PipelineQuantizationConfig, ZImagePipeline
logger.info(f"Loading image model \"{model_name}\"")
t0 = time.time()
@@ -30,10 +30,17 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
model_path = resolve_model_path(model_name, image_model=True)
try:
+ # Define quantization config for 8-bit
+ pipeline_quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_8bit",
+ quant_kwargs={"load_in_8bit": True},
+ )
+
pipe = ZImagePipeline.from_pretrained(
str(model_path),
+ quantization_config=pipeline_quant_config,
torch_dtype=target_dtype,
- low_cpu_mem_usage=False,
+ low_cpu_mem_usage=True,
)
if not cpu_offload:
From cecb172d2c7445f06dd90ee86aa6e9c0b437e1fe Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Thu, 27 Nov 2025 18:29:32 -0800
Subject: [PATCH 17/38] Add the code for 4-bit quantization
---
modules/image_models.py | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/modules/image_models.py b/modules/image_models.py
index 6a6c6547..9e2075fd 100644
--- a/modules/image_models.py
+++ b/modules/image_models.py
@@ -36,6 +36,17 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
quant_kwargs={"load_in_8bit": True},
)
+ # Define quantization config for 4-bit
+ # pipeline_quant_config = PipelineQuantizationConfig(
+ # quant_backend="bitsandbytes_4bit",
+ # quant_kwargs={
+ # "load_in_4bit": True,
+ # "bnb_4bit_quant_type": "nf4", # Or "fp4" for floating point
+ # "bnb_4bit_compute_dtype": torch.bfloat16, # For faster computation
+ # "bnb_4bit_use_double_quant": True, # Nested quantization for extra savings
+ # },
+ # )
+
pipe = ZImagePipeline.from_pretrained(
str(model_path),
quantization_config=pipeline_quant_config,
From b42192c2b7e6fb541b8d6e77c478653ba4a8c1e4 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 10:42:03 -0800
Subject: [PATCH 18/38] Implement settings autosaving
---
modules/image_model_settings.py | 105 ---------
modules/shared.py | 30 +++
modules/ui.py | 34 ++-
modules/ui_image_generation.py | 382 +++++++++++++-------------------
modules/ui_model_menu.py | 63 ------
server.py | 4 +
6 files changed, 217 insertions(+), 401 deletions(-)
delete mode 100644 modules/image_model_settings.py
diff --git a/modules/image_model_settings.py b/modules/image_model_settings.py
deleted file mode 100644
index edb6bf20..00000000
--- a/modules/image_model_settings.py
+++ /dev/null
@@ -1,105 +0,0 @@
-from pathlib import Path
-
-import yaml
-
-import modules.shared as shared
-from modules.logging_colors import logger
-
-DEFAULTS = {
- 'model_name': 'None',
- 'dtype': 'bfloat16',
- 'attn_backend': 'sdpa',
- 'cpu_offload': False,
- 'compile_model': False,
-}
-
-
-def get_settings_path():
- """Get the path to the image model settings file."""
- return Path(shared.args.image_model_dir) / 'settings.yaml'
-
-
-def load_yaml_settings():
- """Load raw settings from yaml file."""
- settings_path = get_settings_path()
-
- if not settings_path.exists():
- return {}
-
- try:
- with open(settings_path, 'r') as f:
- saved = yaml.safe_load(f)
- return saved if saved else {}
- except Exception as e:
- logger.warning(f"Failed to load image model settings: {e}")
- return {}
-
-
-def get_effective_settings():
- """
- Get effective settings with precedence:
- 1. CLI flag (if provided)
- 2. Saved yaml value (if exists)
- 3. Hardcoded default
-
- Returns a dict with all settings.
- """
- yaml_settings = load_yaml_settings()
-
- effective = {}
-
- # model_name: CLI --image-model > yaml > default
- if shared.args.image_model:
- effective['model_name'] = shared.args.image_model
- else:
- effective['model_name'] = yaml_settings.get('model_name', DEFAULTS['model_name'])
-
- # dtype: CLI --image-dtype > yaml > default
- if shared.args.image_dtype is not None:
- effective['dtype'] = shared.args.image_dtype
- else:
- effective['dtype'] = yaml_settings.get('dtype', DEFAULTS['dtype'])
-
- # attn_backend: CLI --image-attn-backend > yaml > default
- if shared.args.image_attn_backend is not None:
- effective['attn_backend'] = shared.args.image_attn_backend
- else:
- effective['attn_backend'] = yaml_settings.get('attn_backend', DEFAULTS['attn_backend'])
-
- # cpu_offload: CLI --image-cpu-offload > yaml > default
- # For store_true flags, check if explicitly set (True means it was passed)
- if shared.args.image_cpu_offload:
- effective['cpu_offload'] = True
- else:
- effective['cpu_offload'] = yaml_settings.get('cpu_offload', DEFAULTS['cpu_offload'])
-
- # compile_model: CLI --image-compile > yaml > default
- if shared.args.image_compile:
- effective['compile_model'] = True
- else:
- effective['compile_model'] = yaml_settings.get('compile_model', DEFAULTS['compile_model'])
-
- return effective
-
-
-def save_image_model_settings(model_name, dtype, attn_backend, cpu_offload, compile_model):
- """Save image model settings to yaml."""
- settings_path = get_settings_path()
-
- # Ensure directory exists
- settings_path.parent.mkdir(parents=True, exist_ok=True)
-
- settings = {
- 'model_name': model_name,
- 'dtype': dtype,
- 'attn_backend': attn_backend,
- 'cpu_offload': cpu_offload,
- 'compile_model': compile_model,
- }
-
- try:
- with open(settings_path, 'w') as f:
- yaml.dump(settings, f, default_flow_style=False)
- logger.info(f"Saved image model settings to {settings_path}")
- except Exception as e:
- logger.error(f"Failed to save image model settings: {e}")
diff --git a/modules/shared.py b/modules/shared.py
index e54eca8f..9a062e91 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -303,6 +303,22 @@ settings = {
# Extensions
'default_extensions': [],
+
+ # Image generation settings
+ 'image_prompt': '',
+ 'image_neg_prompt': '',
+ 'image_width': 1024,
+ 'image_height': 1024,
+ 'image_aspect_ratio': '1:1 Square',
+ 'image_steps': 9,
+ 'image_seed': -1,
+ 'image_batch_size': 1,
+ 'image_batch_count': 1,
+ 'image_model_menu': 'None',
+ 'image_dtype': 'bfloat16',
+ 'image_attn_backend': 'sdpa',
+ 'image_compile': False,
+ 'image_cpu_offload': False,
}
default_settings = copy.deepcopy(settings)
@@ -327,6 +343,20 @@ def do_cmd_flags_warnings():
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
+def apply_image_model_cli_overrides():
+ """Apply CLI flags for image model settings, overriding saved settings."""
+ if args.image_model:
+ settings['image_model_menu'] = args.image_model
+ if args.image_dtype is not None:
+ settings['image_dtype'] = args.image_dtype
+ if args.image_attn_backend is not None:
+ settings['image_attn_backend'] = args.image_attn_backend
+ if args.image_cpu_offload:
+ settings['image_cpu_offload'] = True
+ if args.image_compile:
+ settings['image_compile'] = True
+
+
def fix_loader_name(name):
if not name:
return name
diff --git a/modules/ui.py b/modules/ui.py
index f99e8b6a..3aba20b4 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -280,6 +280,24 @@ def list_interface_input_elements():
'include_past_attachments',
]
+ # Image generation elements
+ elements += [
+ 'image_prompt',
+ 'image_neg_prompt',
+ 'image_width',
+ 'image_height',
+ 'image_aspect_ratio',
+ 'image_steps',
+ 'image_seed',
+ 'image_batch_size',
+ 'image_batch_count',
+ 'image_model_menu',
+ 'image_dtype',
+ 'image_attn_backend',
+ 'image_compile',
+ 'image_cpu_offload',
+ ]
+
return elements
@@ -509,7 +527,21 @@ def setup_auto_save():
'theme_state',
'show_two_notebook_columns',
'paste_to_attachment',
- 'include_past_attachments'
+ 'include_past_attachments',
+
+ # Image generation tab (ui_image_generation.py)
+ 'image_width',
+ 'image_height',
+ 'image_aspect_ratio',
+ 'image_steps',
+ 'image_seed',
+ 'image_batch_size',
+ 'image_batch_count',
+ 'image_model_menu',
+ 'image_dtype',
+ 'image_attn_backend',
+ 'image_compile',
+ 'image_cpu_offload',
]
for element_name in change_elements:
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 038e96c8..fe0c8120 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -7,14 +7,10 @@ import gradio as gr
import numpy as np
import torch
-from modules import shared, utils
-from modules.image_model_settings import (
- get_effective_settings,
- save_image_model_settings
-)
+from modules import shared, ui, utils
from modules.image_models import load_image_model, unload_image_model
+from modules.utils import gradio
-# Aspect ratio definitions: name -> (width_ratio, height_ratio)
ASPECT_RATIOS = {
"1:1 Square": (1, 1),
"16:9 Cinema": (16, 9),
@@ -23,50 +19,34 @@ ASPECT_RATIOS = {
"Custom": None,
}
-STEP = 32 # Slider step for rounding
+STEP = 32
def round_to_step(value, step=STEP):
- """Round a value to the nearest step."""
return round(value / step) * step
def clamp(value, min_val, max_val):
- """Clamp value between min and max."""
return max(min_val, min(max_val, value))
def apply_aspect_ratio(aspect_ratio, current_width, current_height):
- """
- Apply an aspect ratio preset.
-
- Logic to prevent dimension creep:
- - For tall ratios (like 9:16): keep width fixed, calculate height
- - For wide ratios (like 16:9): keep height fixed, calculate width
- - For square (1:1): use the smaller of the current dimensions
-
- Returns (new_width, new_height).
- """
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return current_width, current_height
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
if w_ratio == h_ratio:
- # Square ratio - use the smaller current dimension to prevent creep
base = min(current_width, current_height)
new_width = base
new_height = base
elif w_ratio < h_ratio:
- # Tall ratio (like 9:16) - width is the smaller side, keep it fixed
new_width = current_width
new_height = round_to_step(current_width * h_ratio / w_ratio)
else:
- # Wide ratio (like 16:9) - height is the smaller side, keep it fixed
new_height = current_height
new_width = round_to_step(current_height * w_ratio / h_ratio)
- # Clamp to slider bounds
new_width = clamp(new_width, 256, 2048)
new_height = clamp(new_height, 256, 2048)
@@ -74,7 +54,6 @@ def apply_aspect_ratio(aspect_ratio, current_width, current_height):
def update_height_from_width(width, aspect_ratio):
- """Update height when width changes (if not Custom)."""
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return gr.update()
@@ -86,7 +65,6 @@ def update_height_from_width(width, aspect_ratio):
def update_width_from_height(height, aspect_ratio):
- """Update width when height changes (if not Custom)."""
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return gr.update()
@@ -98,16 +76,13 @@ def update_width_from_height(height, aspect_ratio):
def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
- """Swap dimensions and update aspect ratio to match (or set to Custom)."""
new_width, new_height = height, width
- # Try to find a matching aspect ratio for the swapped dimensions
new_ratio = "Custom"
for name, ratios in ASPECT_RATIOS.items():
if ratios is None:
continue
w_r, h_r = ratios
- # Check if the swapped dimensions match this ratio (within tolerance)
expected_height = new_width * h_r / w_r
if abs(expected_height - new_height) < STEP:
new_ratio = name
@@ -117,291 +92,257 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
def create_ui():
- # Get effective settings (CLI > yaml > defaults)
- settings = get_effective_settings()
-
- # Update shared state (but don't load the model yet)
- if settings['model_name'] != 'None':
- shared.image_model_name = settings['model_name']
+ if shared.settings['image_model_menu'] != 'None':
+ shared.image_model_name = shared.settings['image_model_menu']
with gr.Tab("Image AI", elem_id="image-ai-tab"):
with gr.Tabs():
- # TAB 1: GENERATION STUDIO
+ # TAB 1: GENERATE
with gr.TabItem("Generate"):
with gr.Row():
-
- # === LEFT COLUMN: CONTROLS ===
with gr.Column(scale=4, min_width=350):
+ shared.gradio['image_prompt'] = gr.Textbox(
+ label="Prompt",
+ placeholder="Describe your imagination...",
+ lines=3,
+ autofocus=True,
+ value=shared.settings['image_prompt']
+ )
+ shared.gradio['image_neg_prompt'] = gr.Textbox(
+ label="Negative Prompt",
+ placeholder="Low quality...",
+ lines=3,
+ value=shared.settings['image_neg_prompt']
+ )
- # 1. PROMPT
- prompt = gr.Textbox(label="Prompt", placeholder="Describe your imagination...", lines=3, autofocus=True)
- neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="Low quality...", lines=3)
-
- # 2. GENERATE BUTTON
- generate_btn = gr.Button("✨ GENERATE", variant="primary", size="lg", elem_id="gen-btn")
+ shared.gradio['image_generate_btn'] = gr.Button("✨ GENERATE", variant="primary", size="lg", elem_id="gen-btn")
gr.HTML("
")
- # 3. DIMENSIONS
gr.Markdown("### 📐 Dimensions")
with gr.Row():
with gr.Column():
- width_slider = gr.Slider(256, 2048, value=1024, step=32, label="Width")
-
+ shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=32, label="Width")
with gr.Column():
- height_slider = gr.Slider(256, 2048, value=1024, step=32, label="Height")
-
- swap_btn = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width")
+ shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=32, label="Height")
+ shared.gradio['image_swap_btn'] = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width")
with gr.Row():
- preset_radio = gr.Radio(
+ shared.gradio['image_aspect_ratio'] = gr.Radio(
choices=["1:1 Square", "16:9 Cinema", "9:16 Mobile", "4:3 Photo", "Custom"],
- value="1:1 Square",
+ value=shared.settings['image_aspect_ratio'],
label="Aspect Ratio",
interactive=True
)
- # 4. SETTINGS & BATCHING
gr.Markdown("### ⚙️ Config")
with gr.Row():
with gr.Column():
- steps_slider = gr.Slider(1, 15, value=9, step=1, label="Steps")
- seed_input = gr.Number(label="Seed", value=-1, precision=0, info="-1 = Random")
-
+ shared.gradio['image_steps'] = gr.Slider(1, 15, value=shared.settings['image_steps'], step=1, label="Steps")
+ shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random")
with gr.Column():
- batch_size_parallel = gr.Slider(1, 32, value=1, step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
- batch_count_seq = gr.Slider(1, 128, value=1, step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
+ shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
+ shared.gradio['image_batch_count'] = gr.Slider(1, 128, value=shared.settings['image_batch_count'], step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
- # === RIGHT COLUMN: VIEWPORT ===
with gr.Column(scale=6, min_width=500):
with gr.Column(elem_classes=["viewport-container"]):
- output_gallery = gr.Gallery(
- label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True
- )
+ shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True)
with gr.Row():
- used_seed = gr.Markdown(label="Info", interactive=False)
+ shared.gradio['image_used_seed'] = gr.Markdown(label="Info", interactive=False)
- # TAB 2: HISTORY VIEWER
+ # TAB 2: GALLERY
with gr.TabItem("Gallery"):
with gr.Row():
- refresh_btn = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
+ shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
+ shared.gradio['image_history_gallery'] = gr.Gallery(label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True)
- history_gallery = gr.Gallery(
- label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True
- )
-
- # TAB 3: MODEL SETTINGS
+ # TAB 3: MODEL
with gr.TabItem("Model"):
with gr.Row():
with gr.Column():
with gr.Row():
- image_model_menu = gr.Dropdown(
+ shared.gradio['image_model_menu'] = gr.Dropdown(
choices=utils.get_available_image_models(),
- value=settings['model_name'],
+ value=shared.settings['image_model_menu'],
label='Model',
elem_classes='slim-dropdown'
)
- image_refresh_models = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40)
- image_load_model = gr.Button("Load", variant='primary', elem_classes='refresh-button')
- image_unload_model = gr.Button("Unload", elem_classes='refresh-button')
+ shared.gradio['image_refresh_models'] = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40)
+ shared.gradio['image_load_model'] = gr.Button("Load", variant='primary', elem_classes='refresh-button')
+ shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button')
gr.Markdown("## Settings")
-
with gr.Row():
with gr.Column():
- image_dtype = gr.Dropdown(
+ shared.gradio['image_dtype'] = gr.Dropdown(
choices=['bfloat16', 'float16'],
- value=settings['dtype'],
+ value=shared.settings['image_dtype'],
label='Data Type',
info='bfloat16 recommended for modern GPUs'
)
-
- image_attn_backend = gr.Dropdown(
+ shared.gradio['image_attn_backend'] = gr.Dropdown(
choices=['sdpa', 'flash_attention_2', 'flash_attention_3'],
- value=settings['attn_backend'],
+ value=shared.settings['image_attn_backend'],
label='Attention Backend',
info='SDPA is default. Flash Attention requires compatible GPU.'
)
-
with gr.Column():
- image_compile = gr.Checkbox(
- value=settings['compile_model'],
+ shared.gradio['image_compile'] = gr.Checkbox(
+ value=shared.settings['image_compile'],
label='Compile Model',
info='Faster inference after first run. First run will be slow.'
)
-
- image_cpu_offload = gr.Checkbox(
- value=settings['cpu_offload'],
+ shared.gradio['image_cpu_offload'] = gr.Checkbox(
+ value=shared.settings['image_cpu_offload'],
label='CPU Offload',
info='Enable for low VRAM GPUs. Slower but uses less memory.'
)
with gr.Column():
- image_download_path = gr.Textbox(
+ shared.gradio['image_download_path'] = gr.Textbox(
label="Download model",
placeholder="Tongyi-MAI/Z-Image-Turbo",
- info="Enter the HuggingFace model path like Tongyi-MAI/Z-Image-Turbo. Use : for branch, e.g. Tongyi-MAI/Z-Image-Turbo:main"
+ info="Enter HuggingFace path. Use : for branch, e.g. user/model:main"
)
- image_download_btn = gr.Button("Download", variant='primary')
- image_model_status = gr.Markdown(
- value=f"Model: **{settings['model_name']}** (not loaded)" if settings['model_name'] != 'None' else "No model selected"
+ shared.gradio['image_download_btn'] = gr.Button("Download", variant='primary')
+ shared.gradio['image_model_status'] = gr.Markdown(
+ value=f"Model: **{shared.settings['image_model_menu']}** (not loaded)" if shared.settings['image_model_menu'] != 'None' else "No model selected"
)
- # === WIRING ===
- # Aspect ratio preset changes -> update dimensions
- preset_radio.change(
- fn=apply_aspect_ratio,
- inputs=[preset_radio, width_slider, height_slider],
- outputs=[width_slider, height_slider],
- show_progress=False
- )
+def create_event_handlers():
+ # Dimension controls
+ shared.gradio['image_aspect_ratio'].change(
+ apply_aspect_ratio,
+ gradio('image_aspect_ratio', 'image_width', 'image_height'),
+ gradio('image_width', 'image_height'),
+ show_progress=False
+ )
- # Width slider changes -> update height (if not Custom)
- width_slider.release(
- fn=update_height_from_width,
- inputs=[width_slider, preset_radio],
- outputs=[height_slider],
- show_progress=False
- )
+ shared.gradio['image_width'].release(
+ update_height_from_width,
+ gradio('image_width', 'image_aspect_ratio'),
+ gradio('image_height'),
+ show_progress=False
+ )
- # Height slider changes -> update width (if not Custom)
- height_slider.release(
- fn=update_width_from_height,
- inputs=[height_slider, preset_radio],
- outputs=[width_slider],
- show_progress=False
- )
+ shared.gradio['image_height'].release(
+ update_width_from_height,
+ gradio('image_height', 'image_aspect_ratio'),
+ gradio('image_width'),
+ show_progress=False
+ )
- # Swap button -> swap dimensions and update aspect ratio
- swap_btn.click(
- fn=swap_dimensions_and_update_ratio,
- inputs=[width_slider, height_slider, preset_radio],
- outputs=[width_slider, height_slider, preset_radio],
- show_progress=False
- )
+ shared.gradio['image_swap_btn'].click(
+ swap_dimensions_and_update_ratio,
+ gradio('image_width', 'image_height', 'image_aspect_ratio'),
+ gradio('image_width', 'image_height', 'image_aspect_ratio'),
+ show_progress=False
+ )
- # Generation
- inputs = [prompt, neg_prompt, width_slider, height_slider, steps_slider, seed_input, batch_size_parallel, batch_count_seq]
- outputs = [output_gallery, used_seed]
+ # Generation
+ shared.gradio['image_generate_btn'].click(
+ ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
+ generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
- generate_btn.click(
- fn=lambda *args: generate(*args, image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile),
- inputs=inputs,
- outputs=outputs
- )
- prompt.submit(
- fn=lambda *args: generate(*args, image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile),
- inputs=inputs,
- outputs=outputs
- )
- neg_prompt.submit(
- fn=lambda *args: generate(*args, image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile),
- inputs=inputs,
- outputs=outputs
- )
+ shared.gradio['image_prompt'].submit(
+ ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
+ generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
- # Model tab events
- image_refresh_models.click(
- fn=lambda: gr.update(choices=utils.get_available_image_models()),
- inputs=None,
- outputs=[image_model_menu],
- show_progress=False
- )
+ shared.gradio['image_neg_prompt'].submit(
+ ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
+ generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
- image_load_model.click(
- fn=load_image_model_wrapper,
- inputs=[image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile],
- outputs=[image_model_status],
- show_progress=True
- )
+ # Model management
+ shared.gradio['image_refresh_models'].click(
+ lambda: gr.update(choices=utils.get_available_image_models()),
+ None,
+ gradio('image_model_menu'),
+ show_progress=False
+ )
- image_unload_model.click(
- fn=unload_image_model_wrapper,
- inputs=None,
- outputs=[image_model_status],
- show_progress=False
- )
+ shared.gradio['image_load_model'].click(
+ load_image_model_wrapper,
+ gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile'),
+ gradio('image_model_status'),
+ show_progress=True
+ )
- image_download_btn.click(
- fn=download_image_model_wrapper,
- inputs=[image_download_path],
- outputs=[image_model_status, image_model_menu],
- show_progress=True
- )
+ shared.gradio['image_unload_model'].click(
+ unload_image_model_wrapper,
+ None,
+ gradio('image_model_status'),
+ show_progress=False
+ )
- # History
- refresh_btn.click(fn=get_history_images, inputs=None, outputs=history_gallery, show_progress=False)
+ shared.gradio['image_download_btn'].click(
+ download_image_model_wrapper,
+ gradio('image_download_path'),
+ gradio('image_model_status', 'image_model_menu'),
+ show_progress=True
+ )
+
+ # History
+ shared.gradio['image_refresh_history'].click(
+ get_history_images,
+ None,
+ gradio('image_history_gallery'),
+ show_progress=False
+ )
-def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq,
- model_menu, dtype_dropdown, attn_dropdown, cpu_offload_checkbox, compile_checkbox):
- """Generate images with the current model settings."""
+def generate(state):
+ model_name = state['image_model_menu']
- model_name = shared.image_model_name
-
- if model_name == 'None':
+ if not model_name or model_name == 'None':
return [], "No image model selected. Go to the Model tab and select a model."
- # Auto-load model if not loaded
if shared.image_model is None:
- # Get effective settings (CLI > yaml > defaults)
- settings = get_effective_settings()
-
result = load_image_model(
model_name,
- dtype=settings['dtype'],
- attn_backend=settings['attn_backend'],
- cpu_offload=settings['cpu_offload'],
- compile_model=settings['compile_model']
+ dtype=state['image_dtype'],
+ attn_backend=state['image_attn_backend'],
+ cpu_offload=state['image_cpu_offload'],
+ compile_model=state['image_compile']
)
-
if result is None:
return [], f"Failed to load model `{model_name}`."
+ shared.image_model_name = model_name
+
+ seed = state['image_seed']
if seed == -1:
seed = np.random.randint(0, 2**32 - 1)
generator = torch.Generator("cuda").manual_seed(int(seed))
all_images = []
- # Sequential loop (easier on VRAM)
- for i in range(int(batch_count_seq)):
- current_seed = seed + i
- generator.manual_seed(int(current_seed))
-
- # Parallel generation
+ for i in range(int(state['image_batch_count'])):
+ generator.manual_seed(int(seed + i))
batch_results = shared.image_model(
- prompt=prompt,
- negative_prompt=neg_prompt,
- height=int(height),
- width=int(width),
- num_inference_steps=int(steps),
+ prompt=state['image_prompt'],
+ negative_prompt=state['image_neg_prompt'],
+ height=int(state['image_height']),
+ width=int(state['image_width']),
+ num_inference_steps=int(state['image_steps']),
guidance_scale=0.0,
- num_images_per_prompt=int(batch_size_parallel),
+ num_images_per_prompt=int(state['image_batch_size']),
generator=generator,
).images
-
all_images.extend(batch_results)
- # Save to disk
- save_generated_images(all_images, prompt, seed)
-
+ save_generated_images(all_images, state['image_prompt'], seed)
return all_images, f"Seed: {seed}"
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model):
- """Load model and save settings."""
- if model_name == 'None' or not model_name:
+ if not model_name or model_name == 'None':
yield "No model selected"
return
try:
yield f"Loading `{model_name}`..."
-
- # Unload existing model first
unload_image_model()
- # Load the new model
result = load_image_model(
model_name,
dtype=dtype,
@@ -411,29 +352,22 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi
)
if result is not None:
- # Save settings to yaml
- save_image_model_settings(model_name, dtype, attn_backend, cpu_offload, compile_model)
+ shared.image_model_name = model_name
yield f"✓ Loaded **{model_name}**"
else:
yield f"✗ Failed to load `{model_name}`"
-
except Exception:
- exc = traceback.format_exc()
- yield f"Error:\n```\n{exc}\n```"
+ yield f"Error:\n```\n{traceback.format_exc()}\n```"
def unload_image_model_wrapper():
- """Unload model wrapper."""
unload_image_model()
-
if shared.image_model_name != 'None':
return f"Model: **{shared.image_model_name}** (not loaded)"
- else:
- return "No model loaded"
+ return "No model loaded"
def download_image_model_wrapper(model_path):
- """Download a model from Hugging Face."""
from huggingface_hub import snapshot_download
if not model_path:
@@ -441,17 +375,15 @@ def download_image_model_wrapper(model_path):
return
try:
- # Parse model name and branch
if ':' in model_path:
model_id, branch = model_path.rsplit(':', 1)
else:
model_id, branch = model_path, 'main'
- # Output folder name (username_model format)
folder_name = model_id.replace('/', '_')
output_folder = Path(shared.args.image_model_dir) / folder_name
- yield f"Downloading `{model_id}` (branch: {branch}) to `{output_folder}`...", gr.update()
+ yield f"Downloading `{model_id}` (branch: {branch})...", gr.update()
snapshot_download(
repo_id=model_id,
@@ -460,48 +392,34 @@ def download_image_model_wrapper(model_path):
local_dir_use_symlinks=False,
)
- # Refresh the model list
new_choices = utils.get_available_image_models()
-
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
-
except Exception:
- exc = traceback.format_exc()
- yield f"Error:\n```\n{exc}\n```", gr.update()
+ yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update()
def save_generated_images(images, prompt, seed):
- """Save generated images to disk."""
date_str = datetime.now().strftime("%Y-%m-%d")
folder_path = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder_path, exist_ok=True)
- saved_paths = []
-
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
filename = f"{timestamp}_{seed}_{idx}.png"
- full_path = os.path.join(folder_path, filename)
-
- img.save(full_path)
- saved_paths.append(full_path)
-
- return saved_paths
+ img.save(os.path.join(folder_path, filename))
def get_history_images():
- """Scan the outputs folder and return all images, newest first."""
output_dir = os.path.join("user_data", "image_outputs")
if not os.path.exists(output_dir):
return []
image_files = []
- for root, dirs, files in os.walk(output_dir):
+ for root, _, files in os.walk(output_dir):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
full_path = os.path.join(root, file)
- mtime = os.path.getmtime(full_path)
- image_files.append((full_path, mtime))
+ image_files.append((full_path, os.path.getmtime(full_path)))
image_files.sort(key=lambda x: x[1], reverse=True)
return [x[0] for x in image_files]
diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py
index a5e0f640..86adc229 100644
--- a/modules/ui_model_menu.py
+++ b/modules/ui_model_menu.py
@@ -434,66 +434,3 @@ def format_file_size(size_bytes):
return f"{s:.2f} {size_names[i]}"
else:
return f"{s:.1f} {size_names[i]}"
-
-
-def load_image_model_wrapper(selected_model):
- """Wrapper for loading image models with status updates."""
- from modules.image_models import load_image_model, unload_image_model
-
- if selected_model == 'None' or not selected_model:
- yield "No model selected"
- return
-
- try:
- yield f"Loading `{selected_model}`..."
- unload_image_model()
- result = load_image_model(selected_model)
-
- if result is not None:
- yield f"Successfully loaded `{selected_model}`."
- else:
- yield f"Failed to load `{selected_model}`."
- except Exception:
- exc = traceback.format_exc()
- yield exc.replace('\n', '\n\n')
-
-
-def handle_unload_image_model_click():
- """Handler for the image model unload button."""
- from modules.image_models import unload_image_model
- unload_image_model()
- return "Image model unloaded"
-
-
-def download_image_model_wrapper(custom_model):
- """Download an image model from Hugging Face."""
- from huggingface_hub import snapshot_download
-
- if not custom_model:
- yield "No model specified"
- return
-
- try:
- # Parse model name and branch
- if ':' in custom_model:
- model_name, branch = custom_model.rsplit(':', 1)
- else:
- model_name, branch = custom_model, 'main'
-
- # Output folder
- output_folder = Path(shared.args.image_model_dir) / model_name.split('/')[-1]
-
- yield f"Downloading `{model_name}` (branch: {branch})..."
-
- snapshot_download(
- repo_id=model_name,
- revision=branch,
- local_dir=output_folder,
- local_dir_use_symlinks=False,
- )
-
- yield f"Model successfully saved to `{output_folder}/`."
-
- except Exception:
- exc = traceback.format_exc()
- yield exc.replace('\n', '\n\n')
diff --git a/server.py b/server.py
index 87bbdc4a..5a75e887 100644
--- a/server.py
+++ b/server.py
@@ -172,6 +172,7 @@ def create_interface():
ui_chat.create_event_handlers()
ui_default.create_event_handlers()
ui_notebook.create_event_handlers()
+ ui_image_generation.create_event_handlers()
# Other events
ui_file_saving.create_event_handlers()
@@ -258,6 +259,9 @@ if __name__ == "__main__":
if new_settings:
shared.settings.update(new_settings)
+ # Apply CLI overrides for image model settings (CLI flags take precedence over saved settings)
+ shared.apply_image_model_cli_overrides()
+
# Fallback settings for models
shared.model_config['.*'] = get_fallback_settings()
shared.model_config.move_to_end('.*', last=False) # Move to the beginning
From 5b385dc546467e08543eafdc81c8ae955a6c5ba5 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 10:48:55 -0800
Subject: [PATCH 19/38] Make the image galleries taller
---
css/main.css | 8 ++++++++
modules/ui_image_generation.py | 4 ++--
2 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/css/main.css b/css/main.css
index 565aaf49..317a31f1 100644
--- a/css/main.css
+++ b/css/main.css
@@ -1681,3 +1681,11 @@ button#swap-height-width {
right: 0;
border: 0;
}
+
+#image-output-gallery {
+ height: calc(100vh - 105px);
+}
+
+#image-history-gallery {
+ height: calc(100vh - 139px);
+}
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index fe0c8120..6f251b0e 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -145,7 +145,7 @@ def create_ui():
with gr.Column(scale=6, min_width=500):
with gr.Column(elem_classes=["viewport-container"]):
- shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True)
+ shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery")
with gr.Row():
shared.gradio['image_used_seed'] = gr.Markdown(label="Info", interactive=False)
@@ -153,7 +153,7 @@ def create_ui():
with gr.TabItem("Gallery"):
with gr.Row():
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
- shared.gradio['image_history_gallery'] = gr.Gallery(label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True)
+ shared.gradio['image_history_gallery'] = gr.Gallery(label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True, elem_id="image-history-gallery")
# TAB 3: MODEL
with gr.TabItem("Model"):
From e301dd231e4c33d63b1096d7af3816dab4d29bae Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 10:49:22 -0800
Subject: [PATCH 20/38] Remove some emojis
---
modules/ui_image_generation.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 6f251b0e..b71c9f46 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -118,7 +118,7 @@ def create_ui():
shared.gradio['image_generate_btn'] = gr.Button("✨ GENERATE", variant="primary", size="lg", elem_id="gen-btn")
gr.HTML("
")
- gr.Markdown("### 📐 Dimensions")
+ gr.Markdown("### Dimensions")
with gr.Row():
with gr.Column():
shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=32, label="Width")
@@ -134,7 +134,7 @@ def create_ui():
interactive=True
)
- gr.Markdown("### ⚙️ Config")
+ gr.Markdown("### Config")
with gr.Row():
with gr.Column():
shared.gradio['image_steps'] = gr.Slider(1, 15, value=shared.settings['image_steps'], step=1, label="Steps")
From 9b07a8333075a48ae5b4a52e224f4ce68a127c7d Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 10:51:12 -0800
Subject: [PATCH 21/38] Populate the history gallery by default
---
modules/ui_image_generation.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index b71c9f46..435a0300 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -153,7 +153,7 @@ def create_ui():
with gr.TabItem("Gallery"):
with gr.Row():
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
- shared.gradio['image_history_gallery'] = gr.Gallery(label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True, elem_id="image-history-gallery")
+ shared.gradio['image_history_gallery'] = gr.Gallery(value=lambda : get_history_images(), label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True, elem_id="image-history-gallery")
# TAB 3: MODEL
with gr.TabItem("Model"):
From 366fe353f0379b60c2ce2b8d8237b18f9f081d8b Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 10:53:17 -0800
Subject: [PATCH 22/38] Revert CSS changes
---
css/main.css | 8 --------
1 file changed, 8 deletions(-)
diff --git a/css/main.css b/css/main.css
index 317a31f1..565aaf49 100644
--- a/css/main.css
+++ b/css/main.css
@@ -1681,11 +1681,3 @@ button#swap-height-width {
right: 0;
border: 0;
}
-
-#image-output-gallery {
- height: calc(100vh - 105px);
-}
-
-#image-history-gallery {
- height: calc(100vh - 139px);
-}
From 990f0e2468660c2bcc9b95c8a366fcfeffa5f38a Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 13:29:22 -0800
Subject: [PATCH 23/38] Revert "Revert CSS changes"
This reverts commit 366fe353f0379b60c2ce2b8d8237b18f9f081d8b.
---
css/main.css | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/css/main.css b/css/main.css
index 565aaf49..317a31f1 100644
--- a/css/main.css
+++ b/css/main.css
@@ -1681,3 +1681,11 @@ button#swap-height-width {
right: 0;
border: 0;
}
+
+#image-output-gallery {
+ height: calc(100vh - 105px);
+}
+
+#image-history-gallery {
+ height: calc(100vh - 139px);
+}
From 75796f5a5828eef10c35d4a67a6984cb07279e89 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 13:44:18 -0800
Subject: [PATCH 24/38] Set gallery heights
---
css/main.css | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/css/main.css b/css/main.css
index 317a31f1..95fbdcd7 100644
--- a/css/main.css
+++ b/css/main.css
@@ -1682,10 +1682,12 @@ button#swap-height-width {
border: 0;
}
-#image-output-gallery {
- height: calc(100vh - 105px);
+#image-output-gallery, #image-output-gallery > :nth-child(2) {
+ height: calc(100vh - 128px);
+ max-height: calc(100vh - 128px);
}
-#image-history-gallery {
+#image-history-gallery, #image-history-gallery > :nth-child(2) {
height: calc(100vh - 139px);
+ max-height: calc(100vh - 139px);
}
From b4738beaf8eab9766f111fe92134ea08f65f75b4 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 13:59:10 -0800
Subject: [PATCH 25/38] Remove the seed UI element
---
modules/ui_image_generation.py | 22 ++++++++++++++--------
1 file changed, 14 insertions(+), 8 deletions(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 435a0300..d9e79973 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -1,4 +1,5 @@
import os
+import time
import traceback
from datetime import datetime
from pathlib import Path
@@ -9,6 +10,7 @@ import torch
from modules import shared, ui, utils
from modules.image_models import load_image_model, unload_image_model
+from modules.logging_colors import logger
from modules.utils import gradio
ASPECT_RATIOS = {
@@ -146,8 +148,6 @@ def create_ui():
with gr.Column(scale=6, min_width=500):
with gr.Column(elem_classes=["viewport-container"]):
shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery")
- with gr.Row():
- shared.gradio['image_used_seed'] = gr.Markdown(label="Info", interactive=False)
# TAB 2: GALLERY
with gr.TabItem("Gallery"):
@@ -242,15 +242,15 @@ def create_event_handlers():
# Generation
shared.gradio['image_generate_btn'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
- generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
+ generate, gradio('interface_state'), gradio('image_output_gallery'))
shared.gradio['image_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
- generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
+ generate, gradio('interface_state'), gradio('image_output_gallery'))
shared.gradio['image_neg_prompt'].submit(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
- generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
+ generate, gradio('interface_state'), gradio('image_output_gallery'))
# Model management
shared.gradio['image_refresh_models'].click(
@@ -294,7 +294,8 @@ def generate(state):
model_name = state['image_model_menu']
if not model_name or model_name == 'None':
- return [], "No image model selected. Go to the Model tab and select a model."
+ logger.error("No image model selected. Go to the Model tab and select a model.")
+ return []
if shared.image_model is None:
result = load_image_model(
@@ -305,7 +306,8 @@ def generate(state):
compile_model=state['image_compile']
)
if result is None:
- return [], f"Failed to load model `{model_name}`."
+ logger.error(f"Failed to load model `{model_name}`.")
+ return []
shared.image_model_name = model_name
@@ -316,6 +318,7 @@ def generate(state):
generator = torch.Generator("cuda").manual_seed(int(seed))
all_images = []
+ t0 = time.time()
for i in range(int(state['image_batch_count'])):
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(
@@ -330,8 +333,11 @@ def generate(state):
).images
all_images.extend(batch_results)
+ t1 = time.time()
save_generated_images(all_images, state['image_prompt'], seed)
- return all_images, f"Seed: {seed}"
+
+ logger.info(f'Images generated in {(t1-t0):.2f} seconds (seed {seed})')
+ return all_images
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model):
From c8e9d7fc37eab8fe1ff1f36267c9824dca4e4b69 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 14:00:41 -0800
Subject: [PATCH 26/38] Fix the gallery height after the previous commit
---
css/main.css | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/css/main.css b/css/main.css
index 95fbdcd7..26687eb4 100644
--- a/css/main.css
+++ b/css/main.css
@@ -1683,8 +1683,8 @@ button#swap-height-width {
}
#image-output-gallery, #image-output-gallery > :nth-child(2) {
- height: calc(100vh - 128px);
- max-height: calc(100vh - 128px);
+ height: calc(100vh - 83px);
+ max-height: calc(100vh - 83px);
}
#image-history-gallery, #image-history-gallery > :nth-child(2) {
From 6a7209a8422c94dde56f4638c233532c2e7ce002 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 15:41:58 -0800
Subject: [PATCH 27/38] Add PNG metadata, add pagination to Gallery tab
---
css/main.css | 50 ++++-
modules/ui_image_generation.py | 350 +++++++++++++++++++++++++++++----
2 files changed, 363 insertions(+), 37 deletions(-)
diff --git a/css/main.css b/css/main.css
index 26687eb4..0bfdca0a 100644
--- a/css/main.css
+++ b/css/main.css
@@ -1688,6 +1688,52 @@ button#swap-height-width {
}
#image-history-gallery, #image-history-gallery > :nth-child(2) {
- height: calc(100vh - 139px);
- max-height: calc(100vh - 139px);
+ height: calc(100vh - 174px);
+ max-height: calc(100vh - 174px);
+}
+
+/* Additional CSS for the paginated image gallery */
+
+/* Page info styling */
+#image-page-info {
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ min-width: 200px;
+ font-size: 0.9em;
+ color: var(--body-text-color-subdued);
+}
+
+/* Settings display panel */
+#image-ai-tab .settings-display-panel {
+ background: var(--background-fill-secondary);
+ padding: 12px;
+ border-radius: 8px;
+ font-size: 0.9em;
+ max-height: 300px;
+ overflow-y: auto;
+ margin-top: 8px;
+}
+
+/* Gallery status message */
+#image-ai-tab .gallery-status {
+ color: var(--color-accent);
+ font-size: 0.85em;
+ margin-top: 4px;
+}
+
+/* Pagination button row alignment */
+#image-ai-tab .pagination-controls {
+ display: flex;
+ align-items: center;
+ gap: 8px;
+ flex-wrap: wrap;
+}
+
+/* Selected image preview container */
+#image-ai-tab .selected-preview-container {
+ border: 1px solid var(--border-color-primary);
+ border-radius: 8px;
+ padding: 8px;
+ background: var(--background-fill-secondary);
}
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index d9e79973..b202f6cc 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -1,3 +1,4 @@
+import json
import os
import time
import traceback
@@ -7,6 +8,8 @@ from pathlib import Path
import gradio as gr
import numpy as np
import torch
+from PIL import Image
+from PIL.PngImagePlugin import PngInfo
from modules import shared, ui, utils
from modules.image_models import load_image_model, unload_image_model
@@ -22,6 +25,24 @@ ASPECT_RATIOS = {
}
STEP = 32
+IMAGES_PER_PAGE = 64
+
+# Settings keys to save in PNG metadata (Generate tab only)
+METADATA_SETTINGS_KEYS = [
+ 'image_prompt',
+ 'image_neg_prompt',
+ 'image_width',
+ 'image_height',
+ 'image_aspect_ratio',
+ 'image_steps',
+ 'image_seed',
+ 'image_batch_size',
+ 'image_batch_count',
+]
+
+# Cache for all image paths
+_image_cache = []
+_cache_timestamp = 0
def round_to_step(value, step=STEP):
@@ -93,6 +114,215 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
return new_width, new_height, new_ratio
+def build_generation_metadata(state, actual_seed):
+ """Build metadata dict from generation settings."""
+ metadata = {}
+ for key in METADATA_SETTINGS_KEYS:
+ if key in state:
+ metadata[key] = state[key]
+
+ # Store the actual seed used (not -1)
+ metadata['image_seed'] = actual_seed
+ metadata['generated_at'] = datetime.now().isoformat()
+ metadata['model'] = shared.image_model_name
+
+ return metadata
+
+
+def save_generated_images(images, state, actual_seed):
+ """Save images with generation metadata embedded in PNG."""
+ date_str = datetime.now().strftime("%Y-%m-%d")
+ folder_path = os.path.join("user_data", "image_outputs", date_str)
+ os.makedirs(folder_path, exist_ok=True)
+
+ metadata = build_generation_metadata(state, actual_seed)
+ metadata_json = json.dumps(metadata, ensure_ascii=False)
+
+ for idx, img in enumerate(images):
+ timestamp = datetime.now().strftime("%H-%M-%S")
+ filename = f"{timestamp}_{actual_seed}_{idx}.png"
+ filepath = os.path.join(folder_path, filename)
+
+ # Create PNG metadata
+ png_info = PngInfo()
+ png_info.add_text("image_gen_settings", metadata_json)
+
+ # Save with metadata
+ img.save(filepath, pnginfo=png_info)
+
+
+def read_image_metadata(image_path):
+ """Read generation metadata from PNG file."""
+ try:
+ with Image.open(image_path) as img:
+ if hasattr(img, 'text') and 'image_gen_settings' in img.text:
+ return json.loads(img.text['image_gen_settings'])
+ except Exception as e:
+ logger.debug(f"Could not read metadata from {image_path}: {e}")
+ return None
+
+
+def format_metadata_for_display(metadata):
+ """Format metadata as readable text."""
+ if not metadata:
+ return "No generation settings found in this image."
+
+ lines = ["**Generation Settings**", ""]
+
+ # Display in a nice order
+ display_order = [
+ ('image_prompt', 'Prompt'),
+ ('image_neg_prompt', 'Negative Prompt'),
+ ('image_width', 'Width'),
+ ('image_height', 'Height'),
+ ('image_aspect_ratio', 'Aspect Ratio'),
+ ('image_steps', 'Steps'),
+ ('image_seed', 'Seed'),
+ ('image_batch_size', 'Batch Size'),
+ ('image_batch_count', 'Batch Count'),
+ ('model', 'Model'),
+ ('generated_at', 'Generated At'),
+ ]
+
+ for key, label in display_order:
+ if key in metadata:
+ value = metadata[key]
+ if key in ['image_prompt', 'image_neg_prompt'] and value:
+ # Truncate long prompts for display
+ if len(str(value)) > 200:
+ value = str(value)[:200] + "..."
+ lines.append(f"**{label}:** {value}")
+
+ return "\n\n".join(lines)
+
+
+def get_all_history_images(force_refresh=False):
+ """Get all history images sorted by modification time (newest first). Uses caching."""
+ global _image_cache, _cache_timestamp
+
+ output_dir = os.path.join("user_data", "image_outputs")
+ if not os.path.exists(output_dir):
+ return []
+
+ # Check if we need to refresh cache
+ current_time = time.time()
+ if not force_refresh and _image_cache and (current_time - _cache_timestamp) < 2:
+ return _image_cache
+
+ image_files = []
+ for root, _, files in os.walk(output_dir):
+ for file in files:
+ if file.endswith((".png", ".jpg", ".jpeg")):
+ full_path = os.path.join(root, file)
+ image_files.append((full_path, os.path.getmtime(full_path)))
+
+ image_files.sort(key=lambda x: x[1], reverse=True)
+ _image_cache = [x[0] for x in image_files]
+ _cache_timestamp = current_time
+
+ return _image_cache
+
+
+def get_paginated_images(page=0, force_refresh=False):
+ """Get images for a specific page."""
+ all_images = get_all_history_images(force_refresh)
+ total_images = len(all_images)
+ total_pages = max(1, (total_images + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE)
+
+ # Clamp page to valid range
+ page = max(0, min(page, total_pages - 1))
+
+ start_idx = page * IMAGES_PER_PAGE
+ end_idx = min(start_idx + IMAGES_PER_PAGE, total_images)
+
+ page_images = all_images[start_idx:end_idx]
+
+ return page_images, page, total_pages, total_images
+
+
+def refresh_gallery(current_page=0):
+ """Refresh gallery with current page."""
+ images, page, total_pages, total_images = get_paginated_images(current_page, force_refresh=True)
+ page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
+ return images, page, page_info
+
+
+def go_to_page(page_num, current_page):
+ """Go to a specific page (1-indexed input)."""
+ try:
+ page = int(page_num) - 1 # Convert to 0-indexed
+ except (ValueError, TypeError):
+ page = current_page
+
+ images, page, total_pages, total_images = get_paginated_images(page)
+ page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
+ return images, page, page_info
+
+
+def next_page(current_page):
+ """Go to next page."""
+ images, page, total_pages, total_images = get_paginated_images(current_page + 1)
+ page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
+ return images, page, page_info
+
+
+def prev_page(current_page):
+ """Go to previous page."""
+ images, page, total_pages, total_images = get_paginated_images(current_page - 1)
+ page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
+ return images, page, page_info
+
+
+def on_gallery_select(evt: gr.SelectData, current_page):
+ """Handle image selection from gallery."""
+ if evt.index is None:
+ return "", "Select an image to view its settings"
+
+ # Get the current page's images to find the actual file path
+ all_images = get_all_history_images()
+ total_images = len(all_images)
+
+ # Calculate the actual index in the full list
+ start_idx = current_page * IMAGES_PER_PAGE
+ actual_idx = start_idx + evt.index
+
+ if actual_idx >= total_images:
+ return "", "Image not found"
+
+ image_path = all_images[actual_idx]
+ metadata = read_image_metadata(image_path)
+ metadata_display = format_metadata_for_display(metadata)
+
+ return image_path, metadata_display
+
+
+def send_to_generate(selected_image_path):
+ """Load settings from selected image and return updates for all Generate tab inputs."""
+ if not selected_image_path or not os.path.exists(selected_image_path):
+ return [gr.update()] * 9 + ["No image selected"]
+
+ metadata = read_image_metadata(selected_image_path)
+ if not metadata:
+ return [gr.update()] * 9 + ["No settings found in this image"]
+
+ # Return updates for each input element in order
+ updates = [
+ gr.update(value=metadata.get('image_prompt', '')),
+ gr.update(value=metadata.get('image_neg_prompt', '')),
+ gr.update(value=metadata.get('image_width', 1024)),
+ gr.update(value=metadata.get('image_height', 1024)),
+ gr.update(value=metadata.get('image_aspect_ratio', '1:1 Square')),
+ gr.update(value=metadata.get('image_steps', 9)),
+ gr.update(value=metadata.get('image_seed', -1)),
+ gr.update(value=metadata.get('image_batch_size', 1)),
+ gr.update(value=metadata.get('image_batch_count', 1)),
+ ]
+
+ status = f"✓ Settings loaded from image (seed: {metadata.get('image_seed', 'unknown')})"
+ return updates + [status]
+
+
+
def create_ui():
if shared.settings['image_model_menu'] != 'None':
shared.image_model_name = shared.settings['image_model_menu']
@@ -149,11 +379,40 @@ def create_ui():
with gr.Column(elem_classes=["viewport-container"]):
shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery")
- # TAB 2: GALLERY
+ # TAB 2: GALLERY (with pagination)
with gr.TabItem("Gallery"):
with gr.Row():
- shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
- shared.gradio['image_history_gallery'] = gr.Gallery(value=lambda : get_history_images(), label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True, elem_id="image-history-gallery")
+ with gr.Column(scale=3):
+ # Pagination controls
+ with gr.Row():
+ shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh", elem_classes="refresh-button")
+ shared.gradio['image_prev_page'] = gr.Button("◀ Prev", elem_classes="refresh-button")
+ shared.gradio['image_page_info'] = gr.Markdown("Loading...", elem_id="image-page-info")
+ shared.gradio['image_next_page'] = gr.Button("Next ▶", elem_classes="refresh-button")
+ shared.gradio['image_page_input'] = gr.Number(value=1, label="Page", precision=0, minimum=1, scale=0, min_width=80)
+ shared.gradio['image_go_to_page'] = gr.Button("Go", elem_classes="refresh-button", scale=0, min_width=50)
+
+ # State for current page and selected image path
+ shared.gradio['image_current_page'] = gr.State(value=0)
+ shared.gradio['image_selected_path'] = gr.State(value="")
+
+ # Paginated gallery using gr.Gallery
+ shared.gradio['image_history_gallery'] = gr.Gallery(
+ value=lambda: get_paginated_images(0)[0],
+ label="Image History",
+ show_label=False,
+ columns=6,
+ object_fit="cover",
+ height="auto",
+ allow_preview=True,
+ elem_id="image-history-gallery"
+ )
+
+ with gr.Column(scale=1):
+ gr.Markdown("### Selected Image")
+ shared.gradio['image_settings_display'] = gr.Markdown("Select an image to view its settings")
+ shared.gradio['image_send_to_generate'] = gr.Button("📤 Send to Generate", variant="primary")
+ shared.gradio['image_gallery_status'] = gr.Markdown("")
# TAB 3: MODEL
with gr.TabItem("Model"):
@@ -281,11 +540,59 @@ def create_event_handlers():
show_progress=True
)
- # History
+ # Gallery pagination handlers
shared.gradio['image_refresh_history'].click(
- get_history_images,
- None,
- gradio('image_history_gallery'),
+ refresh_gallery,
+ gradio('image_current_page'),
+ gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
+ show_progress=False
+ )
+
+ shared.gradio['image_next_page'].click(
+ next_page,
+ gradio('image_current_page'),
+ gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
+ show_progress=False
+ )
+
+ shared.gradio['image_prev_page'].click(
+ prev_page,
+ gradio('image_current_page'),
+ gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
+ show_progress=False
+ )
+
+ shared.gradio['image_go_to_page'].click(
+ go_to_page,
+ gradio('image_page_input', 'image_current_page'),
+ gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
+ show_progress=False
+ )
+
+ # Image selection from gallery
+ shared.gradio['image_history_gallery'].select(
+ on_gallery_select,
+ gradio('image_current_page'),
+ gradio('image_selected_path', 'image_settings_display'),
+ show_progress=False
+ )
+
+ # Send to Generate
+ shared.gradio['image_send_to_generate'].click(
+ send_to_generate,
+ gradio('image_selected_path'),
+ gradio(
+ 'image_prompt',
+ 'image_neg_prompt',
+ 'image_width',
+ 'image_height',
+ 'image_aspect_ratio',
+ 'image_steps',
+ 'image_seed',
+ 'image_batch_size',
+ 'image_batch_count',
+ 'image_gallery_status'
+ ),
show_progress=False
)
@@ -334,7 +641,7 @@ def generate(state):
all_images.extend(batch_results)
t1 = time.time()
- save_generated_images(all_images, state['image_prompt'], seed)
+ save_generated_images(all_images, state, seed)
logger.info(f'Images generated in {(t1-t0):.2f} seconds (seed {seed})')
return all_images
@@ -402,30 +709,3 @@ def download_image_model_wrapper(model_path):
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
except Exception:
yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update()
-
-
-def save_generated_images(images, prompt, seed):
- date_str = datetime.now().strftime("%Y-%m-%d")
- folder_path = os.path.join("user_data", "image_outputs", date_str)
- os.makedirs(folder_path, exist_ok=True)
-
- for idx, img in enumerate(images):
- timestamp = datetime.now().strftime("%H-%M-%S")
- filename = f"{timestamp}_{seed}_{idx}.png"
- img.save(os.path.join(folder_path, filename))
-
-
-def get_history_images():
- output_dir = os.path.join("user_data", "image_outputs")
- if not os.path.exists(output_dir):
- return []
-
- image_files = []
- for root, _, files in os.walk(output_dir):
- for file in files:
- if file.endswith((".png", ".jpg", ".jpeg")):
- full_path = os.path.join(root, file)
- image_files.append((full_path, os.path.getmtime(full_path)))
-
- image_files.sort(key=lambda x: x[1], reverse=True)
- return [x[0] for x in image_files]
From 748e2e55fd7b771e45eab8b8cd4b0e5a02de81ba Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 15:44:31 -0800
Subject: [PATCH 28/38] Add steps/second info to log message
---
modules/ui_image_generation.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index b202f6cc..161c8c60 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -643,7 +643,7 @@ def generate(state):
t1 = time.time()
save_generated_images(all_images, state, seed)
- logger.info(f'Images generated in {(t1-t0):.2f} seconds (seed {seed})')
+ logger.info(f'Images generated in {(t1-t0):.2f} seconds ({state["image_steps"]/(t1-t0):.2f} steps/s, seed {seed})')
return all_images
From a7808f7f422def956371e1f194a67205d8d29c26 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 16:02:35 -0800
Subject: [PATCH 29/38] Make filenames always have the same size
---
modules/ui_image_generation.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 161c8c60..ff1b9f67 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -140,7 +140,7 @@ def save_generated_images(images, state, actual_seed):
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
- filename = f"{timestamp}_{actual_seed}_{idx}.png"
+ filename = f"{timestamp}_{actual_seed:010d}_{idx:03d}.png"
filepath = os.path.join(folder_path, filename)
# Create PNG metadata
From 7dfb6e9c57c6ac45c2c77409e032c59739b7724b Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 17:05:42 -0800
Subject: [PATCH 30/38] Add quantization options (bnb and quanto)
---
modules/image_models.py | 104 ++++++++++++++++++++++++++-------
modules/shared.py | 13 ++++-
modules/ui.py | 2 +
modules/ui_image_generation.py | 19 ++++--
4 files changed, 109 insertions(+), 29 deletions(-)
diff --git a/modules/image_models.py b/modules/image_models.py
index 9e2075fd..de3743bf 100644
--- a/modules/image_models.py
+++ b/modules/image_models.py
@@ -8,7 +8,75 @@ from modules.torch_utils import get_device
from modules.utils import resolve_model_path
-def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False):
+def get_quantization_config(quant_method):
+ """
+ Get the appropriate quantization config based on the selected method.
+
+ Args:
+ quant_method: One of 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
+
+ Returns:
+ PipelineQuantizationConfig or None
+ """
+ from diffusers.quantizers import PipelineQuantizationConfig
+ from diffusers import BitsAndBytesConfig, QuantoConfig
+
+ if quant_method == 'none' or not quant_method:
+ return None
+
+ # Bitsandbytes 8-bit quantization
+ elif quant_method == 'bnb-8bit':
+ return PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": BitsAndBytesConfig(
+ load_in_8bit=True
+ )
+ }
+ )
+
+ # Bitsandbytes 4-bit quantization
+ elif quant_method == 'bnb-4bit':
+ return PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.bfloat16,
+ bnb_4bit_use_double_quant=True
+ )
+ }
+ )
+
+ # Quanto 8-bit quantization
+ elif quant_method == 'quanto-8bit':
+ return PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": QuantoConfig(weights_dtype="int8")
+ }
+ )
+
+ # Quanto 4-bit quantization
+ elif quant_method == 'quanto-4bit':
+ return PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": QuantoConfig(weights_dtype="int4")
+ }
+ )
+
+ # Quanto 2-bit quantization
+ elif quant_method == 'quanto-2bit':
+ return PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": QuantoConfig(weights_dtype="int2")
+ }
+ )
+
+ else:
+ logger.warning(f"Unknown quantization method: {quant_method}. Loading without quantization.")
+ return None
+
+
+def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False, quant_method='none'):
"""
Load a diffusers image generation model.
@@ -18,10 +86,11 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3'
cpu_offload: Enable CPU offloading for low VRAM
compile_model: Compile the model for faster inference (slow first run)
+ quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
"""
- from diffusers import PipelineQuantizationConfig, ZImagePipeline
+ from diffusers import ZImagePipeline
- logger.info(f"Loading image model \"{model_name}\"")
+ logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}")
t0 = time.time()
dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
@@ -30,28 +99,21 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
model_path = resolve_model_path(model_name, image_model=True)
try:
- # Define quantization config for 8-bit
- pipeline_quant_config = PipelineQuantizationConfig(
- quant_backend="bitsandbytes_8bit",
- quant_kwargs={"load_in_8bit": True},
- )
+ # Get quantization config based on selected method
+ pipeline_quant_config = get_quantization_config(quant_method)
- # Define quantization config for 4-bit
- # pipeline_quant_config = PipelineQuantizationConfig(
- # quant_backend="bitsandbytes_4bit",
- # quant_kwargs={
- # "load_in_4bit": True,
- # "bnb_4bit_quant_type": "nf4", # Or "fp4" for floating point
- # "bnb_4bit_compute_dtype": torch.bfloat16, # For faster computation
- # "bnb_4bit_use_double_quant": True, # Nested quantization for extra savings
- # },
- # )
+ # Load the pipeline
+ load_kwargs = {
+ "torch_dtype": target_dtype,
+ "low_cpu_mem_usage": True,
+ }
+
+ if pipeline_quant_config is not None:
+ load_kwargs["quantization_config"] = pipeline_quant_config
pipe = ZImagePipeline.from_pretrained(
str(model_path),
- quantization_config=pipeline_quant_config,
- torch_dtype=target_dtype,
- low_cpu_mem_usage=True,
+ **load_kwargs
)
if not cpu_offload:
diff --git a/modules/shared.py b/modules/shared.py
index 9a062e91..d33aa717 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -58,6 +58,9 @@ group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16',
group.add_argument('--image-attn-backend', type=str, default=None, choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.')
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.')
+group.add_argument('--image-quant', type=str, default=None,
+ choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'],
+ help='Quantization method for image model.')
# Model loader
group = parser.add_argument_group('Model loader')
@@ -317,8 +320,9 @@ settings = {
'image_model_menu': 'None',
'image_dtype': 'bfloat16',
'image_attn_backend': 'sdpa',
- 'image_compile': False,
'image_cpu_offload': False,
+ 'image_compile': False,
+ 'image_quant': 'none',
}
default_settings = copy.deepcopy(settings)
@@ -344,8 +348,8 @@ def do_cmd_flags_warnings():
def apply_image_model_cli_overrides():
- """Apply CLI flags for image model settings, overriding saved settings."""
- if args.image_model:
+ """Apply command-line overrides for image model settings."""
+ if args.image_model is not None:
settings['image_model_menu'] = args.image_model
if args.image_dtype is not None:
settings['image_dtype'] = args.image_dtype
@@ -355,6 +359,9 @@ def apply_image_model_cli_overrides():
settings['image_cpu_offload'] = True
if args.image_compile:
settings['image_compile'] = True
+ if args.image_quant is not None:
+ settings['image_quant'] = args.image_quant
+
def fix_loader_name(name):
diff --git a/modules/ui.py b/modules/ui.py
index 3aba20b4..3bcba56b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -296,6 +296,7 @@ def list_interface_input_elements():
'image_attn_backend',
'image_compile',
'image_cpu_offload',
+ 'image_quant',
]
return elements
@@ -542,6 +543,7 @@ def setup_auto_save():
'image_attn_backend',
'image_compile',
'image_cpu_offload',
+ 'image_quant',
]
for element_name in change_elements:
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index ff1b9f67..a5cf3695 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -432,6 +432,13 @@ def create_ui():
gr.Markdown("## Settings")
with gr.Row():
with gr.Column():
+ shared.gradio['image_quant'] = gr.Dropdown(
+ label='Quantization',
+ choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'],
+ value=shared.settings['image_quant'],
+ info='Quantization method for reduced VRAM usage. Quanto supports lower precisions (2-bit, 4-bit, 8-bit).'
+ )
+
shared.gradio['image_dtype'] = gr.Dropdown(
choices=['bfloat16', 'float16'],
value=shared.settings['image_dtype'],
@@ -521,7 +528,7 @@ def create_event_handlers():
shared.gradio['image_load_model'].click(
load_image_model_wrapper,
- gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile'),
+ gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile', 'image_quant'),
gradio('image_model_status'),
show_progress=True
)
@@ -610,7 +617,8 @@ def generate(state):
dtype=state['image_dtype'],
attn_backend=state['image_attn_backend'],
cpu_offload=state['image_cpu_offload'],
- compile_model=state['image_compile']
+ compile_model=state['image_compile'],
+ quant_method=state['image_quant']
)
if result is None:
logger.error(f"Failed to load model `{model_name}`.")
@@ -647,7 +655,7 @@ def generate(state):
return all_images
-def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model):
+def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method):
if not model_name or model_name == 'None':
yield "No model selected"
return
@@ -661,12 +669,13 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi
dtype=dtype,
attn_backend=attn_backend,
cpu_offload=cpu_offload,
- compile_model=compile_model
+ compile_model=compile_model,
+ quant_method=quant_method
)
if result is not None:
shared.image_model_name = model_name
- yield f"✓ Loaded **{model_name}**"
+ yield f"✓ Loaded **{model_name}** (quantization: {quant_method})"
else:
yield f"✗ Failed to load `{model_name}`"
except Exception:
From 5fb1380ac1e057085151c44108d238559d6f8b3e Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 17:09:32 -0800
Subject: [PATCH 31/38] Handle URLs like https://huggingface.co/Qwen/Qwen-Image
---
modules/ui_image_generation.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index a5cf3695..c6ec76a3 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -697,6 +697,12 @@ def download_image_model_wrapper(model_path):
return
try:
+ model_path = model_path.strip()
+ if model_path.startswith('https://huggingface.co/'):
+ model_path = model_path[len('https://huggingface.co/'):]
+ elif model_path.startswith('huggingface.co/'):
+ model_path = model_path[len('huggingface.co/'):]
+
if ':' in model_path:
model_id, branch = model_path.rsplit(':', 1)
else:
From 225b8c326bf263bc6e5272584998ad1419d589af Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 17:13:16 -0800
Subject: [PATCH 32/38] Try to not break portable builds
---
modules/image_models.py | 6 +++---
modules/ui_image_generation.py | 3 ++-
server.py | 3 ++-
3 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/modules/image_models.py b/modules/image_models.py
index de3743bf..fe149253 100644
--- a/modules/image_models.py
+++ b/modules/image_models.py
@@ -1,7 +1,5 @@
import time
-import torch
-
import modules.shared as shared
from modules.logging_colors import logger
from modules.torch_utils import get_device
@@ -18,8 +16,9 @@ def get_quantization_config(quant_method):
Returns:
PipelineQuantizationConfig or None
"""
- from diffusers.quantizers import PipelineQuantizationConfig
+ import torch
from diffusers import BitsAndBytesConfig, QuantoConfig
+ from diffusers.quantizers import PipelineQuantizationConfig
if quant_method == 'none' or not quant_method:
return None
@@ -88,6 +87,7 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
compile_model: Compile the model for faster inference (slow first run)
quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
"""
+ import torch
from diffusers import ZImagePipeline
logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}")
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index c6ec76a3..09faf423 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -7,7 +7,6 @@ from pathlib import Path
import gradio as gr
import numpy as np
-import torch
from PIL import Image
from PIL.PngImagePlugin import PngInfo
@@ -605,6 +604,8 @@ def create_event_handlers():
def generate(state):
+ import torch
+
model_name = state['image_model_menu']
if not model_name or model_name == 'None':
diff --git a/server.py b/server.py
index 5a75e887..58b3d043 100644
--- a/server.py
+++ b/server.py
@@ -172,7 +172,8 @@ def create_interface():
ui_chat.create_event_handlers()
ui_default.create_event_handlers()
ui_notebook.create_event_handlers()
- ui_image_generation.create_event_handlers()
+ if not shared.args.portable:
+ ui_image_generation.create_event_handlers()
# Other events
ui_file_saving.create_event_handlers()
From f46f49e26c9bd945f73c2c9b1346803d67bc96b2 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 18:18:15 -0800
Subject: [PATCH 33/38] Initial Qwen-Image support
---
modules/image_models.py | 44 ++++++++++++++++++++++++++--------
modules/ui_image_generation.py | 40 +++++++++++++++++++++++--------
2 files changed, 64 insertions(+), 20 deletions(-)
diff --git a/modules/image_models.py b/modules/image_models.py
index fe149253..e4831758 100644
--- a/modules/image_models.py
+++ b/modules/image_models.py
@@ -75,6 +75,22 @@ def get_quantization_config(quant_method):
return None
+def get_pipeline_type(pipe):
+ """
+ Detect the pipeline type based on the loaded pipeline class.
+
+ Returns:
+ str: 'zimage', 'qwenimage', or 'unknown'
+ """
+ class_name = pipe.__class__.__name__
+ if 'ZImage' in class_name:
+ return 'zimage'
+ elif 'QwenImage' in class_name:
+ return 'qwenimage'
+ else:
+ return 'unknown'
+
+
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False, quant_method='none'):
"""
Load a diffusers image generation model.
@@ -88,7 +104,7 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
"""
import torch
- from diffusers import ZImagePipeline
+ from diffusers import DiffusionPipeline
logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}")
t0 = time.time()
@@ -111,30 +127,37 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
if pipeline_quant_config is not None:
load_kwargs["quantization_config"] = pipeline_quant_config
- pipe = ZImagePipeline.from_pretrained(
+ # Use DiffusionPipeline for automatic pipeline detection
+ # This handles both ZImagePipeline and QwenImagePipeline
+ pipe = DiffusionPipeline.from_pretrained(
str(model_path),
**load_kwargs
)
+ pipeline_type = get_pipeline_type(pipe)
+
if not cpu_offload:
pipe.to(get_device())
- # Set attention backend
- if attn_backend == 'flash_attention_2':
- pipe.transformer.set_attention_backend("flash")
- elif attn_backend == 'flash_attention_3':
- pipe.transformer.set_attention_backend("_flash_3")
- # sdpa is the default, no action needed
+ # Set attention backend (if supported by the pipeline)
+ if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'):
+ if attn_backend == 'flash_attention_2':
+ pipe.transformer.set_attention_backend("flash")
+ elif attn_backend == 'flash_attention_3':
+ pipe.transformer.set_attention_backend("_flash_3")
+ # sdpa is the default, no action needed
if compile_model:
- logger.info("Compiling model (first run will be slow)...")
- pipe.transformer.compile()
+ if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'compile'):
+ logger.info("Compiling model (first run will be slow)...")
+ pipe.transformer.compile()
if cpu_offload:
pipe.enable_model_cpu_offload()
shared.image_model = pipe
shared.image_model_name = model_name
+ shared.image_pipeline_type = pipeline_type
logger.info(f"Loaded image model \"{model_name}\" in {(time.time() - t0):.2f} seconds.")
return pipe
@@ -152,6 +175,7 @@ def unload_image_model():
del shared.image_model
shared.image_model = None
shared.image_model_name = 'None'
+ shared.image_pipeline_type = None
from modules.torch_utils import clear_torch_cache
clear_torch_cache()
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 09faf423..42c8c21f 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -604,7 +604,12 @@ def create_event_handlers():
def generate(state):
+ """
+ Generate images using the loaded model.
+ Automatically adjusts parameters based on pipeline type.
+ """
import torch
+ import numpy as np
model_name = state['image_model_menu']
@@ -634,19 +639,34 @@ def generate(state):
generator = torch.Generator("cuda").manual_seed(int(seed))
all_images = []
+ # Get pipeline type for parameter adjustment
+ pipeline_type = getattr(shared, 'image_pipeline_type', None)
+ if pipeline_type is None:
+ pipeline_type = get_pipeline_type(shared.image_model)
+
+ # Build generation kwargs based on pipeline type
+ gen_kwargs = {
+ "prompt": state['image_prompt'],
+ "negative_prompt": state['image_neg_prompt'],
+ "height": int(state['image_height']),
+ "width": int(state['image_width']),
+ "num_inference_steps": int(state['image_steps']),
+ "num_images_per_prompt": int(state['image_batch_size']),
+ "generator": generator,
+ }
+
+ # Add pipeline-specific parameters
+ if pipeline_type == 'qwenimage':
+ # Qwen-Image uses true_cfg_scale instead of guidance_scale
+ gen_kwargs["true_cfg_scale"] = state.get('image_cfg_scale', 4.0)
+ else:
+ # Z-Image and others use guidance_scale
+ gen_kwargs["guidance_scale"] = state.get('image_cfg_scale', 0.0)
+
t0 = time.time()
for i in range(int(state['image_batch_count'])):
generator.manual_seed(int(seed + i))
- batch_results = shared.image_model(
- prompt=state['image_prompt'],
- negative_prompt=state['image_neg_prompt'],
- height=int(state['image_height']),
- width=int(state['image_width']),
- num_inference_steps=int(state['image_steps']),
- guidance_scale=0.0,
- num_images_per_prompt=int(state['image_batch_size']),
- generator=generator,
- ).images
+ batch_results = shared.image_model(**gen_kwargs).images
all_images.extend(batch_results)
t1 = time.time()
From 322aab34105f4fe54361939f53af7c2e63b7df4b Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 18:20:47 -0800
Subject: [PATCH 34/38] Increase the image_steps maximum
---
modules/ui_image_generation.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 42c8c21f..749b8b95 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -368,7 +368,7 @@ def create_ui():
gr.Markdown("### Config")
with gr.Row():
with gr.Column():
- shared.gradio['image_steps'] = gr.Slider(1, 15, value=shared.settings['image_steps'], step=1, label="Steps")
+ shared.gradio['image_steps'] = gr.Slider(1, 100, value=shared.settings['image_steps'], step=1, label="Steps")
shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random")
with gr.Column():
shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
From 151b552bc38aa373805280b30dcac3e8b89fbd8f Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 18:24:02 -0800
Subject: [PATCH 35/38] Decrease the resolution step to allow for 1368
---
modules/ui_image_generation.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 749b8b95..defe1788 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -23,7 +23,7 @@ ASPECT_RATIOS = {
"Custom": None,
}
-STEP = 32
+STEP = 16
IMAGES_PER_PAGE = 64
# Settings keys to save in PNG metadata (Generate tab only)
From d75d7a3a63a2ab4bb455211331da1caeb61c6232 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 18:41:15 -0800
Subject: [PATCH 36/38] Add a CFG scale slider, add qwen3 magic
---
modules/shared.py | 1 +
modules/ui.py | 2 ++
modules/ui_image_generation.py | 46 +++++++++++++++++++++++++---------
3 files changed, 37 insertions(+), 12 deletions(-)
diff --git a/modules/shared.py b/modules/shared.py
index d33aa717..fef67489 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -314,6 +314,7 @@ settings = {
'image_height': 1024,
'image_aspect_ratio': '1:1 Square',
'image_steps': 9,
+ 'image_cfg_scale': 0.0,
'image_seed': -1,
'image_batch_size': 1,
'image_batch_count': 1,
diff --git a/modules/ui.py b/modules/ui.py
index 3bcba56b..ca45a444 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -288,6 +288,7 @@ def list_interface_input_elements():
'image_height',
'image_aspect_ratio',
'image_steps',
+ 'image_cfg_scale',
'image_seed',
'image_batch_size',
'image_batch_count',
@@ -535,6 +536,7 @@ def setup_auto_save():
'image_height',
'image_aspect_ratio',
'image_steps',
+ 'image_cfg_scale',
'image_seed',
'image_batch_size',
'image_batch_count',
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index defe1788..2aa9bcb4 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -37,6 +37,7 @@ METADATA_SETTINGS_KEYS = [
'image_seed',
'image_batch_size',
'image_batch_count',
+ 'image_cfg_scale',
]
# Cache for all image paths
@@ -176,6 +177,7 @@ def format_metadata_for_display(metadata):
('image_height', 'Height'),
('image_aspect_ratio', 'Aspect Ratio'),
('image_steps', 'Steps'),
+ ('image_cfg_scale', 'CFG Scale'),
('image_seed', 'Seed'),
('image_batch_size', 'Batch Size'),
('image_batch_count', 'Batch Count'),
@@ -298,11 +300,11 @@ def on_gallery_select(evt: gr.SelectData, current_page):
def send_to_generate(selected_image_path):
"""Load settings from selected image and return updates for all Generate tab inputs."""
if not selected_image_path or not os.path.exists(selected_image_path):
- return [gr.update()] * 9 + ["No image selected"]
+ return [gr.update()] * 10 + ["No image selected"]
metadata = read_image_metadata(selected_image_path)
if not metadata:
- return [gr.update()] * 9 + ["No settings found in this image"]
+ return [gr.update()] * 10 + ["No settings found in this image"]
# Return updates for each input element in order
updates = [
@@ -315,13 +317,13 @@ def send_to_generate(selected_image_path):
gr.update(value=metadata.get('image_seed', -1)),
gr.update(value=metadata.get('image_batch_size', 1)),
gr.update(value=metadata.get('image_batch_count', 1)),
+ gr.update(value=metadata.get('image_cfg_scale', 0.0)),
]
status = f"✓ Settings loaded from image (seed: {metadata.get('image_seed', 'unknown')})"
return updates + [status]
-
def create_ui():
if shared.settings['image_model_menu'] != 'None':
shared.image_model_name = shared.settings['image_model_menu']
@@ -352,9 +354,9 @@ def create_ui():
gr.Markdown("### Dimensions")
with gr.Row():
with gr.Column():
- shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=32, label="Width")
+ shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=STEP, label="Width")
with gr.Column():
- shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=32, label="Height")
+ shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=STEP, label="Height")
shared.gradio['image_swap_btn'] = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width")
with gr.Row():
@@ -369,6 +371,13 @@ def create_ui():
with gr.Row():
with gr.Column():
shared.gradio['image_steps'] = gr.Slider(1, 100, value=shared.settings['image_steps'], step=1, label="Steps")
+ shared.gradio['image_cfg_scale'] = gr.Slider(
+ 0.0, 10.0,
+ value=0.0,
+ step=0.1,
+ label="CFG Scale",
+ info="Z-Image Turbo: 0.0 | Qwen: 4.0"
+ )
shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random")
with gr.Column():
shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
@@ -597,6 +606,7 @@ def create_event_handlers():
'image_seed',
'image_batch_size',
'image_batch_count',
+ 'image_cfg_scale',
'image_gallery_status'
),
show_progress=False
@@ -644,9 +654,19 @@ def generate(state):
if pipeline_type is None:
pipeline_type = get_pipeline_type(shared.image_model)
- # Build generation kwargs based on pipeline type
+ # Process Prompt
+ prompt = state['image_prompt']
+
+ # Apply "Positive Magic" for Qwen models only
+ if pipeline_type == 'qwenimage':
+ magic_suffix = ", Ultra HD, 4K, cinematic composition"
+ # Avoid duplication if user already added it
+ if magic_suffix.strip(", ") not in prompt:
+ prompt += magic_suffix
+
+ # Build generation kwargs
gen_kwargs = {
- "prompt": state['image_prompt'],
+ "prompt": prompt,
"negative_prompt": state['image_neg_prompt'],
"height": int(state['image_height']),
"width": int(state['image_width']),
@@ -655,13 +675,15 @@ def generate(state):
"generator": generator,
}
- # Add pipeline-specific parameters
+ # Add pipeline-specific parameters for CFG
+ cfg_val = state.get('image_cfg_scale', 0.0)
+
if pipeline_type == 'qwenimage':
- # Qwen-Image uses true_cfg_scale instead of guidance_scale
- gen_kwargs["true_cfg_scale"] = state.get('image_cfg_scale', 4.0)
+ # Qwen-Image uses true_cfg_scale (typically 4.0)
+ gen_kwargs["true_cfg_scale"] = cfg_val
else:
- # Z-Image and others use guidance_scale
- gen_kwargs["guidance_scale"] = state.get('image_cfg_scale', 0.0)
+ # Z-Image and others use guidance_scale (typically 0.0 for Turbo)
+ gen_kwargs["guidance_scale"] = cfg_val
t0 = time.time()
for i in range(int(state['image_batch_count'])):
From 5c61fcf47916b2ca47f58d758107c9df949e3056 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 18:59:11 -0800
Subject: [PATCH 37/38] Autosave on prompt change
---
modules/ui.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/modules/ui.py b/modules/ui.py
index ca45a444..9700d297 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -532,6 +532,8 @@ def setup_auto_save():
'include_past_attachments',
# Image generation tab (ui_image_generation.py)
+ 'image_prompt',
+ 'image_neg_prompt',
'image_width',
'image_height',
'image_aspect_ratio',
From f45412676da068ecff3d2b472f8d7a16806eeaca Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 19:05:47 -0800
Subject: [PATCH 38/38] Minor label changes
---
modules/ui_image_generation.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 2aa9bcb4..888cc532 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -348,7 +348,7 @@ def create_ui():
value=shared.settings['image_neg_prompt']
)
- shared.gradio['image_generate_btn'] = gr.Button("✨ GENERATE", variant="primary", size="lg", elem_id="gen-btn")
+ shared.gradio['image_generate_btn'] = gr.Button("GENERATE", variant="primary", size="lg", elem_id="gen-btn")
gr.HTML("
")
gr.Markdown("### Dimensions")
@@ -419,7 +419,7 @@ def create_ui():
with gr.Column(scale=1):
gr.Markdown("### Selected Image")
shared.gradio['image_settings_display'] = gr.Markdown("Select an image to view its settings")
- shared.gradio['image_send_to_generate'] = gr.Button("📤 Send to Generate", variant="primary")
+ shared.gradio['image_send_to_generate'] = gr.Button("Send to Generate", variant="primary")
shared.gradio['image_gallery_status'] = gr.Markdown("")
# TAB 3: MODEL