diff --git a/modules/image_model_settings.py b/modules/image_model_settings.py deleted file mode 100644 index edb6bf20..00000000 --- a/modules/image_model_settings.py +++ /dev/null @@ -1,105 +0,0 @@ -from pathlib import Path - -import yaml - -import modules.shared as shared -from modules.logging_colors import logger - -DEFAULTS = { - 'model_name': 'None', - 'dtype': 'bfloat16', - 'attn_backend': 'sdpa', - 'cpu_offload': False, - 'compile_model': False, -} - - -def get_settings_path(): - """Get the path to the image model settings file.""" - return Path(shared.args.image_model_dir) / 'settings.yaml' - - -def load_yaml_settings(): - """Load raw settings from yaml file.""" - settings_path = get_settings_path() - - if not settings_path.exists(): - return {} - - try: - with open(settings_path, 'r') as f: - saved = yaml.safe_load(f) - return saved if saved else {} - except Exception as e: - logger.warning(f"Failed to load image model settings: {e}") - return {} - - -def get_effective_settings(): - """ - Get effective settings with precedence: - 1. CLI flag (if provided) - 2. Saved yaml value (if exists) - 3. Hardcoded default - - Returns a dict with all settings. - """ - yaml_settings = load_yaml_settings() - - effective = {} - - # model_name: CLI --image-model > yaml > default - if shared.args.image_model: - effective['model_name'] = shared.args.image_model - else: - effective['model_name'] = yaml_settings.get('model_name', DEFAULTS['model_name']) - - # dtype: CLI --image-dtype > yaml > default - if shared.args.image_dtype is not None: - effective['dtype'] = shared.args.image_dtype - else: - effective['dtype'] = yaml_settings.get('dtype', DEFAULTS['dtype']) - - # attn_backend: CLI --image-attn-backend > yaml > default - if shared.args.image_attn_backend is not None: - effective['attn_backend'] = shared.args.image_attn_backend - else: - effective['attn_backend'] = yaml_settings.get('attn_backend', DEFAULTS['attn_backend']) - - # cpu_offload: CLI --image-cpu-offload > yaml > default - # For store_true flags, check if explicitly set (True means it was passed) - if shared.args.image_cpu_offload: - effective['cpu_offload'] = True - else: - effective['cpu_offload'] = yaml_settings.get('cpu_offload', DEFAULTS['cpu_offload']) - - # compile_model: CLI --image-compile > yaml > default - if shared.args.image_compile: - effective['compile_model'] = True - else: - effective['compile_model'] = yaml_settings.get('compile_model', DEFAULTS['compile_model']) - - return effective - - -def save_image_model_settings(model_name, dtype, attn_backend, cpu_offload, compile_model): - """Save image model settings to yaml.""" - settings_path = get_settings_path() - - # Ensure directory exists - settings_path.parent.mkdir(parents=True, exist_ok=True) - - settings = { - 'model_name': model_name, - 'dtype': dtype, - 'attn_backend': attn_backend, - 'cpu_offload': cpu_offload, - 'compile_model': compile_model, - } - - try: - with open(settings_path, 'w') as f: - yaml.dump(settings, f, default_flow_style=False) - logger.info(f"Saved image model settings to {settings_path}") - except Exception as e: - logger.error(f"Failed to save image model settings: {e}") diff --git a/modules/shared.py b/modules/shared.py index e54eca8f..9a062e91 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -303,6 +303,22 @@ settings = { # Extensions 'default_extensions': [], + + # Image generation settings + 'image_prompt': '', + 'image_neg_prompt': '', + 'image_width': 1024, + 'image_height': 1024, + 'image_aspect_ratio': '1:1 Square', + 'image_steps': 9, + 'image_seed': -1, + 'image_batch_size': 1, + 'image_batch_count': 1, + 'image_model_menu': 'None', + 'image_dtype': 'bfloat16', + 'image_attn_backend': 'sdpa', + 'image_compile': False, + 'image_cpu_offload': False, } default_settings = copy.deepcopy(settings) @@ -327,6 +343,20 @@ def do_cmd_flags_warnings(): logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.') +def apply_image_model_cli_overrides(): + """Apply CLI flags for image model settings, overriding saved settings.""" + if args.image_model: + settings['image_model_menu'] = args.image_model + if args.image_dtype is not None: + settings['image_dtype'] = args.image_dtype + if args.image_attn_backend is not None: + settings['image_attn_backend'] = args.image_attn_backend + if args.image_cpu_offload: + settings['image_cpu_offload'] = True + if args.image_compile: + settings['image_compile'] = True + + def fix_loader_name(name): if not name: return name diff --git a/modules/ui.py b/modules/ui.py index f99e8b6a..3aba20b4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -280,6 +280,24 @@ def list_interface_input_elements(): 'include_past_attachments', ] + # Image generation elements + elements += [ + 'image_prompt', + 'image_neg_prompt', + 'image_width', + 'image_height', + 'image_aspect_ratio', + 'image_steps', + 'image_seed', + 'image_batch_size', + 'image_batch_count', + 'image_model_menu', + 'image_dtype', + 'image_attn_backend', + 'image_compile', + 'image_cpu_offload', + ] + return elements @@ -509,7 +527,21 @@ def setup_auto_save(): 'theme_state', 'show_two_notebook_columns', 'paste_to_attachment', - 'include_past_attachments' + 'include_past_attachments', + + # Image generation tab (ui_image_generation.py) + 'image_width', + 'image_height', + 'image_aspect_ratio', + 'image_steps', + 'image_seed', + 'image_batch_size', + 'image_batch_count', + 'image_model_menu', + 'image_dtype', + 'image_attn_backend', + 'image_compile', + 'image_cpu_offload', ] for element_name in change_elements: diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py index 038e96c8..fe0c8120 100644 --- a/modules/ui_image_generation.py +++ b/modules/ui_image_generation.py @@ -7,14 +7,10 @@ import gradio as gr import numpy as np import torch -from modules import shared, utils -from modules.image_model_settings import ( - get_effective_settings, - save_image_model_settings -) +from modules import shared, ui, utils from modules.image_models import load_image_model, unload_image_model +from modules.utils import gradio -# Aspect ratio definitions: name -> (width_ratio, height_ratio) ASPECT_RATIOS = { "1:1 Square": (1, 1), "16:9 Cinema": (16, 9), @@ -23,50 +19,34 @@ ASPECT_RATIOS = { "Custom": None, } -STEP = 32 # Slider step for rounding +STEP = 32 def round_to_step(value, step=STEP): - """Round a value to the nearest step.""" return round(value / step) * step def clamp(value, min_val, max_val): - """Clamp value between min and max.""" return max(min_val, min(max_val, value)) def apply_aspect_ratio(aspect_ratio, current_width, current_height): - """ - Apply an aspect ratio preset. - - Logic to prevent dimension creep: - - For tall ratios (like 9:16): keep width fixed, calculate height - - For wide ratios (like 16:9): keep height fixed, calculate width - - For square (1:1): use the smaller of the current dimensions - - Returns (new_width, new_height). - """ if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS: return current_width, current_height w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] if w_ratio == h_ratio: - # Square ratio - use the smaller current dimension to prevent creep base = min(current_width, current_height) new_width = base new_height = base elif w_ratio < h_ratio: - # Tall ratio (like 9:16) - width is the smaller side, keep it fixed new_width = current_width new_height = round_to_step(current_width * h_ratio / w_ratio) else: - # Wide ratio (like 16:9) - height is the smaller side, keep it fixed new_height = current_height new_width = round_to_step(current_height * w_ratio / h_ratio) - # Clamp to slider bounds new_width = clamp(new_width, 256, 2048) new_height = clamp(new_height, 256, 2048) @@ -74,7 +54,6 @@ def apply_aspect_ratio(aspect_ratio, current_width, current_height): def update_height_from_width(width, aspect_ratio): - """Update height when width changes (if not Custom).""" if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS: return gr.update() @@ -86,7 +65,6 @@ def update_height_from_width(width, aspect_ratio): def update_width_from_height(height, aspect_ratio): - """Update width when height changes (if not Custom).""" if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS: return gr.update() @@ -98,16 +76,13 @@ def update_width_from_height(height, aspect_ratio): def swap_dimensions_and_update_ratio(width, height, aspect_ratio): - """Swap dimensions and update aspect ratio to match (or set to Custom).""" new_width, new_height = height, width - # Try to find a matching aspect ratio for the swapped dimensions new_ratio = "Custom" for name, ratios in ASPECT_RATIOS.items(): if ratios is None: continue w_r, h_r = ratios - # Check if the swapped dimensions match this ratio (within tolerance) expected_height = new_width * h_r / w_r if abs(expected_height - new_height) < STEP: new_ratio = name @@ -117,291 +92,257 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio): def create_ui(): - # Get effective settings (CLI > yaml > defaults) - settings = get_effective_settings() - - # Update shared state (but don't load the model yet) - if settings['model_name'] != 'None': - shared.image_model_name = settings['model_name'] + if shared.settings['image_model_menu'] != 'None': + shared.image_model_name = shared.settings['image_model_menu'] with gr.Tab("Image AI", elem_id="image-ai-tab"): with gr.Tabs(): - # TAB 1: GENERATION STUDIO + # TAB 1: GENERATE with gr.TabItem("Generate"): with gr.Row(): - - # === LEFT COLUMN: CONTROLS === with gr.Column(scale=4, min_width=350): + shared.gradio['image_prompt'] = gr.Textbox( + label="Prompt", + placeholder="Describe your imagination...", + lines=3, + autofocus=True, + value=shared.settings['image_prompt'] + ) + shared.gradio['image_neg_prompt'] = gr.Textbox( + label="Negative Prompt", + placeholder="Low quality...", + lines=3, + value=shared.settings['image_neg_prompt'] + ) - # 1. PROMPT - prompt = gr.Textbox(label="Prompt", placeholder="Describe your imagination...", lines=3, autofocus=True) - neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="Low quality...", lines=3) - - # 2. GENERATE BUTTON - generate_btn = gr.Button("✨ GENERATE", variant="primary", size="lg", elem_id="gen-btn") + shared.gradio['image_generate_btn'] = gr.Button("✨ GENERATE", variant="primary", size="lg", elem_id="gen-btn") gr.HTML("
") - # 3. DIMENSIONS gr.Markdown("### 📐 Dimensions") with gr.Row(): with gr.Column(): - width_slider = gr.Slider(256, 2048, value=1024, step=32, label="Width") - + shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=32, label="Width") with gr.Column(): - height_slider = gr.Slider(256, 2048, value=1024, step=32, label="Height") - - swap_btn = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width") + shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=32, label="Height") + shared.gradio['image_swap_btn'] = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width") with gr.Row(): - preset_radio = gr.Radio( + shared.gradio['image_aspect_ratio'] = gr.Radio( choices=["1:1 Square", "16:9 Cinema", "9:16 Mobile", "4:3 Photo", "Custom"], - value="1:1 Square", + value=shared.settings['image_aspect_ratio'], label="Aspect Ratio", interactive=True ) - # 4. SETTINGS & BATCHING gr.Markdown("### ⚙️ Config") with gr.Row(): with gr.Column(): - steps_slider = gr.Slider(1, 15, value=9, step=1, label="Steps") - seed_input = gr.Number(label="Seed", value=-1, precision=0, info="-1 = Random") - + shared.gradio['image_steps'] = gr.Slider(1, 15, value=shared.settings['image_steps'], step=1, label="Steps") + shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random") with gr.Column(): - batch_size_parallel = gr.Slider(1, 32, value=1, step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.") - batch_count_seq = gr.Slider(1, 128, value=1, step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.") + shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.") + shared.gradio['image_batch_count'] = gr.Slider(1, 128, value=shared.settings['image_batch_count'], step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.") - # === RIGHT COLUMN: VIEWPORT === with gr.Column(scale=6, min_width=500): with gr.Column(elem_classes=["viewport-container"]): - output_gallery = gr.Gallery( - label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True - ) + shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True) with gr.Row(): - used_seed = gr.Markdown(label="Info", interactive=False) + shared.gradio['image_used_seed'] = gr.Markdown(label="Info", interactive=False) - # TAB 2: HISTORY VIEWER + # TAB 2: GALLERY with gr.TabItem("Gallery"): with gr.Row(): - refresh_btn = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button") + shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button") + shared.gradio['image_history_gallery'] = gr.Gallery(label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True) - history_gallery = gr.Gallery( - label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True - ) - - # TAB 3: MODEL SETTINGS + # TAB 3: MODEL with gr.TabItem("Model"): with gr.Row(): with gr.Column(): with gr.Row(): - image_model_menu = gr.Dropdown( + shared.gradio['image_model_menu'] = gr.Dropdown( choices=utils.get_available_image_models(), - value=settings['model_name'], + value=shared.settings['image_model_menu'], label='Model', elem_classes='slim-dropdown' ) - image_refresh_models = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40) - image_load_model = gr.Button("Load", variant='primary', elem_classes='refresh-button') - image_unload_model = gr.Button("Unload", elem_classes='refresh-button') + shared.gradio['image_refresh_models'] = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40) + shared.gradio['image_load_model'] = gr.Button("Load", variant='primary', elem_classes='refresh-button') + shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button') gr.Markdown("## Settings") - with gr.Row(): with gr.Column(): - image_dtype = gr.Dropdown( + shared.gradio['image_dtype'] = gr.Dropdown( choices=['bfloat16', 'float16'], - value=settings['dtype'], + value=shared.settings['image_dtype'], label='Data Type', info='bfloat16 recommended for modern GPUs' ) - - image_attn_backend = gr.Dropdown( + shared.gradio['image_attn_backend'] = gr.Dropdown( choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], - value=settings['attn_backend'], + value=shared.settings['image_attn_backend'], label='Attention Backend', info='SDPA is default. Flash Attention requires compatible GPU.' ) - with gr.Column(): - image_compile = gr.Checkbox( - value=settings['compile_model'], + shared.gradio['image_compile'] = gr.Checkbox( + value=shared.settings['image_compile'], label='Compile Model', info='Faster inference after first run. First run will be slow.' ) - - image_cpu_offload = gr.Checkbox( - value=settings['cpu_offload'], + shared.gradio['image_cpu_offload'] = gr.Checkbox( + value=shared.settings['image_cpu_offload'], label='CPU Offload', info='Enable for low VRAM GPUs. Slower but uses less memory.' ) with gr.Column(): - image_download_path = gr.Textbox( + shared.gradio['image_download_path'] = gr.Textbox( label="Download model", placeholder="Tongyi-MAI/Z-Image-Turbo", - info="Enter the HuggingFace model path like Tongyi-MAI/Z-Image-Turbo. Use : for branch, e.g. Tongyi-MAI/Z-Image-Turbo:main" + info="Enter HuggingFace path. Use : for branch, e.g. user/model:main" ) - image_download_btn = gr.Button("Download", variant='primary') - image_model_status = gr.Markdown( - value=f"Model: **{settings['model_name']}** (not loaded)" if settings['model_name'] != 'None' else "No model selected" + shared.gradio['image_download_btn'] = gr.Button("Download", variant='primary') + shared.gradio['image_model_status'] = gr.Markdown( + value=f"Model: **{shared.settings['image_model_menu']}** (not loaded)" if shared.settings['image_model_menu'] != 'None' else "No model selected" ) - # === WIRING === - # Aspect ratio preset changes -> update dimensions - preset_radio.change( - fn=apply_aspect_ratio, - inputs=[preset_radio, width_slider, height_slider], - outputs=[width_slider, height_slider], - show_progress=False - ) +def create_event_handlers(): + # Dimension controls + shared.gradio['image_aspect_ratio'].change( + apply_aspect_ratio, + gradio('image_aspect_ratio', 'image_width', 'image_height'), + gradio('image_width', 'image_height'), + show_progress=False + ) - # Width slider changes -> update height (if not Custom) - width_slider.release( - fn=update_height_from_width, - inputs=[width_slider, preset_radio], - outputs=[height_slider], - show_progress=False - ) + shared.gradio['image_width'].release( + update_height_from_width, + gradio('image_width', 'image_aspect_ratio'), + gradio('image_height'), + show_progress=False + ) - # Height slider changes -> update width (if not Custom) - height_slider.release( - fn=update_width_from_height, - inputs=[height_slider, preset_radio], - outputs=[width_slider], - show_progress=False - ) + shared.gradio['image_height'].release( + update_width_from_height, + gradio('image_height', 'image_aspect_ratio'), + gradio('image_width'), + show_progress=False + ) - # Swap button -> swap dimensions and update aspect ratio - swap_btn.click( - fn=swap_dimensions_and_update_ratio, - inputs=[width_slider, height_slider, preset_radio], - outputs=[width_slider, height_slider, preset_radio], - show_progress=False - ) + shared.gradio['image_swap_btn'].click( + swap_dimensions_and_update_ratio, + gradio('image_width', 'image_height', 'image_aspect_ratio'), + gradio('image_width', 'image_height', 'image_aspect_ratio'), + show_progress=False + ) - # Generation - inputs = [prompt, neg_prompt, width_slider, height_slider, steps_slider, seed_input, batch_size_parallel, batch_count_seq] - outputs = [output_gallery, used_seed] + # Generation + shared.gradio['image_generate_btn'].click( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed')) - generate_btn.click( - fn=lambda *args: generate(*args, image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile), - inputs=inputs, - outputs=outputs - ) - prompt.submit( - fn=lambda *args: generate(*args, image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile), - inputs=inputs, - outputs=outputs - ) - neg_prompt.submit( - fn=lambda *args: generate(*args, image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile), - inputs=inputs, - outputs=outputs - ) + shared.gradio['image_prompt'].submit( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed')) - # Model tab events - image_refresh_models.click( - fn=lambda: gr.update(choices=utils.get_available_image_models()), - inputs=None, - outputs=[image_model_menu], - show_progress=False - ) + shared.gradio['image_neg_prompt'].submit( + ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( + generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed')) - image_load_model.click( - fn=load_image_model_wrapper, - inputs=[image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile], - outputs=[image_model_status], - show_progress=True - ) + # Model management + shared.gradio['image_refresh_models'].click( + lambda: gr.update(choices=utils.get_available_image_models()), + None, + gradio('image_model_menu'), + show_progress=False + ) - image_unload_model.click( - fn=unload_image_model_wrapper, - inputs=None, - outputs=[image_model_status], - show_progress=False - ) + shared.gradio['image_load_model'].click( + load_image_model_wrapper, + gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile'), + gradio('image_model_status'), + show_progress=True + ) - image_download_btn.click( - fn=download_image_model_wrapper, - inputs=[image_download_path], - outputs=[image_model_status, image_model_menu], - show_progress=True - ) + shared.gradio['image_unload_model'].click( + unload_image_model_wrapper, + None, + gradio('image_model_status'), + show_progress=False + ) - # History - refresh_btn.click(fn=get_history_images, inputs=None, outputs=history_gallery, show_progress=False) + shared.gradio['image_download_btn'].click( + download_image_model_wrapper, + gradio('image_download_path'), + gradio('image_model_status', 'image_model_menu'), + show_progress=True + ) + + # History + shared.gradio['image_refresh_history'].click( + get_history_images, + None, + gradio('image_history_gallery'), + show_progress=False + ) -def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq, - model_menu, dtype_dropdown, attn_dropdown, cpu_offload_checkbox, compile_checkbox): - """Generate images with the current model settings.""" +def generate(state): + model_name = state['image_model_menu'] - model_name = shared.image_model_name - - if model_name == 'None': + if not model_name or model_name == 'None': return [], "No image model selected. Go to the Model tab and select a model." - # Auto-load model if not loaded if shared.image_model is None: - # Get effective settings (CLI > yaml > defaults) - settings = get_effective_settings() - result = load_image_model( model_name, - dtype=settings['dtype'], - attn_backend=settings['attn_backend'], - cpu_offload=settings['cpu_offload'], - compile_model=settings['compile_model'] + dtype=state['image_dtype'], + attn_backend=state['image_attn_backend'], + cpu_offload=state['image_cpu_offload'], + compile_model=state['image_compile'] ) - if result is None: return [], f"Failed to load model `{model_name}`." + shared.image_model_name = model_name + + seed = state['image_seed'] if seed == -1: seed = np.random.randint(0, 2**32 - 1) generator = torch.Generator("cuda").manual_seed(int(seed)) all_images = [] - # Sequential loop (easier on VRAM) - for i in range(int(batch_count_seq)): - current_seed = seed + i - generator.manual_seed(int(current_seed)) - - # Parallel generation + for i in range(int(state['image_batch_count'])): + generator.manual_seed(int(seed + i)) batch_results = shared.image_model( - prompt=prompt, - negative_prompt=neg_prompt, - height=int(height), - width=int(width), - num_inference_steps=int(steps), + prompt=state['image_prompt'], + negative_prompt=state['image_neg_prompt'], + height=int(state['image_height']), + width=int(state['image_width']), + num_inference_steps=int(state['image_steps']), guidance_scale=0.0, - num_images_per_prompt=int(batch_size_parallel), + num_images_per_prompt=int(state['image_batch_size']), generator=generator, ).images - all_images.extend(batch_results) - # Save to disk - save_generated_images(all_images, prompt, seed) - + save_generated_images(all_images, state['image_prompt'], seed) return all_images, f"Seed: {seed}" def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model): - """Load model and save settings.""" - if model_name == 'None' or not model_name: + if not model_name or model_name == 'None': yield "No model selected" return try: yield f"Loading `{model_name}`..." - - # Unload existing model first unload_image_model() - # Load the new model result = load_image_model( model_name, dtype=dtype, @@ -411,29 +352,22 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi ) if result is not None: - # Save settings to yaml - save_image_model_settings(model_name, dtype, attn_backend, cpu_offload, compile_model) + shared.image_model_name = model_name yield f"✓ Loaded **{model_name}**" else: yield f"✗ Failed to load `{model_name}`" - except Exception: - exc = traceback.format_exc() - yield f"Error:\n```\n{exc}\n```" + yield f"Error:\n```\n{traceback.format_exc()}\n```" def unload_image_model_wrapper(): - """Unload model wrapper.""" unload_image_model() - if shared.image_model_name != 'None': return f"Model: **{shared.image_model_name}** (not loaded)" - else: - return "No model loaded" + return "No model loaded" def download_image_model_wrapper(model_path): - """Download a model from Hugging Face.""" from huggingface_hub import snapshot_download if not model_path: @@ -441,17 +375,15 @@ def download_image_model_wrapper(model_path): return try: - # Parse model name and branch if ':' in model_path: model_id, branch = model_path.rsplit(':', 1) else: model_id, branch = model_path, 'main' - # Output folder name (username_model format) folder_name = model_id.replace('/', '_') output_folder = Path(shared.args.image_model_dir) / folder_name - yield f"Downloading `{model_id}` (branch: {branch}) to `{output_folder}`...", gr.update() + yield f"Downloading `{model_id}` (branch: {branch})...", gr.update() snapshot_download( repo_id=model_id, @@ -460,48 +392,34 @@ def download_image_model_wrapper(model_path): local_dir_use_symlinks=False, ) - # Refresh the model list new_choices = utils.get_available_image_models() - yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name) - except Exception: - exc = traceback.format_exc() - yield f"Error:\n```\n{exc}\n```", gr.update() + yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update() def save_generated_images(images, prompt, seed): - """Save generated images to disk.""" date_str = datetime.now().strftime("%Y-%m-%d") folder_path = os.path.join("user_data", "image_outputs", date_str) os.makedirs(folder_path, exist_ok=True) - saved_paths = [] - for idx, img in enumerate(images): timestamp = datetime.now().strftime("%H-%M-%S") filename = f"{timestamp}_{seed}_{idx}.png" - full_path = os.path.join(folder_path, filename) - - img.save(full_path) - saved_paths.append(full_path) - - return saved_paths + img.save(os.path.join(folder_path, filename)) def get_history_images(): - """Scan the outputs folder and return all images, newest first.""" output_dir = os.path.join("user_data", "image_outputs") if not os.path.exists(output_dir): return [] image_files = [] - for root, dirs, files in os.walk(output_dir): + for root, _, files in os.walk(output_dir): for file in files: if file.endswith((".png", ".jpg", ".jpeg")): full_path = os.path.join(root, file) - mtime = os.path.getmtime(full_path) - image_files.append((full_path, mtime)) + image_files.append((full_path, os.path.getmtime(full_path))) image_files.sort(key=lambda x: x[1], reverse=True) return [x[0] for x in image_files] diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index a5e0f640..86adc229 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -434,66 +434,3 @@ def format_file_size(size_bytes): return f"{s:.2f} {size_names[i]}" else: return f"{s:.1f} {size_names[i]}" - - -def load_image_model_wrapper(selected_model): - """Wrapper for loading image models with status updates.""" - from modules.image_models import load_image_model, unload_image_model - - if selected_model == 'None' or not selected_model: - yield "No model selected" - return - - try: - yield f"Loading `{selected_model}`..." - unload_image_model() - result = load_image_model(selected_model) - - if result is not None: - yield f"Successfully loaded `{selected_model}`." - else: - yield f"Failed to load `{selected_model}`." - except Exception: - exc = traceback.format_exc() - yield exc.replace('\n', '\n\n') - - -def handle_unload_image_model_click(): - """Handler for the image model unload button.""" - from modules.image_models import unload_image_model - unload_image_model() - return "Image model unloaded" - - -def download_image_model_wrapper(custom_model): - """Download an image model from Hugging Face.""" - from huggingface_hub import snapshot_download - - if not custom_model: - yield "No model specified" - return - - try: - # Parse model name and branch - if ':' in custom_model: - model_name, branch = custom_model.rsplit(':', 1) - else: - model_name, branch = custom_model, 'main' - - # Output folder - output_folder = Path(shared.args.image_model_dir) / model_name.split('/')[-1] - - yield f"Downloading `{model_name}` (branch: {branch})..." - - snapshot_download( - repo_id=model_name, - revision=branch, - local_dir=output_folder, - local_dir_use_symlinks=False, - ) - - yield f"Model successfully saved to `{output_folder}/`." - - except Exception: - exc = traceback.format_exc() - yield exc.replace('\n', '\n\n') diff --git a/server.py b/server.py index 87bbdc4a..5a75e887 100644 --- a/server.py +++ b/server.py @@ -172,6 +172,7 @@ def create_interface(): ui_chat.create_event_handlers() ui_default.create_event_handlers() ui_notebook.create_event_handlers() + ui_image_generation.create_event_handlers() # Other events ui_file_saving.create_event_handlers() @@ -258,6 +259,9 @@ if __name__ == "__main__": if new_settings: shared.settings.update(new_settings) + # Apply CLI overrides for image model settings (CLI flags take precedence over saved settings) + shared.apply_image_model_cli_overrides() + # Fallback settings for models shared.model_config['.*'] = get_fallback_settings() shared.model_config.move_to_end('.*', last=False) # Move to the beginning