mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Add functions
This commit is contained in:
parent
aa63c612de
commit
2f11b3040d
|
|
@ -1,6 +1,6 @@
|
|||
import gradio as gr
|
||||
import os
|
||||
|
||||
from modules.utils import resolve_model_path
|
||||
|
||||
def create_ui():
|
||||
with gr.Tab("Image AI", elem_id="image-ai-tab"):
|
||||
|
|
@ -78,9 +78,9 @@ def create_ui():
|
|||
inputs = [prompt, neg_prompt, width_slider, height_slider, steps_slider, seed_input, batch_size_parallel, batch_count_seq]
|
||||
outputs = [output_gallery, used_seed]
|
||||
|
||||
# generate_btn.click(fn=generate, inputs=inputs, outputs=outputs)
|
||||
# prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
|
||||
# neg_prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
|
||||
generate_btn.click(fn=generate, inputs=inputs, outputs=outputs)
|
||||
prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
|
||||
neg_prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
|
||||
|
||||
# System
|
||||
# load_btn.click(fn=load_pipeline, inputs=[backend_drop, compile_check, offload_check, gr.State("bfloat16")], outputs=None)
|
||||
|
|
@ -91,5 +91,120 @@ def create_ui():
|
|||
# demo.load(fn=get_history_images, inputs=None, outputs=history_gallery)
|
||||
|
||||
|
||||
def create_event_handlers():
|
||||
pass
|
||||
def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq):
|
||||
if engine.pipe is None:
|
||||
load_pipeline("SDPA", False, False, "bfloat16")
|
||||
|
||||
if seed == -1: seed = np.random.randint(0, 2**32 - 1)
|
||||
|
||||
# We use a base generator. For sequential batches, we might increment seed if desired,
|
||||
# but here we keep the base seed logic consistent.
|
||||
generator = torch.Generator("cuda").manual_seed(int(seed))
|
||||
|
||||
all_images = []
|
||||
|
||||
# SEQUENTIAL LOOP (Easy on VRAM)
|
||||
for i in range(batch_count_seq):
|
||||
# Update seed for subsequent batches so they aren't identical
|
||||
current_seed = seed + i
|
||||
generator.manual_seed(int(current_seed))
|
||||
|
||||
# PARALLEL GENERATION (Fast, Heavy VRAM)
|
||||
# diffusers handles 'num_images_per_prompt' for parallel execution
|
||||
batch_results = engine.pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=neg_prompt,
|
||||
height=int(height),
|
||||
width=int(width),
|
||||
num_inference_steps=int(steps),
|
||||
guidance_scale=0.0,
|
||||
num_images_per_prompt=int(batch_size_parallel),
|
||||
generator=generator,
|
||||
).images
|
||||
|
||||
all_images.extend(batch_results)
|
||||
|
||||
# Save to disk
|
||||
save_generated_images(all_images, prompt, seed)
|
||||
|
||||
return all_images, seed
|
||||
|
||||
|
||||
# --- File Saving Logic ---
|
||||
def save_generated_images(images, prompt, seed):
|
||||
# Create folder structure: outputs/YYYY-MM-DD/
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
folder_path = os.path.join("outputs", date_str)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
saved_paths = []
|
||||
|
||||
for idx, img in enumerate(images):
|
||||
timestamp = datetime.now().strftime("%H-%M-%S")
|
||||
# Filename: Time_Seed_Index.png
|
||||
filename = f"{timestamp}_{seed}_{idx}.png"
|
||||
full_path = os.path.join(folder_path, filename)
|
||||
|
||||
# Save image
|
||||
img.save(full_path)
|
||||
saved_paths.append(full_path)
|
||||
|
||||
# Optional: Save prompt metadata in a text file next to it?
|
||||
# For now, we just save the image.
|
||||
|
||||
return saved_paths
|
||||
|
||||
|
||||
# --- History Logic ---
|
||||
def get_history_images():
|
||||
"""Scans the outputs folder and returns all images, newest first"""
|
||||
if not os.path.exists("outputs"):
|
||||
return []
|
||||
|
||||
image_files = []
|
||||
for root, dirs, files in os.walk("outputs"):
|
||||
for file in files:
|
||||
if file.endswith((".png", ".jpg", ".jpeg")):
|
||||
full_path = os.path.join(root, file)
|
||||
# Get creation time for sorting
|
||||
mtime = os.path.getmtime(full_path)
|
||||
image_files.append((full_path, mtime))
|
||||
|
||||
# Sort by time, newest first
|
||||
image_files.sort(key=lambda x: x[1], reverse=True)
|
||||
return [x[0] for x in image_files]
|
||||
|
||||
|
||||
def load_pipeline(attn_backend, compile_model, offload_cpu, dtype_str):
|
||||
dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
|
||||
target_dtype = dtype_map.get(dtype_str, torch.bfloat16)
|
||||
|
||||
if engine.pipe is not None and engine.config["backend"] == attn_backend:
|
||||
return gr.Info("Pipeline ready.")
|
||||
|
||||
try:
|
||||
gr.Info(f"Loading Model ({attn_backend})...")
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
engine.config["model_id"],
|
||||
torch_dtype=target_dtype,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
if not offload_cpu: pipe.to("cuda")
|
||||
|
||||
if attn_backend == "Flash Attention 2":
|
||||
pipe.transformer.set_attention_backend("flash")
|
||||
elif attn_backend == "Flash Attention 3":
|
||||
pipe.transformer.set_attention_backend("_flash_3")
|
||||
|
||||
if compile_model:
|
||||
gr.Warning("Compiling... First run will be slow.")
|
||||
pipe.transformer.compile()
|
||||
|
||||
if offload_cpu: pipe.enable_model_cpu_offload()
|
||||
|
||||
engine.pipe = pipe
|
||||
engine.config["backend"] = attn_backend
|
||||
return gr.Success("System Ready.")
|
||||
except Exception as e:
|
||||
return gr.Error(f"Init Failed: {str(e)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ def check_model_loaded():
|
|||
return True, None
|
||||
|
||||
|
||||
def resolve_model_path(model_name_or_path):
|
||||
def resolve_model_path(model_name_or_path, image_model=False):
|
||||
"""
|
||||
Resolves a model path, checking for a direct path
|
||||
before the default models directory.
|
||||
|
|
@ -95,6 +95,8 @@ def resolve_model_path(model_name_or_path):
|
|||
path_candidate = Path(model_name_or_path)
|
||||
if path_candidate.exists():
|
||||
return path_candidate
|
||||
elif image_model:
|
||||
return Path(f'{shared.args.image_model_dir}/{model_name_or_path}')
|
||||
else:
|
||||
return Path(f'{shared.args.model_dir}/{model_name_or_path}')
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue