diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py index 5b5c624d..b59a8458 100644 --- a/modules/ui_image_generation.py +++ b/modules/ui_image_generation.py @@ -1,6 +1,6 @@ import gradio as gr import os - +from modules.utils import resolve_model_path def create_ui(): with gr.Tab("Image AI", elem_id="image-ai-tab"): @@ -78,9 +78,9 @@ def create_ui(): inputs = [prompt, neg_prompt, width_slider, height_slider, steps_slider, seed_input, batch_size_parallel, batch_count_seq] outputs = [output_gallery, used_seed] - # generate_btn.click(fn=generate, inputs=inputs, outputs=outputs) - # prompt.submit(fn=generate, inputs=inputs, outputs=outputs) - # neg_prompt.submit(fn=generate, inputs=inputs, outputs=outputs) + generate_btn.click(fn=generate, inputs=inputs, outputs=outputs) + prompt.submit(fn=generate, inputs=inputs, outputs=outputs) + neg_prompt.submit(fn=generate, inputs=inputs, outputs=outputs) # System # load_btn.click(fn=load_pipeline, inputs=[backend_drop, compile_check, offload_check, gr.State("bfloat16")], outputs=None) @@ -91,5 +91,120 @@ def create_ui(): # demo.load(fn=get_history_images, inputs=None, outputs=history_gallery) -def create_event_handlers(): - pass +def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq): + if engine.pipe is None: + load_pipeline("SDPA", False, False, "bfloat16") + + if seed == -1: seed = np.random.randint(0, 2**32 - 1) + + # We use a base generator. For sequential batches, we might increment seed if desired, + # but here we keep the base seed logic consistent. + generator = torch.Generator("cuda").manual_seed(int(seed)) + + all_images = [] + + # SEQUENTIAL LOOP (Easy on VRAM) + for i in range(batch_count_seq): + # Update seed for subsequent batches so they aren't identical + current_seed = seed + i + generator.manual_seed(int(current_seed)) + + # PARALLEL GENERATION (Fast, Heavy VRAM) + # diffusers handles 'num_images_per_prompt' for parallel execution + batch_results = engine.pipe( + prompt=prompt, + negative_prompt=neg_prompt, + height=int(height), + width=int(width), + num_inference_steps=int(steps), + guidance_scale=0.0, + num_images_per_prompt=int(batch_size_parallel), + generator=generator, + ).images + + all_images.extend(batch_results) + + # Save to disk + save_generated_images(all_images, prompt, seed) + + return all_images, seed + + +# --- File Saving Logic --- +def save_generated_images(images, prompt, seed): + # Create folder structure: outputs/YYYY-MM-DD/ + date_str = datetime.now().strftime("%Y-%m-%d") + folder_path = os.path.join("outputs", date_str) + os.makedirs(folder_path, exist_ok=True) + + saved_paths = [] + + for idx, img in enumerate(images): + timestamp = datetime.now().strftime("%H-%M-%S") + # Filename: Time_Seed_Index.png + filename = f"{timestamp}_{seed}_{idx}.png" + full_path = os.path.join(folder_path, filename) + + # Save image + img.save(full_path) + saved_paths.append(full_path) + + # Optional: Save prompt metadata in a text file next to it? + # For now, we just save the image. + + return saved_paths + + +# --- History Logic --- +def get_history_images(): + """Scans the outputs folder and returns all images, newest first""" + if not os.path.exists("outputs"): + return [] + + image_files = [] + for root, dirs, files in os.walk("outputs"): + for file in files: + if file.endswith((".png", ".jpg", ".jpeg")): + full_path = os.path.join(root, file) + # Get creation time for sorting + mtime = os.path.getmtime(full_path) + image_files.append((full_path, mtime)) + + # Sort by time, newest first + image_files.sort(key=lambda x: x[1], reverse=True) + return [x[0] for x in image_files] + + +def load_pipeline(attn_backend, compile_model, offload_cpu, dtype_str): + dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16} + target_dtype = dtype_map.get(dtype_str, torch.bfloat16) + + if engine.pipe is not None and engine.config["backend"] == attn_backend: + return gr.Info("Pipeline ready.") + + try: + gr.Info(f"Loading Model ({attn_backend})...") + pipe = ZImagePipeline.from_pretrained( + engine.config["model_id"], + torch_dtype=target_dtype, + low_cpu_mem_usage=False, + ) + if not offload_cpu: pipe.to("cuda") + + if attn_backend == "Flash Attention 2": + pipe.transformer.set_attention_backend("flash") + elif attn_backend == "Flash Attention 3": + pipe.transformer.set_attention_backend("_flash_3") + + if compile_model: + gr.Warning("Compiling... First run will be slow.") + pipe.transformer.compile() + + if offload_cpu: pipe.enable_model_cpu_offload() + + engine.pipe = pipe + engine.config["backend"] = attn_backend + return gr.Success("System Ready.") + except Exception as e: + return gr.Error(f"Init Failed: {str(e)}") + diff --git a/modules/utils.py b/modules/utils.py index 5315d0f8..13a814ae 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -86,7 +86,7 @@ def check_model_loaded(): return True, None -def resolve_model_path(model_name_or_path): +def resolve_model_path(model_name_or_path, image_model=False): """ Resolves a model path, checking for a direct path before the default models directory. @@ -95,6 +95,8 @@ def resolve_model_path(model_name_or_path): path_candidate = Path(model_name_or_path) if path_candidate.exists(): return path_candidate + elif image_model: + return Path(f'{shared.args.image_model_dir}/{model_name_or_path}') else: return Path(f'{shared.args.model_dir}/{model_name_or_path}')