Add PNG metadata, add pagination to Gallery tab

This commit is contained in:
oobabooga 2025-12-01 15:41:58 -08:00
parent c8e9d7fc37
commit 6a7209a842
2 changed files with 363 additions and 37 deletions

View file

@ -1688,6 +1688,52 @@ button#swap-height-width {
}
#image-history-gallery, #image-history-gallery > :nth-child(2) {
height: calc(100vh - 139px);
max-height: calc(100vh - 139px);
height: calc(100vh - 174px);
max-height: calc(100vh - 174px);
}
/* Additional CSS for the paginated image gallery */
/* Page info styling */
#image-page-info {
display: flex;
align-items: center;
justify-content: center;
min-width: 200px;
font-size: 0.9em;
color: var(--body-text-color-subdued);
}
/* Settings display panel */
#image-ai-tab .settings-display-panel {
background: var(--background-fill-secondary);
padding: 12px;
border-radius: 8px;
font-size: 0.9em;
max-height: 300px;
overflow-y: auto;
margin-top: 8px;
}
/* Gallery status message */
#image-ai-tab .gallery-status {
color: var(--color-accent);
font-size: 0.85em;
margin-top: 4px;
}
/* Pagination button row alignment */
#image-ai-tab .pagination-controls {
display: flex;
align-items: center;
gap: 8px;
flex-wrap: wrap;
}
/* Selected image preview container */
#image-ai-tab .selected-preview-container {
border: 1px solid var(--border-color-primary);
border-radius: 8px;
padding: 8px;
background: var(--background-fill-secondary);
}

View file

