mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Add PNG metadata, add pagination to Gallery tab
This commit is contained in:
parent
c8e9d7fc37
commit
6a7209a842
50
css/main.css
50
css/main.css
|
|
@ -1688,6 +1688,52 @@ button#swap-height-width {
|
|||
}
|
||||
|
||||
#image-history-gallery, #image-history-gallery > :nth-child(2) {
|
||||
height: calc(100vh - 139px);
|
||||
max-height: calc(100vh - 139px);
|
||||
height: calc(100vh - 174px);
|
||||
max-height: calc(100vh - 174px);
|
||||
}
|
||||
|
||||
/* Additional CSS for the paginated image gallery */
|
||||
|
||||
/* Page info styling */
|
||||
#image-page-info {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
min-width: 200px;
|
||||
font-size: 0.9em;
|
||||
color: var(--body-text-color-subdued);
|
||||
}
|
||||
|
||||
/* Settings display panel */
|
||||
#image-ai-tab .settings-display-panel {
|
||||
background: var(--background-fill-secondary);
|
||||
padding: 12px;
|
||||
border-radius: 8px;
|
||||
font-size: 0.9em;
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
/* Gallery status message */
|
||||
#image-ai-tab .gallery-status {
|
||||
color: var(--color-accent);
|
||||
font-size: 0.85em;
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
/* Pagination button row alignment */
|
||||
#image-ai-tab .pagination-controls {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
/* Selected image preview container */
|
||||
#image-ai-tab .selected-preview-container {
|
||||
border: 1px solid var(--border-color-primary);
|
||||
border-radius: 8px;
|
||||
padding: 8px;
|
||||
background: var(--background-fill-secondary);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
|
|
@ -7,6 +8,8 @@ from pathlib import Path
|
|||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
from modules import shared, ui, utils
|
||||
from modules.image_models import load_image_model, unload_image_model
|
||||
|
|
@ -22,6 +25,24 @@ ASPECT_RATIOS = {
|
|||
}
|
||||
|
||||
STEP = 32
|
||||
IMAGES_PER_PAGE = 64
|
||||
|
||||
# Settings keys to save in PNG metadata (Generate tab only)
|
||||
METADATA_SETTINGS_KEYS = [
|
||||
'image_prompt',
|
||||
'image_neg_prompt',
|
||||
'image_width',
|
||||
'image_height',
|
||||
'image_aspect_ratio',
|
||||
'image_steps',
|
||||
'image_seed',
|
||||
'image_batch_size',
|
||||
'image_batch_count',
|
||||
]
|
||||
|
||||
# Cache for all image paths
|
||||
_image_cache = []
|
||||
_cache_timestamp = 0
|
||||
|
||||
|
||||
def round_to_step(value, step=STEP):
|
||||
|
|
@ -93,6 +114,215 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
|
|||
return new_width, new_height, new_ratio
|
||||
|
||||
|
||||
def build_generation_metadata(state, actual_seed):
|
||||
"""Build metadata dict from generation settings."""
|
||||
metadata = {}
|
||||
for key in METADATA_SETTINGS_KEYS:
|
||||
if key in state:
|
||||
metadata[key] = state[key]
|
||||
|
||||
# Store the actual seed used (not -1)
|
||||
metadata['image_seed'] = actual_seed
|
||||
metadata['generated_at'] = datetime.now().isoformat()
|
||||
metadata['model'] = shared.image_model_name
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def save_generated_images(images, state, actual_seed):
|
||||
"""Save images with generation metadata embedded in PNG."""
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
folder_path = os.path.join("user_data", "image_outputs", date_str)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
metadata = build_generation_metadata(state, actual_seed)
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||
|
||||
for idx, img in enumerate(images):
|
||||
timestamp = datetime.now().strftime("%H-%M-%S")
|
||||
filename = f"{timestamp}_{actual_seed}_{idx}.png"
|
||||
filepath = os.path.join(folder_path, filename)
|
||||
|
||||
# Create PNG metadata
|
||||
png_info = PngInfo()
|
||||
png_info.add_text("image_gen_settings", metadata_json)
|
||||
|
||||
# Save with metadata
|
||||
img.save(filepath, pnginfo=png_info)
|
||||
|
||||
|
||||
def read_image_metadata(image_path):
|
||||
"""Read generation metadata from PNG file."""
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
if hasattr(img, 'text') and 'image_gen_settings' in img.text:
|
||||
return json.loads(img.text['image_gen_settings'])
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not read metadata from {image_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def format_metadata_for_display(metadata):
|
||||
"""Format metadata as readable text."""
|
||||
if not metadata:
|
||||
return "No generation settings found in this image."
|
||||
|
||||
lines = ["**Generation Settings**", ""]
|
||||
|
||||
# Display in a nice order
|
||||
display_order = [
|
||||
('image_prompt', 'Prompt'),
|
||||
('image_neg_prompt', 'Negative Prompt'),
|
||||
('image_width', 'Width'),
|
||||
('image_height', 'Height'),
|
||||
('image_aspect_ratio', 'Aspect Ratio'),
|
||||
('image_steps', 'Steps'),
|
||||
('image_seed', 'Seed'),
|
||||
('image_batch_size', 'Batch Size'),
|
||||
('image_batch_count', 'Batch Count'),
|
||||
('model', 'Model'),
|
||||
('generated_at', 'Generated At'),
|
||||
]
|
||||
|
||||
for key, label in display_order:
|
||||
if key in metadata:
|
||||
value = metadata[key]
|
||||
if key in ['image_prompt', 'image_neg_prompt'] and value:
|
||||
# Truncate long prompts for display
|
||||
if len(str(value)) > 200:
|
||||
value = str(value)[:200] + "..."
|
||||
lines.append(f"**{label}:** {value}")
|
||||
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def get_all_history_images(force_refresh=False):
|
||||
"""Get all history images sorted by modification time (newest first). Uses caching."""
|
||||
global _image_cache, _cache_timestamp
|
||||
|
||||
output_dir = os.path.join("user_data", "image_outputs")
|
||||
if not os.path.exists(output_dir):
|
||||
return []
|
||||
|
||||
# Check if we need to refresh cache
|
||||
current_time = time.time()
|
||||
if not force_refresh and _image_cache and (current_time - _cache_timestamp) < 2:
|
||||
return _image_cache
|
||||
|
||||
image_files = []
|
||||
for root, _, files in os.walk(output_dir):
|
||||
for file in files:
|
||||
if file.endswith((".png", ".jpg", ".jpeg")):
|
||||
full_path = os.path.join(root, file)
|
||||
image_files.append((full_path, os.path.getmtime(full_path)))
|
||||
|
||||
image_files.sort(key=lambda x: x[1], reverse=True)
|
||||
_image_cache = [x[0] for x in image_files]
|
||||
_cache_timestamp = current_time
|
||||
|
||||
return _image_cache
|
||||
|
||||
|
||||
def get_paginated_images(page=0, force_refresh=False):
|
||||
"""Get images for a specific page."""
|
||||
all_images = get_all_history_images(force_refresh)
|
||||
total_images = len(all_images)
|
||||
total_pages = max(1, (total_images + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE)
|
||||
|
||||
# Clamp page to valid range
|
||||
page = max(0, min(page, total_pages - 1))
|
||||
|
||||
start_idx = page * IMAGES_PER_PAGE
|
||||
end_idx = min(start_idx + IMAGES_PER_PAGE, total_images)
|
||||
|
||||
page_images = all_images[start_idx:end_idx]
|
||||
|
||||
return page_images, page, total_pages, total_images
|
||||
|
||||
|
||||
def refresh_gallery(current_page=0):
|
||||
"""Refresh gallery with current page."""
|
||||
images, page, total_pages, total_images = get_paginated_images(current_page, force_refresh=True)
|
||||
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||
return images, page, page_info
|
||||
|
||||
|
||||
def go_to_page(page_num, current_page):
|
||||
"""Go to a specific page (1-indexed input)."""
|
||||
try:
|
||||
page = int(page_num) - 1 # Convert to 0-indexed
|
||||
except (ValueError, TypeError):
|
||||
page = current_page
|
||||
|
||||
images, page, total_pages, total_images = get_paginated_images(page)
|
||||
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||
return images, page, page_info
|
||||
|
||||
|
||||
def next_page(current_page):
|
||||
"""Go to next page."""
|
||||
images, page, total_pages, total_images = get_paginated_images(current_page + 1)
|
||||
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||
return images, page, page_info
|
||||
|
||||
|
||||
def prev_page(current_page):
|
||||
"""Go to previous page."""
|
||||
images, page, total_pages, total_images = get_paginated_images(current_page - 1)
|
||||
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||
return images, page, page_info
|
||||
|
||||
|
||||
def on_gallery_select(evt: gr.SelectData, current_page):
|
||||
"""Handle image selection from gallery."""
|
||||
if evt.index is None:
|
||||
return "", "Select an image to view its settings"
|
||||
|
||||
# Get the current page's images to find the actual file path
|
||||
all_images = get_all_history_images()
|
||||
total_images = len(all_images)
|
||||
|
||||
# Calculate the actual index in the full list
|
||||
start_idx = current_page * IMAGES_PER_PAGE
|
||||
actual_idx = start_idx + evt.index
|
||||
|
||||
if actual_idx >= total_images:
|
||||
return "", "Image not found"
|
||||
|
||||
image_path = all_images[actual_idx]
|
||||
metadata = read_image_metadata(image_path)
|
||||
metadata_display = format_metadata_for_display(metadata)
|
||||
|
||||
return image_path, metadata_display
|
||||
|
||||
|
||||
def send_to_generate(selected_image_path):
|
||||
"""Load settings from selected image and return updates for all Generate tab inputs."""
|
||||
if not selected_image_path or not os.path.exists(selected_image_path):
|
||||
return [gr.update()] * 9 + ["No image selected"]
|
||||
|
||||
metadata = read_image_metadata(selected_image_path)
|
||||
if not metadata:
|
||||
return [gr.update()] * 9 + ["No settings found in this image"]
|
||||
|
||||
# Return updates for each input element in order
|
||||
updates = [
|
||||
gr.update(value=metadata.get('image_prompt', '')),
|
||||
gr.update(value=metadata.get('image_neg_prompt', '')),
|
||||
gr.update(value=metadata.get('image_width', 1024)),
|
||||
gr.update(value=metadata.get('image_height', 1024)),
|
||||
gr.update(value=metadata.get('image_aspect_ratio', '1:1 Square')),
|
||||
gr.update(value=metadata.get('image_steps', 9)),
|
||||
gr.update(value=metadata.get('image_seed', -1)),
|
||||
gr.update(value=metadata.get('image_batch_size', 1)),
|
||||
gr.update(value=metadata.get('image_batch_count', 1)),
|
||||
]
|
||||
|
||||
status = f"✓ Settings loaded from image (seed: {metadata.get('image_seed', 'unknown')})"
|
||||
return updates + [status]
|
||||
|
||||
|
||||
|
||||
def create_ui():
|
||||
if shared.settings['image_model_menu'] != 'None':
|
||||
shared.image_model_name = shared.settings['image_model_menu']
|
||||
|
|
@ -149,11 +379,40 @@ def create_ui():
|
|||
with gr.Column(elem_classes=["viewport-container"]):
|
||||
shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery")
|
||||
|
||||
# TAB 2: GALLERY
|
||||
# TAB 2: GALLERY (with pagination)
|
||||
with gr.TabItem("Gallery"):
|
||||
with gr.Row():
|
||||
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
|
||||
shared.gradio['image_history_gallery'] = gr.Gallery(value=lambda : get_history_images(), label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True, elem_id="image-history-gallery")
|
||||
with gr.Column(scale=3):
|
||||
# Pagination controls
|
||||
with gr.Row():
|
||||
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh", elem_classes="refresh-button")
|
||||
shared.gradio['image_prev_page'] = gr.Button("◀ Prev", elem_classes="refresh-button")
|
||||
shared.gradio['image_page_info'] = gr.Markdown("Loading...", elem_id="image-page-info")
|
||||
shared.gradio['image_next_page'] = gr.Button("Next ▶", elem_classes="refresh-button")
|
||||
shared.gradio['image_page_input'] = gr.Number(value=1, label="Page", precision=0, minimum=1, scale=0, min_width=80)
|
||||
shared.gradio['image_go_to_page'] = gr.Button("Go", elem_classes="refresh-button", scale=0, min_width=50)
|
||||
|
||||
# State for current page and selected image path
|
||||
shared.gradio['image_current_page'] = gr.State(value=0)
|
||||
shared.gradio['image_selected_path'] = gr.State(value="")
|
||||
|
||||
# Paginated gallery using gr.Gallery
|
||||
shared.gradio['image_history_gallery'] = gr.Gallery(
|
||||
value=lambda: get_paginated_images(0)[0],
|
||||
label="Image History",
|
||||
show_label=False,
|
||||
columns=6,
|
||||
object_fit="cover",
|
||||
height="auto",
|
||||
allow_preview=True,
|
||||
elem_id="image-history-gallery"
|
||||
)
|
||||
|
||||
with gr.Column(scale=1):
|
||||
gr.Markdown("### Selected Image")
|
||||
shared.gradio['image_settings_display'] = gr.Markdown("Select an image to view its settings")
|
||||
shared.gradio['image_send_to_generate'] = gr.Button("📤 Send to Generate", variant="primary")
|
||||
shared.gradio['image_gallery_status'] = gr.Markdown("")
|
||||
|
||||
# TAB 3: MODEL
|
||||
with gr.TabItem("Model"):
|
||||
|
|
@ -281,11 +540,59 @@ def create_event_handlers():
|
|||
show_progress=True
|
||||
)
|
||||
|
||||
# History
|
||||
# Gallery pagination handlers
|
||||
shared.gradio['image_refresh_history'].click(
|
||||
get_history_images,
|
||||
None,
|
||||
gradio('image_history_gallery'),
|
||||
refresh_gallery,
|
||||
gradio('image_current_page'),
|
||||
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
shared.gradio['image_next_page'].click(
|
||||
next_page,
|
||||
gradio('image_current_page'),
|
||||
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
shared.gradio['image_prev_page'].click(
|
||||
prev_page,
|
||||
gradio('image_current_page'),
|
||||
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
shared.gradio['image_go_to_page'].click(
|
||||
go_to_page,
|
||||
gradio('image_page_input', 'image_current_page'),
|
||||
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
# Image selection from gallery
|
||||
shared.gradio['image_history_gallery'].select(
|
||||
on_gallery_select,
|
||||
gradio('image_current_page'),
|
||||
gradio('image_selected_path', 'image_settings_display'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
# Send to Generate
|
||||
shared.gradio['image_send_to_generate'].click(
|
||||
send_to_generate,
|
||||
gradio('image_selected_path'),
|
||||
gradio(
|
||||
'image_prompt',
|
||||
'image_neg_prompt',
|
||||
'image_width',
|
||||
'image_height',
|
||||
'image_aspect_ratio',
|
||||
'image_steps',
|
||||
'image_seed',
|
||||
'image_batch_size',
|
||||
'image_batch_count',
|
||||
'image_gallery_status'
|
||||
),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
|
|
@ -334,7 +641,7 @@ def generate(state):
|
|||
all_images.extend(batch_results)
|
||||
|
||||
t1 = time.time()
|
||||
save_generated_images(all_images, state['image_prompt'], seed)
|
||||
save_generated_images(all_images, state, seed)
|
||||
|
||||
logger.info(f'Images generated in {(t1-t0):.2f} seconds (seed {seed})')
|
||||
return all_images
|
||||
|
|
@ -402,30 +709,3 @@ def download_image_model_wrapper(model_path):
|
|||
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
|
||||
except Exception:
|
||||
yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update()
|
||||
|
||||
|
||||
def save_generated_images(images, prompt, seed):
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
folder_path = os.path.join("user_data", "image_outputs", date_str)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
for idx, img in enumerate(images):
|
||||
timestamp = datetime.now().strftime("%H-%M-%S")
|
||||
filename = f"{timestamp}_{seed}_{idx}.png"
|
||||
img.save(os.path.join(folder_path, filename))
|
||||
|
||||
|
||||
def get_history_images():
|
||||
output_dir = os.path.join("user_data", "image_outputs")
|
||||
if not os.path.exists(output_dir):
|
||||
return []
|
||||
|
||||
image_files = []
|
||||
for root, _, files in os.walk(output_dir):
|
||||
for file in files:
|
||||
if file.endswith((".png", ".jpg", ".jpeg")):
|
||||
full_path = os.path.join(root, file)
|
||||
image_files.append((full_path, os.path.getmtime(full_path)))
|
||||
|
||||
image_files.sort(key=lambda x: x[1], reverse=True)
|
||||
return [x[0] for x in image_files]
|
||||
|
|
|
|||
Loading…
Reference in a new issue