diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py index 030379fd..749ca981 100644 --- a/modules/ui_image_generation.py +++ b/modules/ui_image_generation.py @@ -38,19 +38,19 @@ def clamp(value, min_val, max_val): def apply_aspect_ratio(aspect_ratio, current_width, current_height): """ Apply an aspect ratio preset. - + Logic to prevent dimension creep: - For tall ratios (like 9:16): keep width fixed, calculate height - - For wide ratios (like 16:9): keep height fixed, calculate width + - For wide ratios (like 16:9): keep height fixed, calculate width - For square (1:1): use the smaller of the current dimensions - + Returns (new_width, new_height). """ if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS: return current_width, current_height - + w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] - + if w_ratio == h_ratio: # Square ratio - use the smaller current dimension to prevent creep base = min(current_width, current_height) @@ -64,11 +64,11 @@ def apply_aspect_ratio(aspect_ratio, current_width, current_height): # Wide ratio (like 16:9) - height is the smaller side, keep it fixed new_height = current_height new_width = round_to_step(current_height * w_ratio / h_ratio) - + # Clamp to slider bounds new_width = clamp(new_width, 256, 2048) new_height = clamp(new_height, 256, 2048) - + return int(new_width), int(new_height) @@ -76,11 +76,11 @@ def update_height_from_width(width, aspect_ratio): """Update height when width changes (if not Custom).""" if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS: return gr.update() - + w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] new_height = round_to_step(width * h_ratio / w_ratio) new_height = clamp(new_height, 256, 2048) - + return int(new_height) @@ -88,18 +88,18 @@ def update_width_from_height(height, aspect_ratio): """Update width when height changes (if not Custom).""" if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS: return gr.update() - + w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] new_width = round_to_step(height * w_ratio / h_ratio) new_width = clamp(new_width, 256, 2048) - + return int(new_width) def swap_dimensions_and_update_ratio(width, height, aspect_ratio): """Swap dimensions and update aspect ratio to match (or set to Custom).""" new_width, new_height = height, width - + # Try to find a matching aspect ratio for the swapped dimensions new_ratio = "Custom" for name, ratios in ASPECT_RATIOS.items(): @@ -111,27 +111,27 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio): if abs(expected_height - new_height) < STEP: new_ratio = name break - + return new_width, new_height, new_ratio def create_ui(): # Get effective settings (CLI > yaml > defaults) settings = get_effective_settings() - + # Update shared state (but don't load the model yet) if settings['model_name'] != 'None': shared.image_model_name = settings['model_name'] - + with gr.Tab("Image AI", elem_id="image-ai-tab"): with gr.Tabs(): # TAB 1: GENERATION STUDIO with gr.TabItem("Generate"): with gr.Row(): - + # === LEFT COLUMN: CONTROLS === with gr.Column(scale=4, min_width=350): - + # 1. PROMPT prompt = gr.Textbox(label="Prompt", placeholder="Describe your imagination...", lines=3, autofocus=True) neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="Low quality...", lines=3) @@ -170,12 +170,12 @@ def create_ui(): with gr.Column(): batch_size_parallel = gr.Slider(1, 32, value=1, step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.") batch_count_seq = gr.Slider(1, 128, value=1, step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.") - + # === RIGHT COLUMN: VIEWPORT === with gr.Column(scale=6, min_width=500): with gr.Column(elem_classes=["viewport-container"]): output_gallery = gr.Gallery( - label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True + label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True ) with gr.Row(): used_seed = gr.Markdown(label="Info", interactive=False) @@ -203,9 +203,9 @@ def create_ui(): image_refresh_models = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40) image_load_model = gr.Button("Load", variant='primary', elem_classes='refresh-button') image_unload_model = gr.Button("Unload", elem_classes='refresh-button') - + gr.Markdown("## Settings") - + with gr.Row(): with gr.Column(): image_dtype = gr.Dropdown( @@ -214,14 +214,14 @@ def create_ui(): label='Data Type', info='bfloat16 recommended for modern GPUs' ) - + image_attn_backend = gr.Dropdown( choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], value=settings['attn_backend'], label='Attention Backend', info='SDPA is default. Flash Attention requires compatible GPU.' ) - + with gr.Column(): image_compile = gr.Checkbox( value=settings['compile_model'], @@ -234,7 +234,7 @@ def create_ui(): label='CPU Offload', info='Enable for low VRAM GPUs. Slower but uses less memory.' ) - + with gr.Column(): image_download_path = gr.Textbox( label="Download model", @@ -247,7 +247,7 @@ def create_ui(): ) # === WIRING === - + # Aspect ratio preset changes -> update dimensions preset_radio.change( fn=apply_aspect_ratio, @@ -255,7 +255,7 @@ def create_ui(): outputs=[width_slider, height_slider], show_progress=False ) - + # Width slider changes -> update height (if not Custom) width_slider.release( fn=update_height_from_width, @@ -263,7 +263,7 @@ def create_ui(): outputs=[height_slider], show_progress=False ) - + # Height slider changes -> update width (if not Custom) height_slider.release( fn=update_width_from_height, @@ -271,7 +271,7 @@ def create_ui(): outputs=[width_slider], show_progress=False ) - + # Swap button -> swap dimensions and update aspect ratio swap_btn.click( fn=swap_dimensions_and_update_ratio, @@ -299,7 +299,7 @@ def create_ui(): inputs=inputs, outputs=outputs ) - + # Model tab events image_refresh_models.click( fn=lambda: gr.update(choices=utils.get_available_image_models()), @@ -307,28 +307,28 @@ def create_ui(): outputs=[image_model_menu], show_progress=False ) - + image_load_model.click( fn=load_image_model_wrapper, inputs=[image_model_menu, image_dtype, image_attn_backend, image_cpu_offload, image_compile], outputs=[image_model_status], show_progress=True ) - + image_unload_model.click( fn=unload_image_model_wrapper, inputs=None, outputs=[image_model_status], show_progress=False ) - + image_download_btn.click( fn=download_image_model_wrapper, inputs=[image_download_path], outputs=[image_model_status, image_model_menu], show_progress=True ) - + # History refresh_btn.click(fn=get_history_images, inputs=None, outputs=history_gallery, show_progress=False) @@ -336,40 +336,39 @@ def create_ui(): def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq, model_menu, dtype_dropdown, attn_dropdown, cpu_offload_checkbox, compile_checkbox): """Generate images with the current model settings.""" - - # Get current UI values (these are Gradio components, we need their values) + model_name = shared.image_model_name - + if model_name == 'None': return [], "No image model selected. Go to the Model tab and select a model." - + # Auto-load model if not loaded if shared.image_model is None: - # Load saved settings for the model - saved_settings = load_image_model_settings() - + # Get effective settings (CLI > yaml > defaults) + settings = get_effective_settings() + result = load_image_model( model_name, - dtype=saved_settings['dtype'], - attn_backend=saved_settings['attn_backend'], - cpu_offload=saved_settings['cpu_offload'], - compile_model=saved_settings['compile_model'] + dtype=settings['dtype'], + attn_backend=settings['attn_backend'], + cpu_offload=settings['cpu_offload'], + compile_model=settings['compile_model'] ) - + if result is None: return [], f"Failed to load model `{model_name}`." - + if seed == -1: seed = np.random.randint(0, 2**32 - 1) - + generator = torch.Generator("cuda").manual_seed(int(seed)) all_images = [] - + # Sequential loop (easier on VRAM) for i in range(int(batch_count_seq)): current_seed = seed + i generator.manual_seed(int(current_seed)) - + # Parallel generation batch_results = shared.image_model( prompt=prompt, @@ -381,12 +380,12 @@ def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel num_images_per_prompt=int(batch_size_parallel), generator=generator, ).images - + all_images.extend(batch_results) - + # Save to disk save_generated_images(all_images, prompt, seed) - + return all_images, f"Seed: {seed}" @@ -395,13 +394,13 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi if model_name == 'None' or not model_name: yield "No model selected" return - + try: yield f"Loading `{model_name}`..." - + # Unload existing model first unload_image_model() - + # Load the new model result = load_image_model( model_name, @@ -410,14 +409,14 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi cpu_offload=cpu_offload, compile_model=compile_model ) - + if result is not None: # Save settings to yaml save_image_model_settings(model_name, dtype, attn_backend, cpu_offload, compile_model) yield f"✓ Loaded **{model_name}**" else: yield f"✗ Failed to load `{model_name}`" - + except Exception: exc = traceback.format_exc() yield f"Error:\n```\n{exc}\n```" @@ -426,7 +425,7 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi def unload_image_model_wrapper(): """Unload model wrapper.""" unload_image_model() - + if shared.image_model_name != 'None': return f"Model: **{shared.image_model_name}** (not loaded)" else: @@ -436,36 +435,36 @@ def unload_image_model_wrapper(): def download_image_model_wrapper(model_path): """Download a model from Hugging Face.""" from huggingface_hub import snapshot_download - + if not model_path: yield "No model specified", gr.update() return - + try: # Parse model name and branch if ':' in model_path: model_id, branch = model_path.rsplit(':', 1) else: model_id, branch = model_path, 'main' - + # Output folder name folder_name = model_id.split('/')[-1] output_folder = Path(shared.args.image_model_dir) / folder_name - + yield f"Downloading `{model_id}` (branch: {branch})...", gr.update() - + snapshot_download( repo_id=model_id, revision=branch, local_dir=output_folder, local_dir_use_symlinks=False, ) - + # Refresh the model list new_choices = utils.get_available_image_models() - + yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name) - + except Exception: exc = traceback.format_exc() yield f"Error:\n```\n{exc}\n```", gr.update()