@ -1,3 +1,4 @@
import json
import os
import time
import traceback
@ -7,6 +8,8 @@ from pathlib import Path
import gradio as gr
import numpy as np
import torch
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from modules import shared, ui, utils
from modules.image_models import load_image_model, unload_image_model
@ -22,6 +25,24 @@ ASPECT_RATIOS = {
}
STEP = 32
IMAGES_PER_PAGE = 64
# Settings keys to save in PNG metadata (Generate tab only)
METADATA_SETTINGS_KEYS = [
'image_prompt',
'image_neg_prompt',
'image_width',
'image_height',
'image_aspect_ratio',
'image_steps',
'image_seed',
'image_batch_size',
'image_batch_count',
]
# Cache for all image paths
_image_cache = []
_cache_timestamp = 0
def round_to_step(value, step=STEP):
@ -93,6 +114,215 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
return new_width, new_height, new_ratio
def build_generation_metadata(state, actual_seed):
"""Build metadata dict from generation settings."""
metadata = {}
for key in METADATA_SETTINGS_KEYS:
if key in state:
metadata[key] = state[key]
# Store the actual seed used (not -1)
metadata['image_seed'] = actual_seed
metadata['generated_at'] = datetime.now().isoformat()
metadata['model'] = shared.image_model_name
return metadata
def save_generated_images(images, state, actual_seed):
"""Save images with generation metadata embedded in PNG."""
date_str = datetime.now().strftime("%Y-%m-%d")
folder_path = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder_path, exist_ok=True)
metadata = build_generation_metadata(state, actual_seed)
metadata_json = json.dumps(metadata, ensure_ascii=False)
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
filename = f"{timestamp}_{actual_seed}_{idx}.png"
filepath = os.path.join(folder_path, filename)
# Create PNG metadata
png_info = PngInfo()
png_info.add_text("image_gen_settings", metadata_json)
# Save with metadata
img.save(filepath, pnginfo=png_info)
def read_image_metadata(image_path):
"""Read generation metadata from PNG file."""
try:
with Image.open(image_path) as img:
if hasattr(img, 'text') and 'image_gen_settings' in img.text:
return json.loads(img.text['image_gen_settings'])
except Exception as e:
logger.debug(f"Could not read metadata from {image_path}: {e}")
return None
def format_metadata_for_display(metadata):
"""Format metadata as readable text."""
if not metadata:
return "No generation settings found in this image."
lines = ["**Generation Settings**", ""]
# Display in a nice order
display_order = [
('image_prompt', 'Prompt'),
('image_neg_prompt', 'Negative Prompt'),
('image_width', 'Width'),
('image_height', 'Height'),
('image_aspect_ratio', 'Aspect Ratio'),
('image_steps', 'Steps'),
('image_seed', 'Seed'),
('image_batch_size', 'Batch Size'),
('image_batch_count', 'Batch Count'),
('model', 'Model'),
('generated_at', 'Generated At'),
]
for key, label in display_order:
if key in metadata:
value = metadata[key]
if key in ['image_prompt', 'image_neg_prompt'] and value:
# Truncate long prompts for display
if len(str(value)) > 200:
value = str(value)[:200] + "..."
lines.append(f"**{label}:** {value}")
return "\n\n".join(lines)
def get_all_history_images(force_refresh=False):
"""Get all history images sorted by modification time (newest first). Uses caching."""
global _image_cache, _cache_timestamp
output_dir = os.path.join("user_data", "image_outputs")
if not os.path.exists(output_dir):
return []
# Check if we need to refresh cache
current_time = time.time()
if not force_refresh and _image_cache and (current_time - _cache_timestamp) < 2:
return _image_cache
image_files = []
for root, _, files in os.walk(output_dir):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
full_path = os.path.join(root, file)
image_files.append((full_path, os.path.getmtime(full_path)))
image_files.sort(key=lambda x: x[1], reverse=True)
_image_cache = [x[0] for x in image_files]
_cache_timestamp = current_time
return _image_cache
def get_paginated_images(page=0, force_refresh=False):
"""Get images for a specific page."""
all_images = get_all_history_images(force_refresh)
total_images = len(all_images)
total_pages = max(1, (total_images + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE)
# Clamp page to valid range
page = max(0, min(page, total_pages - 1))
start_idx = page * IMAGES_PER_PAGE
end_idx = min(start_idx + IMAGES_PER_PAGE, total_images)
page_images = all_images[start_idx:end_idx]
return page_images, page, total_pages, total_images
def refresh_gallery(current_page=0):
"""Refresh gallery with current page."""
images, page, total_pages, total_images = get_paginated_images(current_page, force_refresh=True)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def go_to_page(page_num, current_page):
"""Go to a specific page (1-indexed input)."""
try:
page = int(page_num) - 1 # Convert to 0-indexed
except (ValueError, TypeError):
page = current_page
images, page, total_pages, total_images = get_paginated_images(page)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def next_page(current_page):
"""Go to next page."""
images, page, total_pages, total_images = get_paginated_images(current_page + 1)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def prev_page(current_page):
"""Go to previous page."""
images, page, total_pages, total_images = get_paginated_images(current_page - 1)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def on_gallery_select(evt: gr.SelectData, current_page):
"""Handle image selection from gallery."""
if evt.index is None:
return "", "Select an image to view its settings"
# Get the current page's images to find the actual file path
all_images = get_all_history_images()
total_images = len(all_images)
# Calculate the actual index in the full list
start_idx = current_page * IMAGES_PER_PAGE
actual_idx = start_idx + evt.index
if actual_idx >= total_images:
return "", "Image not found"
image_path = all_images[actual_idx]
metadata = read_image_metadata(image_path)
metadata_display = format_metadata_for_display(metadata)
return image_path, metadata_display
def send_to_generate(selected_image_path):
"""Load settings from selected image and return updates for all Generate tab inputs."""
if not selected_image_path or not os.path.exists(selected_image_path):
return [gr.update()] * 9 + ["No image selected"]
metadata = read_image_metadata(selected_image_path)
if not metadata:
return [gr.update()] * 9 + ["No settings found in this image"]
# Return updates for each input element in order
updates = [
gr.update(value=metadata.get('image_prompt', '')),
gr.update(value=metadata.get('image_neg_prompt', '')),
gr.update(value=metadata.get('image_width', 1024)),
gr.update(value=metadata.get('image_height', 1024)),
gr.update(value=metadata.get('image_aspect_ratio', '1:1 Square')),
gr.update(value=metadata.get('image_steps', 9)),
gr.update(value=metadata.get('image_seed', -1)),
gr.update(value=metadata.get('image_batch_size', 1)),
gr.update(value=metadata.get('image_batch_count', 1)),
]
status = f"✓ Settings loaded from image (seed: {metadata.get('image_seed', 'unknown')})"
return updates + [status]
def create_ui():
if shared.settings['image_model_menu'] != 'None':
shared.image_model_name = shared.settings['image_model_menu']
@ -149,11 +379,40 @@ def create_ui():
with gr.Column(elem_classes=["viewport-container"]):
shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery")
# TAB 2: GALLERY
# TAB 2: GALLERY (with pagination)
with gr.TabItem("Gallery"):
with gr.Row():
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
shared.gradio['image_history_gallery'] = gr.Gallery(value=lambda : get_history_images(), label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True, elem_id="image-history-gallery")
with gr.Column(scale=3):
# Pagination controls
with gr.Row():
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh", elem_classes="refresh-button")
shared.gradio['image_prev_page'] = gr.Button("◀ Prev", elem_classes="refresh-button")
shared.gradio['image_page_info'] = gr.Markdown("Loading...", elem_id="image-page-info")
shared.gradio['image_next_page'] = gr.Button("Next ▶", elem_classes="refresh-button")
shared.gradio['image_page_input'] = gr.Number(value=1, label="Page", precision=0, minimum=1, scale=0, min_width=80)
shared.gradio['image_go_to_page'] = gr.Button("Go", elem_classes="refresh-button", scale=0, min_width=50)
# State for current page and selected image path
shared.gradio['image_current_page'] = gr.State(value=0)
shared.gradio['image_selected_path'] = gr.State(value="")
# Paginated gallery using gr.Gallery
shared.gradio['image_history_gallery'] = gr.Gallery(
value=lambda: get_paginated_images(0)[0],
label="Image History",
show_label=False,
columns=6,
object_fit="cover",
height="auto",
allow_preview=True,
elem_id="image-history-gallery"
)
with gr.Column(scale=1):
gr.Markdown("### Selected Image")
shared.gradio['image_settings_display'] = gr.Markdown("Select an image to view its settings")
shared.gradio['image_send_to_generate'] = gr.Button("📤 Send to Generate", variant="primary")
shared.gradio['image_gallery_status'] = gr.Markdown("")
# TAB 3: MODEL
with gr.TabItem("Model"):
@ -281,11 +540,59 @@ def create_event_handlers():
show_progress=True
)
# History
# Gallery pagination handlers
shared.gradio['image_refresh_history'].click(
get_history_images,
None,
gradio('image_history_gallery'),
refresh_gallery,
gradio('image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
shared.gradio['image_next_page'].click(
next_page,
gradio('image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
shared.gradio['image_prev_page'].click(
prev_page,
gradio('image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
shared.gradio['image_go_to_page'].click(
go_to_page,
gradio('image_page_input', 'image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
# Image selection from gallery
shared.gradio['image_history_gallery'].select(
on_gallery_select,
gradio('image_current_page'),
gradio('image_selected_path', 'image_settings_display'),
show_progress=False
)
# Send to Generate
shared.gradio['image_send_to_generate'].click(
send_to_generate,
gradio('image_selected_path'),
gradio(
'image_prompt',
'image_neg_prompt',
'image_width',
'image_height',
'image_aspect_ratio',
'image_steps',
'image_seed',
'image_batch_size',
'image_batch_count',
'image_gallery_status'
),
show_progress=False
)
@ -334,7 +641,7 @@ def generate(state):
all_images.extend(batch_results)
t1 = time.time()
save_generated_images(all_images, state['image_prompt'], seed)
save_generated_images(all_images, state, seed)
logger.info(f'Images generated in {(t1-t0):.2f} seconds (seed {seed})')
return all_images
@ -402,30 +709,3 @@ def download_image_model_wrapper(model_path):
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
except Exception:
yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update()
def save_generated_images(images, prompt, seed):
date_str = datetime.now().strftime("%Y-%m-%d")
folder_path = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder_path, exist_ok=True)
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
filename = f"{timestamp}_{seed}_{idx}.png"
img.save(os.path.join(folder_path, filename))
def get_history_images():
output_dir = os.path.join("user_data", "image_outputs")
if not os.path.exists(output_dir):
return []
image_files = []
for root, _, files in os.walk(output_dir):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
full_path = os.path.join(root, file)
image_files.append((full_path, os.path.getmtime(full_path)))
image_files.sort(key=lambda x: x[1], reverse=True)
return [x[0] for x in image_files]