mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
426 lines
17 KiB
Python
426 lines
17 KiB
Python
import os
|
|
import traceback
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import gradio as gr
|
|
import numpy as np
|
|
import torch
|
|
|
|
from modules import shared, ui, utils
|
|
from modules.image_models import load_image_model, unload_image_model
|
|
from modules.utils import gradio
|
|
|
|
ASPECT_RATIOS = {
|
|
"1:1 Square": (1, 1),
|
|
"16:9 Cinema": (16, 9),
|
|
"9:16 Mobile": (9, 16),
|
|
"4:3 Photo": (4, 3),
|
|
"Custom": None,
|
|
}
|
|
|
|
STEP = 32
|
|
|
|
|
|
def round_to_step(value, step=STEP):
|
|
return round(value / step) * step
|
|
|
|
|
|
def clamp(value, min_val, max_val):
|
|
return max(min_val, min(max_val, value))
|
|
|
|
|
|
def apply_aspect_ratio(aspect_ratio, current_width, current_height):
|
|
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
|
|
return current_width, current_height
|
|
|
|
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
|
|
|
if w_ratio == h_ratio:
|
|
base = min(current_width, current_height)
|
|
new_width = base
|
|
new_height = base
|
|
elif w_ratio < h_ratio:
|
|
new_width = current_width
|
|
new_height = round_to_step(current_width * h_ratio / w_ratio)
|
|
else:
|
|
new_height = current_height
|
|
new_width = round_to_step(current_height * w_ratio / h_ratio)
|
|
|
|
new_width = clamp(new_width, 256, 2048)
|
|
new_height = clamp(new_height, 256, 2048)
|
|
|
|
return int(new_width), int(new_height)
|
|
|
|
|
|
def update_height_from_width(width, aspect_ratio):
|
|
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
|
|
return gr.update()
|
|
|
|
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
|
new_height = round_to_step(width * h_ratio / w_ratio)
|
|
new_height = clamp(new_height, 256, 2048)
|
|
|
|
return int(new_height)
|
|
|
|
|
|
def update_width_from_height(height, aspect_ratio):
|
|
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
|
|
return gr.update()
|
|
|
|
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
|
new_width = round_to_step(height * w_ratio / h_ratio)
|
|
new_width = clamp(new_width, 256, 2048)
|
|
|
|
return int(new_width)
|
|
|
|
|
|
def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
|
|
new_width, new_height = height, width
|
|
|
|
new_ratio = "Custom"
|
|
for name, ratios in ASPECT_RATIOS.items():
|
|
if ratios is None:
|
|
continue
|
|
w_r, h_r = ratios
|
|
expected_height = new_width * h_r / w_r
|
|
if abs(expected_height - new_height) < STEP:
|
|
new_ratio = name
|
|
break
|
|
|
|
return new_width, new_height, new_ratio
|
|
|
|
|
|
def create_ui():
|
|
if shared.settings['image_model_menu'] != 'None':
|
|
shared.image_model_name = shared.settings['image_model_menu']
|
|
|
|
with gr.Tab("Image AI", elem_id="image-ai-tab"):
|
|
with gr.Tabs():
|
|
# TAB 1: GENERATE
|
|
with gr.TabItem("Generate"):
|
|
with gr.Row():
|
|
with gr.Column(scale=4, min_width=350):
|
|
shared.gradio['image_prompt'] = gr.Textbox(
|
|
label="Prompt",
|
|
placeholder="Describe your imagination...",
|
|
lines=3,
|
|
autofocus=True,
|
|
value=shared.settings['image_prompt']
|
|
)
|
|
shared.gradio['image_neg_prompt'] = gr.Textbox(
|
|
label="Negative Prompt",
|
|
placeholder="Low quality...",
|
|
lines=3,
|
|
value=shared.settings['image_neg_prompt']
|
|
)
|
|
|
|
shared.gradio['image_generate_btn'] = gr.Button("✨ GENERATE", variant="primary", size="lg", elem_id="gen-btn")
|
|
gr.HTML("<hr style='border-top: 1px solid #444; margin: 20px 0;'>")
|
|
|
|
gr.Markdown("### 📐 Dimensions")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=32, label="Width")
|
|
with gr.Column():
|
|
shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=32, label="Height")
|
|
shared.gradio['image_swap_btn'] = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width")
|
|
|
|
with gr.Row():
|
|
shared.gradio['image_aspect_ratio'] = gr.Radio(
|
|
choices=["1:1 Square", "16:9 Cinema", "9:16 Mobile", "4:3 Photo", "Custom"],
|
|
value=shared.settings['image_aspect_ratio'],
|
|
label="Aspect Ratio",
|
|
interactive=True
|
|
)
|
|
|
|
gr.Markdown("### ⚙️ Config")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
shared.gradio['image_steps'] = gr.Slider(1, 15, value=shared.settings['image_steps'], step=1, label="Steps")
|
|
shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random")
|
|
with gr.Column():
|
|
shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
|
|
shared.gradio['image_batch_count'] = gr.Slider(1, 128, value=shared.settings['image_batch_count'], step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
|
|
|
|
with gr.Column(scale=6, min_width=500):
|
|
with gr.Column(elem_classes=["viewport-container"]):
|
|
shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True)
|
|
with gr.Row():
|
|
shared.gradio['image_used_seed'] = gr.Markdown(label="Info", interactive=False)
|
|
|
|
# TAB 2: GALLERY
|
|
with gr.TabItem("Gallery"):
|
|
with gr.Row():
|
|
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
|
|
shared.gradio['image_history_gallery'] = gr.Gallery(label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True)
|
|
|
|
# TAB 3: MODEL
|
|
with gr.TabItem("Model"):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
with gr.Row():
|
|
shared.gradio['image_model_menu'] = gr.Dropdown(
|
|
choices=utils.get_available_image_models(),
|
|
value=shared.settings['image_model_menu'],
|
|
label='Model',
|
|
elem_classes='slim-dropdown'
|
|
)
|
|
shared.gradio['image_refresh_models'] = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40)
|
|
shared.gradio['image_load_model'] = gr.Button("Load", variant='primary', elem_classes='refresh-button')
|
|
shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button')
|
|
|
|
gr.Markdown("## Settings")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
shared.gradio['image_dtype'] = gr.Dropdown(
|
|
choices=['bfloat16', 'float16'],
|
|
value=shared.settings['image_dtype'],
|
|
label='Data Type',
|
|
info='bfloat16 recommended for modern GPUs'
|
|
)
|
|
shared.gradio['image_attn_backend'] = gr.Dropdown(
|
|
choices=['sdpa', 'flash_attention_2', 'flash_attention_3'],
|
|
value=shared.settings['image_attn_backend'],
|
|
label='Attention Backend',
|
|
info='SDPA is default. Flash Attention requires compatible GPU.'
|
|
)
|
|
with gr.Column():
|
|
shared.gradio['image_compile'] = gr.Checkbox(
|
|
value=shared.settings['image_compile'],
|
|
label='Compile Model',
|
|
info='Faster inference after first run. First run will be slow.'
|
|
)
|
|
shared.gradio['image_cpu_offload'] = gr.Checkbox(
|
|
value=shared.settings['image_cpu_offload'],
|
|
label='CPU Offload',
|
|
info='Enable for low VRAM GPUs. Slower but uses less memory.'
|
|
)
|
|
|
|
with gr.Column():
|
|
shared.gradio['image_download_path'] = gr.Textbox(
|
|
label="Download model",
|
|
placeholder="Tongyi-MAI/Z-Image-Turbo",
|
|
info="Enter HuggingFace path. Use : for branch, e.g. user/model:main"
|
|
)
|
|
shared.gradio['image_download_btn'] = gr.Button("Download", variant='primary')
|
|
shared.gradio['image_model_status'] = gr.Markdown(
|
|
value=f"Model: **{shared.settings['image_model_menu']}** (not loaded)" if shared.settings['image_model_menu'] != 'None' else "No model selected"
|
|
)
|
|
|
|
|
|
def create_event_handlers():
|
|
# Dimension controls
|
|
shared.gradio['image_aspect_ratio'].change(
|
|
apply_aspect_ratio,
|
|
gradio('image_aspect_ratio', 'image_width', 'image_height'),
|
|
gradio('image_width', 'image_height'),
|
|
show_progress=False
|
|
)
|
|
|
|
shared.gradio['image_width'].release(
|
|
update_height_from_width,
|
|
gradio('image_width', 'image_aspect_ratio'),
|
|
gradio('image_height'),
|
|
show_progress=False
|
|
)
|
|
|
|
shared.gradio['image_height'].release(
|
|
update_width_from_height,
|
|
gradio('image_height', 'image_aspect_ratio'),
|
|
gradio('image_width'),
|
|
show_progress=False
|
|
)
|
|
|
|
shared.gradio['image_swap_btn'].click(
|
|
swap_dimensions_and_update_ratio,
|
|
gradio('image_width', 'image_height', 'image_aspect_ratio'),
|
|
gradio('image_width', 'image_height', 'image_aspect_ratio'),
|
|
show_progress=False
|
|
)
|
|
|
|
# Generation
|
|
shared.gradio['image_generate_btn'].click(
|
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
|
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
|
|
|
|
shared.gradio['image_prompt'].submit(
|
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
|
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
|
|
|
|
shared.gradio['image_neg_prompt'].submit(
|
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
|
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
|
|
|
|
# Model management
|
|
shared.gradio['image_refresh_models'].click(
|
|
lambda: gr.update(choices=utils.get_available_image_models()),
|
|
None,
|
|
gradio('image_model_menu'),
|
|
show_progress=False
|
|
)
|
|
|
|
shared.gradio['image_load_model'].click(
|
|
load_image_model_wrapper,
|
|
gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile'),
|
|
gradio('image_model_status'),
|
|
show_progress=True
|
|
)
|
|
|
|
shared.gradio['image_unload_model'].click(
|
|
unload_image_model_wrapper,
|
|
None,
|
|
gradio('image_model_status'),
|
|
show_progress=False
|
|
)
|
|
|
|
shared.gradio['image_download_btn'].click(
|
|
download_image_model_wrapper,
|
|
gradio('image_download_path'),
|
|
gradio('image_model_status', 'image_model_menu'),
|
|
show_progress=True
|
|
)
|
|
|
|
# History
|
|
shared.gradio['image_refresh_history'].click(
|
|
get_history_images,
|
|
None,
|
|
gradio('image_history_gallery'),
|
|
show_progress=False
|
|
)
|
|
|
|
|
|
def generate(state):
|
|
model_name = state['image_model_menu']
|
|
|
|
if not model_name or model_name == 'None':
|
|
return [], "No image model selected. Go to the Model tab and select a model."
|
|
|
|
if shared.image_model is None:
|
|
result = load_image_model(
|
|
model_name,
|
|
dtype=state['image_dtype'],
|
|
attn_backend=state['image_attn_backend'],
|
|
cpu_offload=state['image_cpu_offload'],
|
|
compile_model=state['image_compile']
|
|
)
|
|
if result is None:
|
|
return [], f"Failed to load model `{model_name}`."
|
|
|
|
shared.image_model_name = model_name
|
|
|
|
seed = state['image_seed']
|
|
if seed == -1:
|
|
seed = np.random.randint(0, 2**32 - 1)
|
|
|
|
generator = torch.Generator("cuda").manual_seed(int(seed))
|
|
all_images = []
|
|
|
|
for i in range(int(state['image_batch_count'])):
|
|
generator.manual_seed(int(seed + i))
|
|
batch_results = shared.image_model(
|
|
prompt=state['image_prompt'],
|
|
negative_prompt=state['image_neg_prompt'],
|
|
height=int(state['image_height']),
|
|
width=int(state['image_width']),
|
|
num_inference_steps=int(state['image_steps']),
|
|
guidance_scale=0.0,
|
|
num_images_per_prompt=int(state['image_batch_size']),
|
|
generator=generator,
|
|
).images
|
|
all_images.extend(batch_results)
|
|
|
|
save_generated_images(all_images, state['image_prompt'], seed)
|
|
return all_images, f"Seed: {seed}"
|
|
|
|
|
|
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model):
|
|
if not model_name or model_name == 'None':
|
|
yield "No model selected"
|
|
return
|
|
|
|
try:
|
|
yield f"Loading `{model_name}`..."
|
|
unload_image_model()
|
|
|
|
result = load_image_model(
|
|
model_name,
|
|
dtype=dtype,
|
|
attn_backend=attn_backend,
|
|
cpu_offload=cpu_offload,
|
|
compile_model=compile_model
|
|
)
|
|
|
|
if result is not None:
|
|
shared.image_model_name = model_name
|
|
yield f"✓ Loaded **{model_name}**"
|
|
else:
|
|
yield f"✗ Failed to load `{model_name}`"
|
|
except Exception:
|
|
yield f"Error:\n```\n{traceback.format_exc()}\n```"
|
|
|
|
|
|
def unload_image_model_wrapper():
|
|
unload_image_model()
|
|
if shared.image_model_name != 'None':
|
|
return f"Model: **{shared.image_model_name}** (not loaded)"
|
|
return "No model loaded"
|
|
|
|
|
|
def download_image_model_wrapper(model_path):
|
|
from huggingface_hub import snapshot_download
|
|
|
|
if not model_path:
|
|
yield "No model specified", gr.update()
|
|
return
|
|
|
|
try:
|
|
if ':' in model_path:
|
|
model_id, branch = model_path.rsplit(':', 1)
|
|
else:
|
|
model_id, branch = model_path, 'main'
|
|
|
|
folder_name = model_id.replace('/', '_')
|
|
output_folder = Path(shared.args.image_model_dir) / folder_name
|
|
|
|
yield f"Downloading `{model_id}` (branch: {branch})...", gr.update()
|
|
|
|
snapshot_download(
|
|
repo_id=model_id,
|
|
revision=branch,
|
|
local_dir=output_folder,
|
|
local_dir_use_symlinks=False,
|
|
)
|
|
|
|
new_choices = utils.get_available_image_models()
|
|
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
|
|
except Exception:
|
|
yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update()
|
|
|
|
|
|
def save_generated_images(images, prompt, seed):
|
|
date_str = datetime.now().strftime("%Y-%m-%d")
|
|
folder_path = os.path.join("user_data", "image_outputs", date_str)
|
|
os.makedirs(folder_path, exist_ok=True)
|
|
|
|
for idx, img in enumerate(images):
|
|
timestamp = datetime.now().strftime("%H-%M-%S")
|
|
filename = f"{timestamp}_{seed}_{idx}.png"
|
|
img.save(os.path.join(folder_path, filename))
|
|
|
|
|
|
def get_history_images():
|
|
output_dir = os.path.join("user_data", "image_outputs")
|
|
if not os.path.exists(output_dir):
|
|
return []
|
|
|
|
image_files = []
|
|
for root, _, files in os.walk(output_dir):
|
|
for file in files:
|
|
if file.endswith((".png", ".jpg", ".jpeg")):
|
|
full_path = os.path.join(root, file)
|
|
image_files.append((full_path, os.path.getmtime(full_path)))
|
|
|
|
image_files.sort(key=lambda x: x[1], reverse=True)
|
|
return [x[0] for x in image_files]
|