Add functions

This commit is contained in:
oobabooga 2025-11-27 13:53:46 -08:00
parent aa63c612de
commit 2f11b3040d
2 changed files with 124 additions and 7 deletions

View file

@ -1,6 +1,6 @@
import gradio as gr import gradio as gr
import os import os
from modules.utils import resolve_model_path
def create_ui(): def create_ui():
with gr.Tab("Image AI", elem_id="image-ai-tab"): with gr.Tab("Image AI", elem_id="image-ai-tab"):
@ -78,9 +78,9 @@ def create_ui():
inputs = [prompt, neg_prompt, width_slider, height_slider, steps_slider, seed_input, batch_size_parallel, batch_count_seq] inputs = [prompt, neg_prompt, width_slider, height_slider, steps_slider, seed_input, batch_size_parallel, batch_count_seq]
outputs = [output_gallery, used_seed] outputs = [output_gallery, used_seed]
# generate_btn.click(fn=generate, inputs=inputs, outputs=outputs) generate_btn.click(fn=generate, inputs=inputs, outputs=outputs)
# prompt.submit(fn=generate, inputs=inputs, outputs=outputs) prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
# neg_prompt.submit(fn=generate, inputs=inputs, outputs=outputs) neg_prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
# System # System
# load_btn.click(fn=load_pipeline, inputs=[backend_drop, compile_check, offload_check, gr.State("bfloat16")], outputs=None) # load_btn.click(fn=load_pipeline, inputs=[backend_drop, compile_check, offload_check, gr.State("bfloat16")], outputs=None)
@ -91,5 +91,120 @@ def create_ui():
# demo.load(fn=get_history_images, inputs=None, outputs=history_gallery) # demo.load(fn=get_history_images, inputs=None, outputs=history_gallery)
def create_event_handlers(): def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq):
pass if engine.pipe is None:
load_pipeline("SDPA", False, False, "bfloat16")
if seed == -1: seed = np.random.randint(0, 2**32 - 1)
# We use a base generator. For sequential batches, we might increment seed if desired,
# but here we keep the base seed logic consistent.
generator = torch.Generator("cuda").manual_seed(int(seed))
all_images = []
# SEQUENTIAL LOOP (Easy on VRAM)
for i in range(batch_count_seq):
# Update seed for subsequent batches so they aren't identical
current_seed = seed + i
generator.manual_seed(int(current_seed))
# PARALLEL GENERATION (Fast, Heavy VRAM)
# diffusers handles 'num_images_per_prompt' for parallel execution
batch_results = engine.pipe(
prompt=prompt,
negative_prompt=neg_prompt,
height=int(height),
width=int(width),
num_inference_steps=int(steps),
guidance_scale=0.0,
num_images_per_prompt=int(batch_size_parallel),
generator=generator,
).images
all_images.extend(batch_results)
# Save to disk
save_generated_images(all_images, prompt, seed)
return all_images, seed
# --- File Saving Logic ---
def save_generated_images(images, prompt, seed):
# Create folder structure: outputs/YYYY-MM-DD/
date_str = datetime.now().strftime("%Y-%m-%d")
folder_path = os.path.join("outputs", date_str)
os.makedirs(folder_path, exist_ok=True)
saved_paths = []
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
# Filename: Time_Seed_Index.png
filename = f"{timestamp}_{seed}_{idx}.png"
full_path = os.path.join(folder_path, filename)
# Save image
img.save(full_path)
saved_paths.append(full_path)
# Optional: Save prompt metadata in a text file next to it?
# For now, we just save the image.
return saved_paths
# --- History Logic ---
def get_history_images():
"""Scans the outputs folder and returns all images, newest first"""
if not os.path.exists("outputs"):
return []
image_files = []
for root, dirs, files in os.walk("outputs"):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
full_path = os.path.join(root, file)
# Get creation time for sorting
mtime = os.path.getmtime(full_path)
image_files.append((full_path, mtime))
# Sort by time, newest first
image_files.sort(key=lambda x: x[1], reverse=True)
return [x[0] for x in image_files]
def load_pipeline(attn_backend, compile_model, offload_cpu, dtype_str):
dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
target_dtype = dtype_map.get(dtype_str, torch.bfloat16)
if engine.pipe is not None and engine.config["backend"] == attn_backend:
return gr.Info("Pipeline ready.")
try:
gr.Info(f"Loading Model ({attn_backend})...")
pipe = ZImagePipeline.from_pretrained(
engine.config["model_id"],
torch_dtype=target_dtype,
low_cpu_mem_usage=False,
)
if not offload_cpu: pipe.to("cuda")
if attn_backend == "Flash Attention 2":
pipe.transformer.set_attention_backend("flash")
elif attn_backend == "Flash Attention 3":
pipe.transformer.set_attention_backend("_flash_3")
if compile_model:
gr.Warning("Compiling... First run will be slow.")
pipe.transformer.compile()
if offload_cpu: pipe.enable_model_cpu_offload()
engine.pipe = pipe
engine.config["backend"] = attn_backend
return gr.Success("System Ready.")
except Exception as e:
return gr.Error(f"Init Failed: {str(e)}")

View file

@ -86,7 +86,7 @@ def check_model_loaded():
return True, None return True, None
def resolve_model_path(model_name_or_path): def resolve_model_path(model_name_or_path, image_model=False):
""" """
Resolves a model path, checking for a direct path Resolves a model path, checking for a direct path
before the default models directory. before the default models directory.
@ -95,6 +95,8 @@ def resolve_model_path(model_name_or_path):
path_candidate = Path(model_name_or_path) path_candidate = Path(model_name_or_path)
if path_candidate.exists(): if path_candidate.exists():
return path_candidate return path_candidate
elif image_model:
return Path(f'{shared.args.image_model_dir}/{model_name_or_path}')
else: else:
return Path(f'{shared.args.model_dir}/{model_name_or_path}') return Path(f'{shared.args.model_dir}/{model_name_or_path}')