diff --git a/modules/shared.py b/modules/shared.py index fda4ece6..66666b75 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -51,8 +51,12 @@ group.add_argument('--verbose', action='store_true', help='Print the prompts to group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.') # Image generation +group = parser.add_argument_group('Image model') group.add_argument('--image-model', type=str, help='Name of the image model to load by default.') group.add_argument('--image-model-dir', type=str, default='user_data/image_models', help='Path to directory with all the image models.') +group.add_argument('--image-dtype', type=str, default='bfloat16', choices=['bfloat16', 'float16'], help='Data type for image model.') +group.add_argument('--image-attn-backend', type=str, default='sdpa', choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.') +group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.') # Model loader group = parser.add_argument_group('Model loader') diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py index b59a8458..25bfeb21 100644 --- a/modules/ui_image_generation.py +++ b/modules/ui_image_generation.py @@ -1,6 +1,13 @@ -import gradio as gr import os -from modules.utils import resolve_model_path +from datetime import datetime + +import gradio as gr +import numpy as np +import torch + +from modules import shared +from modules.image_models import load_image_model, unload_image_model + def create_ui(): with gr.Tab("Image AI", elem_id="image-ai-tab"): @@ -92,26 +99,33 @@ def create_ui(): def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq): - if engine.pipe is None: - load_pipeline("SDPA", False, False, "bfloat16") - - if seed == -1: seed = np.random.randint(0, 2**32 - 1) - - # We use a base generator. For sequential batches, we might increment seed if desired, - # but here we keep the base seed logic consistent. + import numpy as np + import torch + from modules import shared + from modules.image_models import load_image_model + + # Auto-load model if not loaded + if shared.image_model is None: + if shared.image_model_name == 'None': + return [], "No image model selected. Please load a model first." + load_image_model(shared.image_model_name) + + if shared.image_model is None: + return [], "Failed to load image model." + + if seed == -1: + seed = np.random.randint(0, 2**32 - 1) + generator = torch.Generator("cuda").manual_seed(int(seed)) - all_images = [] - - # SEQUENTIAL LOOP (Easy on VRAM) - for i in range(batch_count_seq): - # Update seed for subsequent batches so they aren't identical + + # Sequential loop (easier on VRAM) + for i in range(int(batch_count_seq)): current_seed = seed + i generator.manual_seed(int(current_seed)) - - # PARALLEL GENERATION (Fast, Heavy VRAM) - # diffusers handles 'num_images_per_prompt' for parallel execution - batch_results = engine.pipe( + + # Parallel generation + batch_results = shared.image_model( prompt=prompt, negative_prompt=neg_prompt, height=int(height), @@ -121,13 +135,13 @@ def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel num_images_per_prompt=int(batch_size_parallel), generator=generator, ).images - + all_images.extend(batch_results) - + # Save to disk save_generated_images(all_images, prompt, seed) - - return all_images, seed + + return all_images, f"Seed: {seed}" # --- File Saving Logic --- @@ -173,38 +187,3 @@ def get_history_images(): # Sort by time, newest first image_files.sort(key=lambda x: x[1], reverse=True) return [x[0] for x in image_files] - - -def load_pipeline(attn_backend, compile_model, offload_cpu, dtype_str): - dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16} - target_dtype = dtype_map.get(dtype_str, torch.bfloat16) - - if engine.pipe is not None and engine.config["backend"] == attn_backend: - return gr.Info("Pipeline ready.") - - try: - gr.Info(f"Loading Model ({attn_backend})...") - pipe = ZImagePipeline.from_pretrained( - engine.config["model_id"], - torch_dtype=target_dtype, - low_cpu_mem_usage=False, - ) - if not offload_cpu: pipe.to("cuda") - - if attn_backend == "Flash Attention 2": - pipe.transformer.set_attention_backend("flash") - elif attn_backend == "Flash Attention 3": - pipe.transformer.set_attention_backend("_flash_3") - - if compile_model: - gr.Warning("Compiling... First run will be slow.") - pipe.transformer.compile() - - if offload_cpu: pipe.enable_model_cpu_offload() - - engine.pipe = pipe - engine.config["backend"] = attn_backend - return gr.Success("System Ready.") - except Exception as e: - return gr.Error(f"Init Failed: {str(e)}") - diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index dbbd3274..cb3508f8 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -140,7 +140,7 @@ def create_ui(): with gr.Column(): with gr.Row(): shared.gradio['image_model_menu'] = gr.Dropdown(choices=utils.get_available_image_models(), value=lambda: shared.image_model_name, label='Model', elem_classes='slim-dropdown', interactive=not mu) - ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu) + ui.create_refresh_button(shared.gradio['image_model_menu'], lambda: None, lambda: {'choices': utils.get_available_image_models()}, 'refresh-button', interactive=not mu) shared.gradio['image_load_model'] = gr.Button("Load", elem_classes='refresh-button', interactive=not mu) shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button', interactive=not mu) shared.gradio['image_save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button', interactive=not mu) @@ -169,7 +169,7 @@ def create_ui(): shared.gradio['image_download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu) with gr.Row(): - shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready') + shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.image_model_name == 'None' else 'Ready') def create_event_handlers(): @@ -220,6 +220,28 @@ def create_event_handlers(): shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True) shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True) + # Image model event handlers + shared.gradio['image_load_model'].click( + load_image_model_wrapper, + gradio('image_model_menu'), + gradio('image_model_status'), + show_progress=True + ) + + shared.gradio['image_unload_model'].click( + handle_unload_image_model_click, + None, + gradio('image_model_status'), + show_progress=False + ) + + shared.gradio['image_download_model_button'].click( + download_image_model_wrapper, + gradio('image_custom_model_menu'), + gradio('image_model_status'), + show_progress=True + ) + def load_model_wrapper(selected_model, loader, autoload=False): try: @@ -471,3 +493,66 @@ def format_file_size(size_bytes): return f"{s:.2f} {size_names[i]}" else: return f"{s:.1f} {size_names[i]}" + + +def load_image_model_wrapper(selected_model): + """Wrapper for loading image models with status updates.""" + from modules.image_models import load_image_model, unload_image_model + + if selected_model == 'None' or not selected_model: + yield "No model selected" + return + + try: + yield f"Loading `{selected_model}`..." + unload_image_model() + result = load_image_model(selected_model) + + if result is not None: + yield f"Successfully loaded `{selected_model}`." + else: + yield f"Failed to load `{selected_model}`." + except Exception: + exc = traceback.format_exc() + yield exc.replace('\n', '\n\n') + + +def handle_unload_image_model_click(): + """Handler for the image model unload button.""" + from modules.image_models import unload_image_model + unload_image_model() + return "Image model unloaded" + + +def download_image_model_wrapper(custom_model): + """Download an image model from Hugging Face.""" + from huggingface_hub import snapshot_download + + if not custom_model: + yield "No model specified" + return + + try: + # Parse model name and branch + if ':' in custom_model: + model_name, branch = custom_model.rsplit(':', 1) + else: + model_name, branch = custom_model, 'main' + + # Output folder + output_folder = Path(shared.args.image_model_dir) / model_name.split('/')[-1] + + yield f"Downloading `{model_name}` (branch: {branch})..." + + snapshot_download( + repo_id=model_name, + revision=branch, + local_dir=output_folder, + local_dir_use_symlinks=False, + ) + + yield f"Model successfully saved to `{output_folder}/`." + + except Exception: + exc = traceback.format_exc() + yield exc.replace('\n', '\n\n')