Image generation now functional

This commit is contained in:
oobabooga 2025-11-27 14:24:35 -08:00
parent 2f11b3040d
commit a873692234
3 changed files with 127 additions and 59 deletions

View file

@ -51,8 +51,12 @@ group.add_argument('--verbose', action='store_true', help='Print the prompts to
group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.')
# Image generation
group = parser.add_argument_group('Image model')
group.add_argument('--image-model', type=str, help='Name of the image model to load by default.')
group.add_argument('--image-model-dir', type=str, default='user_data/image_models', help='Path to directory with all the image models.')
group.add_argument('--image-dtype', type=str, default='bfloat16', choices=['bfloat16', 'float16'], help='Data type for image model.')
group.add_argument('--image-attn-backend', type=str, default='sdpa', choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.')
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
# Model loader
group = parser.add_argument_group('Model loader')

View file

@ -1,6 +1,13 @@
import gradio as gr
import os
from modules.utils import resolve_model_path
from datetime import datetime
import gradio as gr
import numpy as np
import torch
from modules import shared
from modules.image_models import load_image_model, unload_image_model
def create_ui():
with gr.Tab("Image AI", elem_id="image-ai-tab"):
@ -92,26 +99,33 @@ def create_ui():
def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq):
if engine.pipe is None:
load_pipeline("SDPA", False, False, "bfloat16")
if seed == -1: seed = np.random.randint(0, 2**32 - 1)
# We use a base generator. For sequential batches, we might increment seed if desired,
# but here we keep the base seed logic consistent.
import numpy as np
import torch
from modules import shared
from modules.image_models import load_image_model
# Auto-load model if not loaded
if shared.image_model is None:
if shared.image_model_name == 'None':
return [], "No image model selected. Please load a model first."
load_image_model(shared.image_model_name)
if shared.image_model is None:
return [], "Failed to load image model."
if seed == -1:
seed = np.random.randint(0, 2**32 - 1)
generator = torch.Generator("cuda").manual_seed(int(seed))
all_images = []
# SEQUENTIAL LOOP (Easy on VRAM)
for i in range(batch_count_seq):
# Update seed for subsequent batches so they aren't identical
# Sequential loop (easier on VRAM)
for i in range(int(batch_count_seq)):
current_seed = seed + i
generator.manual_seed(int(current_seed))
# PARALLEL GENERATION (Fast, Heavy VRAM)
# diffusers handles 'num_images_per_prompt' for parallel execution
batch_results = engine.pipe(
# Parallel generation
batch_results = shared.image_model(
prompt=prompt,
negative_prompt=neg_prompt,
height=int(height),
@ -121,13 +135,13 @@ def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel
num_images_per_prompt=int(batch_size_parallel),
generator=generator,
).images
all_images.extend(batch_results)
# Save to disk
save_generated_images(all_images, prompt, seed)
return all_images, seed
return all_images, f"Seed: {seed}"
# --- File Saving Logic ---
@ -173,38 +187,3 @@ def get_history_images():
# Sort by time, newest first
image_files.sort(key=lambda x: x[1], reverse=True)
return [x[0] for x in image_files]
def load_pipeline(attn_backend, compile_model, offload_cpu, dtype_str):
dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
target_dtype = dtype_map.get(dtype_str, torch.bfloat16)
if engine.pipe is not None and engine.config["backend"] == attn_backend:
return gr.Info("Pipeline ready.")
try:
gr.Info(f"Loading Model ({attn_backend})...")
pipe = ZImagePipeline.from_pretrained(
engine.config["model_id"],
torch_dtype=target_dtype,
low_cpu_mem_usage=False,
)
if not offload_cpu: pipe.to("cuda")
if attn_backend == "Flash Attention 2":
pipe.transformer.set_attention_backend("flash")
elif attn_backend == "Flash Attention 3":
pipe.transformer.set_attention_backend("_flash_3")
if compile_model:
gr.Warning("Compiling... First run will be slow.")
pipe.transformer.compile()
if offload_cpu: pipe.enable_model_cpu_offload()
engine.pipe = pipe
engine.config["backend"] = attn_backend
return gr.Success("System Ready.")
except Exception as e:
return gr.Error(f"Init Failed: {str(e)}")

View file

@ -140,7 +140,7 @@ def create_ui():
with gr.Column():
with gr.Row():
shared.gradio['image_model_menu'] = gr.Dropdown(choices=utils.get_available_image_models(), value=lambda: shared.image_model_name, label='Model', elem_classes='slim-dropdown', interactive=not mu)
ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu)
ui.create_refresh_button(shared.gradio['image_model_menu'], lambda: None, lambda: {'choices': utils.get_available_image_models()}, 'refresh-button', interactive=not mu)
shared.gradio['image_load_model'] = gr.Button("Load", elem_classes='refresh-button', interactive=not mu)
shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button', interactive=not mu)
shared.gradio['image_save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button', interactive=not mu)
@ -169,7 +169,7 @@ def create_ui():
shared.gradio['image_download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu)
with gr.Row():
shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.image_model_name == 'None' else 'Ready')
def create_event_handlers():
@ -220,6 +220,28 @@ def create_event_handlers():
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
# Image model event handlers
shared.gradio['image_load_model'].click(
load_image_model_wrapper,
gradio('image_model_menu'),
gradio('image_model_status'),
show_progress=True
)
shared.gradio['image_unload_model'].click(
handle_unload_image_model_click,
None,
gradio('image_model_status'),
show_progress=False
)
shared.gradio['image_download_model_button'].click(
download_image_model_wrapper,
gradio('image_custom_model_menu'),
gradio('image_model_status'),
show_progress=True
)
def load_model_wrapper(selected_model, loader, autoload=False):
try:
@ -471,3 +493,66 @@ def format_file_size(size_bytes):
return f"{s:.2f} {size_names[i]}"
else:
return f"{s:.1f} {size_names[i]}"
def load_image_model_wrapper(selected_model):
"""Wrapper for loading image models with status updates."""
from modules.image_models import load_image_model, unload_image_model
if selected_model == 'None' or not selected_model:
yield "No model selected"
return
try:
yield f"Loading `{selected_model}`..."
unload_image_model()
result = load_image_model(selected_model)
if result is not None:
yield f"Successfully loaded `{selected_model}`."
else:
yield f"Failed to load `{selected_model}`."
except Exception:
exc = traceback.format_exc()
yield exc.replace('\n', '\n\n')
def handle_unload_image_model_click():
"""Handler for the image model unload button."""
from modules.image_models import unload_image_model
unload_image_model()
return "Image model unloaded"
def download_image_model_wrapper(custom_model):
"""Download an image model from Hugging Face."""
from huggingface_hub import snapshot_download
if not custom_model:
yield "No model specified"
return
try:
# Parse model name and branch
if ':' in custom_model:
model_name, branch = custom_model.rsplit(':', 1)
else:
model_name, branch = custom_model, 'main'
# Output folder
output_folder = Path(shared.args.image_model_dir) / model_name.split('/')[-1]
yield f"Downloading `{model_name}` (branch: {branch})..."
snapshot_download(
repo_id=model_name,
revision=branch,
local_dir=output_folder,
local_dir_use_symlinks=False,
)
yield f"Model successfully saved to `{output_folder}/`."
except Exception:
exc = traceback.format_exc()
yield exc.replace('\n', '\n\n')