From b42192c2b7e6fb541b8d6e77c478653ba4a8c1e4 Mon Sep 17 00:00:00 2001
From: oobabooga <112222186+oobabooga@users.noreply.github.com>
Date: Mon, 1 Dec 2025 10:42:03 -0800
Subject: [PATCH] Implement settings autosaving
---
modules/image_model_settings.py | 105 ---------
modules/shared.py | 30 +++
modules/ui.py | 34 ++-
modules/ui_image_generation.py | 382 +++++++++++++-------------------
modules/ui_model_menu.py | 63 ------
server.py | 4 +
6 files changed, 217 insertions(+), 401 deletions(-)
delete mode 100644 modules/image_model_settings.py
diff --git a/modules/image_model_settings.py b/modules/image_model_settings.py
deleted file mode 100644
index edb6bf20..00000000
--- a/modules/image_model_settings.py
+++ /dev/null
@@ -1,105 +0,0 @@
-from pathlib import Path
-
-import yaml
-
-import modules.shared as shared
-from modules.logging_colors import logger
-
-DEFAULTS = {
- 'model_name': 'None',
- 'dtype': 'bfloat16',
- 'attn_backend': 'sdpa',
- 'cpu_offload': False,
- 'compile_model': False,
-}
-
-
-def get_settings_path():
- """Get the path to the image model settings file."""
- return Path(shared.args.image_model_dir) / 'settings.yaml'
-
-
-def load_yaml_settings():
- """Load raw settings from yaml file."""
- settings_path = get_settings_path()
-
- if not settings_path.exists():
- return {}
-
- try:
- with open(settings_path, 'r') as f:
- saved = yaml.safe_load(f)
- return saved if saved else {}
- except Exception as e:
- logger.warning(f"Failed to load image model settings: {e}")
- return {}
-
-
-def get_effective_settings():
- """
- Get effective settings with precedence:
- 1. CLI flag (if provided)
- 2. Saved yaml value (if exists)
- 3. Hardcoded default
-
- Returns a dict with all settings.
- """
- yaml_settings = load_yaml_settings()
-
- effective = {}
-
- # model_name: CLI --image-model > yaml > default
- if shared.args.image_model:
- effective['model_name'] = shared.args.image_model
- else:
- effective['model_name'] = yaml_settings.get('model_name', DEFAULTS['model_name'])
-
- # dtype: CLI --image-dtype > yaml > default
- if shared.args.image_dtype is not None:
- effective['dtype'] = shared.args.image_dtype
- else:
- effective['dtype'] = yaml_settings.get('dtype', DEFAULTS['dtype'])
-
- # attn_backend: CLI --image-attn-backend > yaml > default
- if shared.args.image_attn_backend is not None:
- effective['attn_backend'] = shared.args.image_attn_backend
- else:
- effective['attn_backend'] = yaml_settings.get('attn_backend', DEFAULTS['attn_backend'])
-
- # cpu_offload: CLI --image-cpu-offload > yaml > default
- # For store_true flags, check if explicitly set (True means it was passed)
- if shared.args.image_cpu_offload:
- effective['cpu_offload'] = True
- else:
- effective['cpu_offload'] = yaml_settings.get('cpu_offload', DEFAULTS['cpu_offload'])
-
- # compile_model: CLI --image-compile > yaml > default
- if shared.args.image_compile:
- effective['compile_model'] = True
- else:
- effective['compile_model'] = yaml_settings.get('compile_model', DEFAULTS['compile_model'])
-
- return effective
-
-
-def save_image_model_settings(model_name, dtype, attn_backend, cpu_offload, compile_model):
- """Save image model settings to yaml."""
- settings_path = get_settings_path()
-
- # Ensure directory exists
- settings_path.parent.mkdir(parents=True, exist_ok=True)
-
- settings = {
- 'model_name': model_name,
- 'dtype': dtype,
- 'attn_backend': attn_backend,
- 'cpu_offload': cpu_offload,
- 'compile_model': compile_model,
- }
-
- try:
- with open(settings_path, 'w') as f:
- yaml.dump(settings, f, default_flow_style=False)
- logger.info(f"Saved image model settings to {settings_path}")
- except Exception as e:
- logger.error(f"Failed to save image model settings: {e}")
diff --git a/modules/shared.py b/modules/shared.py
index e54eca8f..9a062e91 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -303,6 +303,22 @@ settings = {
# Extensions
'default_extensions': [],
+
+ # Image generation settings
+ 'image_prompt': '',
+ 'image_neg_prompt': '',
+ 'image_width': 1024,
+ 'image_height': 1024,
+ 'image_aspect_ratio': '1:1 Square',
+ 'image_steps': 9,
+ 'image_seed': -1,
+ 'image_batch_size': 1,
+ 'image_batch_count': 1,
+ 'image_model_menu': 'None',
+ 'image_dtype': 'bfloat16',
+ 'image_attn_backend': 'sdpa',
+ 'image_compile': False,
+ 'image_cpu_offload': False,
}
default_settings = copy.deepcopy(settings)
@@ -327,6 +343,20 @@ def do_cmd_flags_warnings():
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
+def apply_image_model_cli_overrides():
+ """Apply CLI flags for image model settings, overriding saved settings."""
+ if args.image_model:
+ settings['image_model_menu'] = args.image_model
+ if args.image_dtype is not None:
+ settings['image_dtype'] = args.image_dtype
+ if args.image_attn_backend is not None:
+ settings['image_attn_backend'] = args.image_attn_backend
+ if args.image_cpu_offload:
+ settings['image_cpu_offload'] = True
+ if args.image_compile:
+ settings['image_compile'] = True
+
+
def fix_loader_name(name):
if not name:
return name
diff --git a/modules/ui.py b/modules/ui.py
index f99e8b6a..3aba20b4 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -280,6 +280,24 @@ def list_interface_input_elements():
'include_past_attachments',
]
+ # Image generation elements
+ elements += [
+ 'image_prompt',
+ 'image_neg_prompt',
+ 'image_width',
+ 'image_height',
+ 'image_aspect_ratio',
+ 'image_steps',
+ 'image_seed',
+ 'image_batch_size',
+ 'image_batch_count',
+ 'image_model_menu',
+ 'image_dtype',
+ 'image_attn_backend',
+ 'image_compile',
+ 'image_cpu_offload',
+ ]
+
return elements
@@ -509,7 +527,21 @@ def setup_auto_save():
'theme_state',
'show_two_notebook_columns',
'paste_to_attachment',
- 'include_past_attachments'
+ 'include_past_attachments',
+
+ # Image generation tab (ui_image_generation.py)
+ 'image_width',
+ 'image_height',
+ 'image_aspect_ratio',
+ 'image_steps',
+ 'image_seed',
+ 'image_batch_size',
+ 'image_batch_count',
+ 'image_model_menu',
+ 'image_dtype',
+ 'image_attn_backend',
+ 'image_compile',
+ 'image_cpu_offload',
]
for element_name in change_elements:
diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py
index 038e96c8..fe0c8120 100644
--- a/modules/ui_image_generation.py
+++ b/modules/ui_image_generation.py
@@ -7,14 +7,10 @@ import gradio as gr
import numpy as np
import torch
-from modules import shared, utils
-from modules.image_model_settings import (
- get_effective_settings,
- save_image_model_settings
-)
+from modules import shared, ui, utils
from modules.image_models import load_image_model, unload_image_model
+from modules.utils import gradio
-# Aspect ratio definitions: name -> (width_ratio, height_ratio)
ASPECT_RATIOS = {
"1:1 Square": (1, 1),
"16:9 Cinema": (16, 9),
@@ -23,50 +19,34 @@ ASPECT_RATIOS = {
"Custom": None,
}
-STEP = 32 # Slider step for rounding
+STEP = 32
def round_to_step(value, step=STEP):
- """Round a value to the nearest step."""
return round(value / step) * step
def clamp(value, min_val, max_val):
- """Clamp value between min and max."""
return max(min_val, min(max_val, value))
def apply_aspect_ratio(aspect_ratio, current_width, current_height):
- """
- Apply an aspect ratio preset.
-
- Logic to prevent dimension creep:
- - For tall ratios (like 9:16): keep width fixed, calculate height
- - For wide ratios (like 16:9): keep height fixed, calculate width
- - For square (1:1): use the smaller of the current dimensions
-
- Returns (new_width, new_height).
- """
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return current_width, current_height
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
if w_ratio == h_ratio:
- # Square ratio - use the smaller current dimension to prevent creep
base = min(current_width, current_height)
new_width = base
new_height = base
elif w_ratio < h_ratio:
- # Tall ratio (like 9:16) - width is the smaller side, keep it fixed
new_width = current_width
new_height = round_to_step(current_width * h_ratio / w_ratio)
else:
- # Wide ratio (like 16:9) - height is the smaller side, keep it fixed
new_height = current_height
new_width = round_to_step(current_height * w_ratio / h_ratio)
- # Clamp to slider bounds
new_width = clamp(new_width, 256, 2048)
new_height = clamp(new_height, 256, 2048)
@@ -74,7 +54,6 @@ def apply_aspect_ratio(aspect_ratio, current_width, current_height):
def update_height_from_width(width, aspect_ratio):
- """Update height when width changes (if not Custom)."""
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return gr.update()
@@ -86,7 +65,6 @@ def update_height_from_width(width, aspect_ratio):
def update_width_from_height(height, aspect_ratio):
- """Update width when height changes (if not Custom)."""
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
return gr.update()
@@ -98,16 +76,13 @@ def update_width_from_height(height, aspect_ratio):
def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
- """Swap dimensions and update aspect ratio to match (or set to Custom)."""
new_width, new_height = height, width
- # Try to find a matching aspect ratio for the swapped dimensions
new_ratio = "Custom"
for name, ratios in ASPECT_RATIOS.items():
if ratios is None:
continue
w_r, h_r = ratios
- # Check if the swapped dimensions match this ratio (within tolerance)
expected_height = new_width * h_r / w_r
if abs(expected_height - new_height) < STEP:
new_ratio = name
@@ -117,291 +92,257 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
def create_ui():
- # Get effective settings (CLI > yaml > defaults)
- settings = get_effective_settings()
-
- # Update shared state (but don't load the model yet)
- if settings['model_name'] != 'None':
- shared.image_model_name = settings['model_name']
+ if shared.settings['image_model_menu'] != 'None':
+ shared.image_model_name = shared.settings['image_model_menu']
with gr.Tab("Image AI", elem_id="image-ai-tab"):
with gr.Tabs():
- # TAB 1: GENERATION STUDIO
+ # TAB 1: GENERATE
with gr.TabItem("Generate"):
with gr.Row():
-
- # === LEFT COLUMN: CONTROLS ===
with gr.Column(scale=4, min_width=350):
+ shared.gradio['image_prompt'] = gr.Textbox(
+ label="Prompt",
+ placeholder="Describe your imagination...",
+ lines=3,
+ autofocus=True,
+ value=shared.settings['image_prompt']
+ )
+ shared.gradio['image_neg_prompt'] = gr.Textbox(
+ label="Negative Prompt",
+ placeholder="Low quality...",
+ lines=3,
+ value=shared.settings['image_neg_prompt']
+ )
- # 1. PROMPT
- prompt = gr.Textbox(label="Prompt", placeholder="Describe your imagination...", lines=3, autofocus=True)
- neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="Low quality...", lines=3)
-
- # 2. GENERATE BUTTON
- generate_btn = gr.Button("✨ GENERATE", variant="primary", size="lg", elem_id="gen-btn")
+ shared.gradio['image_generate_btn'] = gr.Button("✨ GENERATE", variant="primary", size="lg", elem_id="gen-btn")
gr.HTML("
")
- # 3. DIMENSIONS
gr.Markdown("### 📐 Dimensions")
with gr.Row():
with gr.Column():
- width_slider = gr.Slider(256, 2048, value=1024, step=32, label="Width")
-
+ shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=32, label="Width")
with gr.Column():
- height_slider = gr.Slider(256, 2048, value=1024, step=32, label="Height")
-
- swap_btn = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width")
+ shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=32, label="Height")
+ shared.gradio['image_swap_btn'] = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width")
with gr.Row():
- preset_radio = gr.Radio(
+ shared.gradio['image_aspect_ratio'] = gr.Radio(
choices=["1:1 Square", "16:9 Cinema", "9:16 Mobile", "4:3 Photo", "Custom"],
- value="1:1 Square",
+ value=shared.settings['image_aspect_ratio'],
label="Aspect Ratio",
interactive=True
)
- # 4. SETTINGS & BATCHING
gr.Markdown("### ⚙️ Config")
with gr.Row():
with gr.Column():
- steps_slider = gr.Slider(1, 15, value=9, step=1, label="Steps")
- seed_input = gr.Number(label="Seed", value=-1, precision=0, info="-1 = Random")
-
+ shared.gradio['image_steps'] = gr.Slider(1, 15, value=shared.settings['image_steps'], step=1, label="Steps")
+ shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random")
with gr.Column():
- batch_size_parallel = gr.Slider(1, 32, value=1, step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
- batch_count_seq = gr.Slider(1, 128, value=1, step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
+ shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
+ shared.gradio['image_batch_count'] = gr.Slider(1, 128, value=shared.settings['image_batch_count'], step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
- # === RIGHT COLUMN: VIEWPORT ===
with gr.Column(scale=6, min_width=500):
with gr.Column(elem_classes=["viewport-container"]):
- output_gallery = gr.Gallery(
- label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True
- )
+ shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True)
with gr.Row():
- used_seed = gr.Markdown(label="Info", interactive=False)
+ shared.gradio['image_used_seed'] = gr.Markdown(label="Info", interactive=False)
- # TAB 2: HISTORY VIEWER
+ # TAB 2: GALLERY
with gr.TabItem("Gallery"):
with gr.Row():
- refresh_btn = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
+ shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
+ shared.gradio['image_history_gallery'] = gr.Gallery(label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True)
- history_gallery = gr.Gallery(
- label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True
- )
-
- # TAB 3: MODEL SETTINGS
+ # TAB 3: MODEL
with gr.TabItem("Model"):
with gr.Row():
with gr.Column():
with gr.Row():
- image_model_menu = gr.Dropdown(
+ shared.gradio['image_model_menu'] = gr.Dropdown(
choices=utils.get_available_image_models(),
- value=settings['model_name'],
+ value=shared.settings['image_model_menu'],
label='Model',
elem_classes='slim-dropdown'
)
- image_refresh_models = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40)
- image_load_model = gr.Button("Load", variant='primary', elem_classes='refresh-button')
- image_unload_model = gr.Button("Unload", elem_classes='refresh-button')
+ shared.gradio['image_refresh_models'] = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40)
+ shared.gradio['image_load_model'] = gr.Button("Load", variant='primary', elem_classes='refresh-button')
+ shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button')
gr.Markdown("## Settings")
-
with gr.Row():
with gr.Column():
- image_dtype = gr.Dropdown(
+ shared.gradio['image_dtype'] = gr.Dropdown(
choices=['bfloat16', 'float16'],
- value=settings['dtype'],
+ value=shared.settings['image_dtype'],
label='Data Type',
info='bfloat16 recommended for modern GPUs'
)
-
- image_attn_backend = gr.Dropdown(
+ shared.gradio['image_attn_backend'] = gr.Dropdown(
choices=['sdpa', 'flash_attention_2', 'flash_attention_3'],
- value=settings['attn_backend'],
+ value=shared.settings['image_attn_backend'],
label='Attention Backend',
info='SDPA is default. Flash Attention requires compatible GPU.'
)
-
with gr.Column():
- image_compile = gr.Checkbox(
- value=settings['compile_model'],
+ shared.gradio['image_compile'] = gr.Checkbox(
+ value=shared.settings['image_compile'],
label='Compile Model',
info='Faster inference after first run. First run will be slow.'
)
-
- image_cpu_offload = gr.Checkbox(
- value=settings['cpu_offload'],
+ shared.gradio['image_cpu_offload'] = gr.Checkbox(
+ value=shared.settings['image_cpu_offload'],
label='CPU Offload',
info='Enable for low VRAM GPUs. Slower but uses less memory.'
)
with gr.Column():
- image_download_path = gr.Textbox(
+ shared.gradio['image_download_path'] = gr.Textbox(
label="Download model",
placeholder="Tongyi-MAI/Z-Image-Turbo",
- info="Enter the HuggingFace model path like Tongyi-MAI/Z-Image-Turbo. Use : for branch, e.g. Tongyi-MAI/Z-Image-Turbo:main"
+ info="Enter HuggingFace path. Use : for branch, e.g. user/model:main"
)
- image_download_btn = gr.Button("Download", variant='primary')
- image_model_status = gr.Markdown(
- value=f"Model: **{settings['model_name']}** (not loaded)" if settings['model_name'] != 'None' else "No model selected"
+ shared.gradio['image_download_btn'] = gr.Button("Download", variant='primary')
+ shared.gradio['image_model_status'] = gr.Markdown(
+ value=f"Model: **{shared.settings['image_model_menu']}** (not loaded)" if shared.settings['image_model_menu'] != 'None' else "No model selected"
)
- # === WIRING ===
- # Aspect ratio preset changes -> update dimensions
- preset_radio.change(
- fn=apply_aspect_ratio,
- inputs=[preset_radio, width_slider, height_slider],
- outputs=[width_slider, height_slider],
- show_progress=False
- )
+def create_event_handlers():
+ # Dimension controls
+ shared.gradio['image_aspect_ratio'].change(
+ apply_aspect_ratio,
+ gradio('image_aspect_ratio', 'image_width', 'image_height'),
+ gradio('image_width', 'image_height'),
+ show_progress=False
+ )
- # Width slider changes -> update height (if not Custom)
- width_slider.release(
- fn=update_height_from_width,
- inputs=[width_slider, preset_radio],
- outputs=[height_slider],
- show_progress=False
- )
+ shared.gradio['image_width'].release(
+ update_height_from_width,
+ gradio('image_width', 'image_aspect_ratio'),
+ gradio('image_height'),
+ show_progress=False
+ )
- # Height slider changes -> update width (if not Custom)
- height_slider.release(
- fn=update_width_from_height,
- inputs=[height_slider, preset_radio],
- outputs=[width_slider],
- show_progress=False
- )
+ shared.gradio['image_height'].release(
+ update_width_from_height,
+ gradio('image_height', 'image_aspect_ratio'),
+ gradio('image_width'),
+ show_progress=False
+ )
- # Swap button -> swap dimensions and update aspect ratio
- swap_btn.click(
- fn=swap_dimensions_and_update_ratio,
- inputs=[width_slider, height_slider, preset_radio],
- outputs=[width_slider, height_slider, preset_radio],
- show_progress=False
- )
+ shared.gradio['image_swap_btn'].click(
+ swap_dimensions_and_update_ratio,
+ gradio('image_width', 'image_height', 'image_aspect_ratio'),
+ gradio('image_width', 'image_height', 'image_aspect_ratio'),
+ show_progress=False
+ )
- # Generation
- inputs = [prompt, neg_prompt, width_slider, height_slider, steps_slider, seed_input, batch_size_parallel, batch_count_seq]
- outputs = [output_gallery, used_seed]
+ # Generation
+ shared.gradio['image_generate_btn'].click(
+ ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
+ generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
- generate_btn.click(
- fn=lambda *args: generate(*args, image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile),
- inputs=inputs,
- outputs=outputs
- )
- prompt.submit(
- fn=lambda *args: generate(*args, image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile),
- inputs=inputs,
- outputs=outputs
- )
- neg_prompt.submit(
- fn=lambda *args: generate(*args, image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile),
- inputs=inputs,
- outputs=outputs
- )
+ shared.gradio['image_prompt'].submit(
+ ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
+ generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
- # Model tab events
- image_refresh_models.click(
- fn=lambda: gr.update(choices=utils.get_available_image_models()),
- inputs=None,
- outputs=[image_model_menu],
- show_progress=False
- )
+ shared.gradio['image_neg_prompt'].submit(
+ ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
+ generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
- image_load_model.click(
- fn=load_image_model_wrapper,
- inputs=[image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile],
- outputs=[image_model_status],
- show_progress=True
- )
+ # Model management
+ shared.gradio['image_refresh_models'].click(
+ lambda: gr.update(choices=utils.get_available_image_models()),
+ None,
+ gradio('image_model_menu'),
+ show_progress=False
+ )
- image_unload_model.click(
- fn=unload_image_model_wrapper,
- inputs=None,
- outputs=[image_model_status],
- show_progress=False
- )
+ shared.gradio['image_load_model'].click(
+ load_image_model_wrapper,
+ gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile'),
+ gradio('image_model_status'),
+ show_progress=True
+ )
- image_download_btn.click(
- fn=download_image_model_wrapper,
- inputs=[image_download_path],
- outputs=[image_model_status, image_model_menu],
- show_progress=True
- )
+ shared.gradio['image_unload_model'].click(
+ unload_image_model_wrapper,
+ None,
+ gradio('image_model_status'),
+ show_progress=False
+ )
- # History
- refresh_btn.click(fn=get_history_images, inputs=None, outputs=history_gallery, show_progress=False)
+ shared.gradio['image_download_btn'].click(
+ download_image_model_wrapper,
+ gradio('image_download_path'),
+ gradio('image_model_status', 'image_model_menu'),
+ show_progress=True
+ )
+
+ # History
+ shared.gradio['image_refresh_history'].click(
+ get_history_images,
+ None,
+ gradio('image_history_gallery'),
+ show_progress=False
+ )
-def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq,
- model_menu, dtype_dropdown, attn_dropdown, cpu_offload_checkbox, compile_checkbox):
- """Generate images with the current model settings."""
+def generate(state):
+ model_name = state['image_model_menu']
- model_name = shared.image_model_name
-
- if model_name == 'None':
+ if not model_name or model_name == 'None':
return [], "No image model selected. Go to the Model tab and select a model."
- # Auto-load model if not loaded
if shared.image_model is None:
- # Get effective settings (CLI > yaml > defaults)
- settings = get_effective_settings()
-
result = load_image_model(
model_name,
- dtype=settings['dtype'],
- attn_backend=settings['attn_backend'],
- cpu_offload=settings['cpu_offload'],
- compile_model=settings['compile_model']
+ dtype=state['image_dtype'],
+ attn_backend=state['image_attn_backend'],
+ cpu_offload=state['image_cpu_offload'],
+ compile_model=state['image_compile']
)
-
if result is None:
return [], f"Failed to load model `{model_name}`."
+ shared.image_model_name = model_name
+
+ seed = state['image_seed']
if seed == -1:
seed = np.random.randint(0, 2**32 - 1)
generator = torch.Generator("cuda").manual_seed(int(seed))
all_images = []
- # Sequential loop (easier on VRAM)
- for i in range(int(batch_count_seq)):
- current_seed = seed + i
- generator.manual_seed(int(current_seed))
-
- # Parallel generation
+ for i in range(int(state['image_batch_count'])):
+ generator.manual_seed(int(seed + i))
batch_results = shared.image_model(
- prompt=prompt,
- negative_prompt=neg_prompt,
- height=int(height),
- width=int(width),
- num_inference_steps=int(steps),
+ prompt=state['image_prompt'],
+ negative_prompt=state['image_neg_prompt'],
+ height=int(state['image_height']),
+ width=int(state['image_width']),
+ num_inference_steps=int(state['image_steps']),
guidance_scale=0.0,
- num_images_per_prompt=int(batch_size_parallel),
+ num_images_per_prompt=int(state['image_batch_size']),
generator=generator,
).images
-
all_images.extend(batch_results)
- # Save to disk
- save_generated_images(all_images, prompt, seed)
-
+ save_generated_images(all_images, state['image_prompt'], seed)
return all_images, f"Seed: {seed}"
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model):
- """Load model and save settings."""
- if model_name == 'None' or not model_name:
+ if not model_name or model_name == 'None':
yield "No model selected"
return
try:
yield f"Loading `{model_name}`..."
-
- # Unload existing model first
unload_image_model()
- # Load the new model
result = load_image_model(
model_name,
dtype=dtype,
@@ -411,29 +352,22 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi
)
if result is not None:
- # Save settings to yaml
- save_image_model_settings(model_name, dtype, attn_backend, cpu_offload, compile_model)
+ shared.image_model_name = model_name
yield f"✓ Loaded **{model_name}**"
else:
yield f"✗ Failed to load `{model_name}`"
-
except Exception:
- exc = traceback.format_exc()
- yield f"Error:\n```\n{exc}\n```"
+ yield f"Error:\n```\n{traceback.format_exc()}\n```"
def unload_image_model_wrapper():
- """Unload model wrapper."""
unload_image_model()
-
if shared.image_model_name != 'None':
return f"Model: **{shared.image_model_name}** (not loaded)"
- else:
- return "No model loaded"
+ return "No model loaded"
def download_image_model_wrapper(model_path):
- """Download a model from Hugging Face."""
from huggingface_hub import snapshot_download
if not model_path:
@@ -441,17 +375,15 @@ def download_image_model_wrapper(model_path):
return
try:
- # Parse model name and branch
if ':' in model_path:
model_id, branch = model_path.rsplit(':', 1)
else:
model_id, branch = model_path, 'main'
- # Output folder name (username_model format)
folder_name = model_id.replace('/', '_')
output_folder = Path(shared.args.image_model_dir) / folder_name
- yield f"Downloading `{model_id}` (branch: {branch}) to `{output_folder}`...", gr.update()
+ yield f"Downloading `{model_id}` (branch: {branch})...", gr.update()
snapshot_download(
repo_id=model_id,
@@ -460,48 +392,34 @@ def download_image_model_wrapper(model_path):
local_dir_use_symlinks=False,
)
- # Refresh the model list
new_choices = utils.get_available_image_models()
-
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
-
except Exception:
- exc = traceback.format_exc()
- yield f"Error:\n```\n{exc}\n```", gr.update()
+ yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update()
def save_generated_images(images, prompt, seed):
- """Save generated images to disk."""
date_str = datetime.now().strftime("%Y-%m-%d")
folder_path = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder_path, exist_ok=True)
- saved_paths = []
-
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
filename = f"{timestamp}_{seed}_{idx}.png"
- full_path = os.path.join(folder_path, filename)
-
- img.save(full_path)
- saved_paths.append(full_path)
-
- return saved_paths
+ img.save(os.path.join(folder_path, filename))
def get_history_images():
- """Scan the outputs folder and return all images, newest first."""
output_dir = os.path.join("user_data", "image_outputs")
if not os.path.exists(output_dir):
return []
image_files = []
- for root, dirs, files in os.walk(output_dir):
+ for root, _, files in os.walk(output_dir):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
full_path = os.path.join(root, file)
- mtime = os.path.getmtime(full_path)
- image_files.append((full_path, mtime))
+ image_files.append((full_path, os.path.getmtime(full_path)))
image_files.sort(key=lambda x: x[1], reverse=True)
return [x[0] for x in image_files]
diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py
index a5e0f640..86adc229 100644
--- a/modules/ui_model_menu.py
+++ b/modules/ui_model_menu.py
@@ -434,66 +434,3 @@ def format_file_size(size_bytes):
return f"{s:.2f} {size_names[i]}"
else:
return f"{s:.1f} {size_names[i]}"
-
-
-def load_image_model_wrapper(selected_model):
- """Wrapper for loading image models with status updates."""
- from modules.image_models import load_image_model, unload_image_model
-
- if selected_model == 'None' or not selected_model:
- yield "No model selected"
- return
-
- try:
- yield f"Loading `{selected_model}`..."
- unload_image_model()
- result = load_image_model(selected_model)
-
- if result is not None:
- yield f"Successfully loaded `{selected_model}`."
- else:
- yield f"Failed to load `{selected_model}`."
- except Exception:
- exc = traceback.format_exc()
- yield exc.replace('\n', '\n\n')
-
-
-def handle_unload_image_model_click():
- """Handler for the image model unload button."""
- from modules.image_models import unload_image_model
- unload_image_model()
- return "Image model unloaded"
-
-
-def download_image_model_wrapper(custom_model):
- """Download an image model from Hugging Face."""
- from huggingface_hub import snapshot_download
-
- if not custom_model:
- yield "No model specified"
- return
-
- try:
- # Parse model name and branch
- if ':' in custom_model:
- model_name, branch = custom_model.rsplit(':', 1)
- else:
- model_name, branch = custom_model, 'main'
-
- # Output folder
- output_folder = Path(shared.args.image_model_dir) / model_name.split('/')[-1]
-
- yield f"Downloading `{model_name}` (branch: {branch})..."
-
- snapshot_download(
- repo_id=model_name,
- revision=branch,
- local_dir=output_folder,
- local_dir_use_symlinks=False,
- )
-
- yield f"Model successfully saved to `{output_folder}/`."
-
- except Exception:
- exc = traceback.format_exc()
- yield exc.replace('\n', '\n\n')
diff --git a/server.py b/server.py
index 87bbdc4a..5a75e887 100644
--- a/server.py
+++ b/server.py
@@ -172,6 +172,7 @@ def create_interface():
ui_chat.create_event_handlers()
ui_default.create_event_handlers()
ui_notebook.create_event_handlers()
+ ui_image_generation.create_event_handlers()
# Other events
ui_file_saving.create_event_handlers()
@@ -258,6 +259,9 @@ if __name__ == "__main__":
if new_settings:
shared.settings.update(new_settings)
+ # Apply CLI overrides for image model settings (CLI flags take precedence over saved settings)
+ shared.apply_image_model_cli_overrides()
+
# Fallback settings for models
shared.model_config['.*'] = get_fallback_settings()
shared.model_config.move_to_end('.*', last=False) # Move to the beginning