mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Image generation now functional
This commit is contained in:
parent
2f11b3040d
commit
a873692234
|
|
@ -51,8 +51,12 @@ group.add_argument('--verbose', action='store_true', help='Print the prompts to
|
|||
group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.')
|
||||
|
||||
# Image generation
|
||||
group = parser.add_argument_group('Image model')
|
||||
group.add_argument('--image-model', type=str, help='Name of the image model to load by default.')
|
||||
group.add_argument('--image-model-dir', type=str, default='user_data/image_models', help='Path to directory with all the image models.')
|
||||
group.add_argument('--image-dtype', type=str, default='bfloat16', choices=['bfloat16', 'float16'], help='Data type for image model.')
|
||||
group.add_argument('--image-attn-backend', type=str, default='sdpa', choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.')
|
||||
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
|
||||
|
||||
# Model loader
|
||||
group = parser.add_argument_group('Model loader')
|
||||
|
|
|
|||
|
|
@ -1,6 +1,13 @@
|
|||
import gradio as gr
|
||||
import os
|
||||
from modules.utils import resolve_model_path
|
||||
from datetime import datetime
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
from modules.image_models import load_image_model, unload_image_model
|
||||
|
||||
|
||||
def create_ui():
|
||||
with gr.Tab("Image AI", elem_id="image-ai-tab"):
|
||||
|
|
@ -92,26 +99,33 @@ def create_ui():
|
|||
|
||||
|
||||
def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel, batch_count_seq):
|
||||
if engine.pipe is None:
|
||||
load_pipeline("SDPA", False, False, "bfloat16")
|
||||
|
||||
if seed == -1: seed = np.random.randint(0, 2**32 - 1)
|
||||
|
||||
# We use a base generator. For sequential batches, we might increment seed if desired,
|
||||
# but here we keep the base seed logic consistent.
|
||||
import numpy as np
|
||||
import torch
|
||||
from modules import shared
|
||||
from modules.image_models import load_image_model
|
||||
|
||||
# Auto-load model if not loaded
|
||||
if shared.image_model is None:
|
||||
if shared.image_model_name == 'None':
|
||||
return [], "No image model selected. Please load a model first."
|
||||
load_image_model(shared.image_model_name)
|
||||
|
||||
if shared.image_model is None:
|
||||
return [], "Failed to load image model."
|
||||
|
||||
if seed == -1:
|
||||
seed = np.random.randint(0, 2**32 - 1)
|
||||
|
||||
generator = torch.Generator("cuda").manual_seed(int(seed))
|
||||
|
||||
all_images = []
|
||||
|
||||
# SEQUENTIAL LOOP (Easy on VRAM)
|
||||
for i in range(batch_count_seq):
|
||||
# Update seed for subsequent batches so they aren't identical
|
||||
|
||||
# Sequential loop (easier on VRAM)
|
||||
for i in range(int(batch_count_seq)):
|
||||
current_seed = seed + i
|
||||
generator.manual_seed(int(current_seed))
|
||||
|
||||
# PARALLEL GENERATION (Fast, Heavy VRAM)
|
||||
# diffusers handles 'num_images_per_prompt' for parallel execution
|
||||
batch_results = engine.pipe(
|
||||
|
||||
# Parallel generation
|
||||
batch_results = shared.image_model(
|
||||
prompt=prompt,
|
||||
negative_prompt=neg_prompt,
|
||||
height=int(height),
|
||||
|
|
@ -121,13 +135,13 @@ def generate(prompt, neg_prompt, width, height, steps, seed, batch_size_parallel
|
|||
num_images_per_prompt=int(batch_size_parallel),
|
||||
generator=generator,
|
||||
).images
|
||||
|
||||
|
||||
all_images.extend(batch_results)
|
||||
|
||||
|
||||
# Save to disk
|
||||
save_generated_images(all_images, prompt, seed)
|
||||
|
||||
return all_images, seed
|
||||
|
||||
return all_images, f"Seed: {seed}"
|
||||
|
||||
|
||||
# --- File Saving Logic ---
|
||||
|
|
@ -173,38 +187,3 @@ def get_history_images():
|
|||
# Sort by time, newest first
|
||||
image_files.sort(key=lambda x: x[1], reverse=True)
|
||||
return [x[0] for x in image_files]
|
||||
|
||||
|
||||
def load_pipeline(attn_backend, compile_model, offload_cpu, dtype_str):
|
||||
dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
|
||||
target_dtype = dtype_map.get(dtype_str, torch.bfloat16)
|
||||
|
||||
if engine.pipe is not None and engine.config["backend"] == attn_backend:
|
||||
return gr.Info("Pipeline ready.")
|
||||
|
||||
try:
|
||||
gr.Info(f"Loading Model ({attn_backend})...")
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
engine.config["model_id"],
|
||||
torch_dtype=target_dtype,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
if not offload_cpu: pipe.to("cuda")
|
||||
|
||||
if attn_backend == "Flash Attention 2":
|
||||
pipe.transformer.set_attention_backend("flash")
|
||||
elif attn_backend == "Flash Attention 3":
|
||||
pipe.transformer.set_attention_backend("_flash_3")
|
||||
|
||||
if compile_model:
|
||||
gr.Warning("Compiling... First run will be slow.")
|
||||
pipe.transformer.compile()
|
||||
|
||||
if offload_cpu: pipe.enable_model_cpu_offload()
|
||||
|
||||
engine.pipe = pipe
|
||||
engine.config["backend"] = attn_backend
|
||||
return gr.Success("System Ready.")
|
||||
except Exception as e:
|
||||
return gr.Error(f"Init Failed: {str(e)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ def create_ui():
|
|||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['image_model_menu'] = gr.Dropdown(choices=utils.get_available_image_models(), value=lambda: shared.image_model_name, label='Model', elem_classes='slim-dropdown', interactive=not mu)
|
||||
ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu)
|
||||
ui.create_refresh_button(shared.gradio['image_model_menu'], lambda: None, lambda: {'choices': utils.get_available_image_models()}, 'refresh-button', interactive=not mu)
|
||||
shared.gradio['image_load_model'] = gr.Button("Load", elem_classes='refresh-button', interactive=not mu)
|
||||
shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button', interactive=not mu)
|
||||
shared.gradio['image_save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button', interactive=not mu)
|
||||
|
|
@ -169,7 +169,7 @@ def create_ui():
|
|||
shared.gradio['image_download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu)
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
|
||||
shared.gradio['image_model_status'] = gr.Markdown('No model is loaded' if shared.image_model_name == 'None' else 'Ready')
|
||||
|
||||
|
||||
def create_event_handlers():
|
||||
|
|
@ -220,6 +220,28 @@ def create_event_handlers():
|
|||
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
|
||||
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
|
||||
|
||||
# Image model event handlers
|
||||
shared.gradio['image_load_model'].click(
|
||||
load_image_model_wrapper,
|
||||
gradio('image_model_menu'),
|
||||
gradio('image_model_status'),
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
shared.gradio['image_unload_model'].click(
|
||||
handle_unload_image_model_click,
|
||||
None,
|
||||
gradio('image_model_status'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
shared.gradio['image_download_model_button'].click(
|
||||
download_image_model_wrapper,
|
||||
gradio('image_custom_model_menu'),
|
||||
gradio('image_model_status'),
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
|
||||
def load_model_wrapper(selected_model, loader, autoload=False):
|
||||
try:
|
||||
|
|
@ -471,3 +493,66 @@ def format_file_size(size_bytes):
|
|||
return f"{s:.2f} {size_names[i]}"
|
||||
else:
|
||||
return f"{s:.1f} {size_names[i]}"
|
||||
|
||||
|
||||
def load_image_model_wrapper(selected_model):
|
||||
"""Wrapper for loading image models with status updates."""
|
||||
from modules.image_models import load_image_model, unload_image_model
|
||||
|
||||
if selected_model == 'None' or not selected_model:
|
||||
yield "No model selected"
|
||||
return
|
||||
|
||||
try:
|
||||
yield f"Loading `{selected_model}`..."
|
||||
unload_image_model()
|
||||
result = load_image_model(selected_model)
|
||||
|
||||
if result is not None:
|
||||
yield f"Successfully loaded `{selected_model}`."
|
||||
else:
|
||||
yield f"Failed to load `{selected_model}`."
|
||||
except Exception:
|
||||
exc = traceback.format_exc()
|
||||
yield exc.replace('\n', '\n\n')
|
||||
|
||||
|
||||
def handle_unload_image_model_click():
|
||||
"""Handler for the image model unload button."""
|
||||
from modules.image_models import unload_image_model
|
||||
unload_image_model()
|
||||
return "Image model unloaded"
|
||||
|
||||
|
||||
def download_image_model_wrapper(custom_model):
|
||||
"""Download an image model from Hugging Face."""
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
if not custom_model:
|
||||
yield "No model specified"
|
||||
return
|
||||
|
||||
try:
|
||||
# Parse model name and branch
|
||||
if ':' in custom_model:
|
||||
model_name, branch = custom_model.rsplit(':', 1)
|
||||
else:
|
||||
model_name, branch = custom_model, 'main'
|
||||
|
||||
# Output folder
|
||||
output_folder = Path(shared.args.image_model_dir) / model_name.split('/')[-1]
|
||||
|
||||
yield f"Downloading `{model_name}` (branch: {branch})..."
|
||||
|
||||
snapshot_download(
|
||||
repo_id=model_name,
|
||||
revision=branch,
|
||||
local_dir=output_folder,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
yield f"Model successfully saved to `{output_folder}/`."
|
||||
|
||||
except Exception:
|
||||
exc = traceback.format_exc()
|
||||
yield exc.replace('\n', '\n\n')
|
||||
|
|
|
|||
Loading…
Reference in a new issue