Add PNG metadata, add pagination to Gallery tab

This commit is contained in:
oobabooga 2025-12-01 15:41:58 -08:00
parent c8e9d7fc37
commit 6a7209a842
2 changed files with 363 additions and 37 deletions

View file

@ -1688,6 +1688,52 @@ button#swap-height-width {
} }
#image-history-gallery, #image-history-gallery > :nth-child(2) { #image-history-gallery, #image-history-gallery > :nth-child(2) {
height: calc(100vh - 139px); height: calc(100vh - 174px);
max-height: calc(100vh - 139px); max-height: calc(100vh - 174px);
}
/* Additional CSS for the paginated image gallery */
/* Page info styling */
#image-page-info {
display: flex;
align-items: center;
justify-content: center;
min-width: 200px;
font-size: 0.9em;
color: var(--body-text-color-subdued);
}
/* Settings display panel */
#image-ai-tab .settings-display-panel {
background: var(--background-fill-secondary);
padding: 12px;
border-radius: 8px;
font-size: 0.9em;
max-height: 300px;
overflow-y: auto;
margin-top: 8px;
}
/* Gallery status message */
#image-ai-tab .gallery-status {
color: var(--color-accent);
font-size: 0.85em;
margin-top: 4px;
}
/* Pagination button row alignment */
#image-ai-tab .pagination-controls {
display: flex;
align-items: center;
gap: 8px;
flex-wrap: wrap;
}
/* Selected image preview container */
#image-ai-tab .selected-preview-container {
border: 1px solid var(--border-color-primary);
border-radius: 8px;
padding: 8px;
background: var(--background-fill-secondary);
} }

View file

@ -1,3 +1,4 @@
import json
import os import os
import time import time
import traceback import traceback
@ -7,6 +8,8 @@ from pathlib import Path
import gradio as gr import gradio as gr
import numpy as np import numpy as np
import torch import torch
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from modules import shared, ui, utils from modules import shared, ui, utils
from modules.image_models import load_image_model, unload_image_model from modules.image_models import load_image_model, unload_image_model
@ -22,6 +25,24 @@ ASPECT_RATIOS = {
} }
STEP = 32 STEP = 32
IMAGES_PER_PAGE = 64
# Settings keys to save in PNG metadata (Generate tab only)
METADATA_SETTINGS_KEYS = [
'image_prompt',
'image_neg_prompt',
'image_width',
'image_height',
'image_aspect_ratio',
'image_steps',
'image_seed',
'image_batch_size',
'image_batch_count',
]
# Cache for all image paths
_image_cache = []
_cache_timestamp = 0
def round_to_step(value, step=STEP): def round_to_step(value, step=STEP):
@ -93,6 +114,215 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
return new_width, new_height, new_ratio return new_width, new_height, new_ratio
def build_generation_metadata(state, actual_seed):
"""Build metadata dict from generation settings."""
metadata = {}
for key in METADATA_SETTINGS_KEYS:
if key in state:
metadata[key] = state[key]
# Store the actual seed used (not -1)
metadata['image_seed'] = actual_seed
metadata['generated_at'] = datetime.now().isoformat()
metadata['model'] = shared.image_model_name
return metadata
def save_generated_images(images, state, actual_seed):
"""Save images with generation metadata embedded in PNG."""
date_str = datetime.now().strftime("%Y-%m-%d")
folder_path = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder_path, exist_ok=True)
metadata = build_generation_metadata(state, actual_seed)
metadata_json = json.dumps(metadata, ensure_ascii=False)
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
filename = f"{timestamp}_{actual_seed}_{idx}.png"
filepath = os.path.join(folder_path, filename)
# Create PNG metadata
png_info = PngInfo()
png_info.add_text("image_gen_settings", metadata_json)
# Save with metadata
img.save(filepath, pnginfo=png_info)
def read_image_metadata(image_path):
"""Read generation metadata from PNG file."""
try:
with Image.open(image_path) as img:
if hasattr(img, 'text') and 'image_gen_settings' in img.text:
return json.loads(img.text['image_gen_settings'])
except Exception as e:
logger.debug(f"Could not read metadata from {image_path}: {e}")
return None
def format_metadata_for_display(metadata):
"""Format metadata as readable text."""
if not metadata:
return "No generation settings found in this image."
lines = ["**Generation Settings**", ""]
# Display in a nice order
display_order = [
('image_prompt', 'Prompt'),
('image_neg_prompt', 'Negative Prompt'),
('image_width', 'Width'),
('image_height', 'Height'),
('image_aspect_ratio', 'Aspect Ratio'),
('image_steps', 'Steps'),
('image_seed', 'Seed'),
('image_batch_size', 'Batch Size'),
('image_batch_count', 'Batch Count'),
('model', 'Model'),
('generated_at', 'Generated At'),
]
for key, label in display_order:
if key in metadata:
value = metadata[key]
if key in ['image_prompt', 'image_neg_prompt'] and value:
# Truncate long prompts for display
if len(str(value)) > 200:
value = str(value)[:200] + "..."
lines.append(f"**{label}:** {value}")
return "\n\n".join(lines)
def get_all_history_images(force_refresh=False):
"""Get all history images sorted by modification time (newest first). Uses caching."""
global _image_cache, _cache_timestamp
output_dir = os.path.join("user_data", "image_outputs")
if not os.path.exists(output_dir):
return []
# Check if we need to refresh cache
current_time = time.time()
if not force_refresh and _image_cache and (current_time - _cache_timestamp) < 2:
return _image_cache
image_files = []
for root, _, files in os.walk(output_dir):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
full_path = os.path.join(root, file)
image_files.append((full_path, os.path.getmtime(full_path)))
image_files.sort(key=lambda x: x[1], reverse=True)
_image_cache = [x[0] for x in image_files]
_cache_timestamp = current_time
return _image_cache
def get_paginated_images(page=0, force_refresh=False):
"""Get images for a specific page."""
all_images = get_all_history_images(force_refresh)
total_images = len(all_images)
total_pages = max(1, (total_images + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE)
# Clamp page to valid range
page = max(0, min(page, total_pages - 1))
start_idx = page * IMAGES_PER_PAGE
end_idx = min(start_idx + IMAGES_PER_PAGE, total_images)
page_images = all_images[start_idx:end_idx]
return page_images, page, total_pages, total_images
def refresh_gallery(current_page=0):
"""Refresh gallery with current page."""
images, page, total_pages, total_images = get_paginated_images(current_page, force_refresh=True)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def go_to_page(page_num, current_page):
"""Go to a specific page (1-indexed input)."""
try:
page = int(page_num) - 1 # Convert to 0-indexed
except (ValueError, TypeError):
page = current_page
images, page, total_pages, total_images = get_paginated_images(page)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def next_page(current_page):
"""Go to next page."""
images, page, total_pages, total_images = get_paginated_images(current_page + 1)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def prev_page(current_page):
"""Go to previous page."""
images, page, total_pages, total_images = get_paginated_images(current_page - 1)
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
return images, page, page_info
def on_gallery_select(evt: gr.SelectData, current_page):
"""Handle image selection from gallery."""
if evt.index is None:
return "", "Select an image to view its settings"
# Get the current page's images to find the actual file path
all_images = get_all_history_images()
total_images = len(all_images)
# Calculate the actual index in the full list
start_idx = current_page * IMAGES_PER_PAGE
actual_idx = start_idx + evt.index
if actual_idx >= total_images:
return "", "Image not found"
image_path = all_images[actual_idx]
metadata = read_image_metadata(image_path)
metadata_display = format_metadata_for_display(metadata)
return image_path, metadata_display
def send_to_generate(selected_image_path):
"""Load settings from selected image and return updates for all Generate tab inputs."""
if not selected_image_path or not os.path.exists(selected_image_path):
return [gr.update()] * 9 + ["No image selected"]
metadata = read_image_metadata(selected_image_path)
if not metadata:
return [gr.update()] * 9 + ["No settings found in this image"]
# Return updates for each input element in order
updates = [
gr.update(value=metadata.get('image_prompt', '')),
gr.update(value=metadata.get('image_neg_prompt', '')),
gr.update(value=metadata.get('image_width', 1024)),
gr.update(value=metadata.get('image_height', 1024)),
gr.update(value=metadata.get('image_aspect_ratio', '1:1 Square')),
gr.update(value=metadata.get('image_steps', 9)),
gr.update(value=metadata.get('image_seed', -1)),
gr.update(value=metadata.get('image_batch_size', 1)),
gr.update(value=metadata.get('image_batch_count', 1)),
]
status = f"✓ Settings loaded from image (seed: {metadata.get('image_seed', 'unknown')})"
return updates + [status]
def create_ui(): def create_ui():
if shared.settings['image_model_menu'] != 'None': if shared.settings['image_model_menu'] != 'None':
shared.image_model_name = shared.settings['image_model_menu'] shared.image_model_name = shared.settings['image_model_menu']
@ -149,11 +379,40 @@ def create_ui():
with gr.Column(elem_classes=["viewport-container"]): with gr.Column(elem_classes=["viewport-container"]):
shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery") shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery")
# TAB 2: GALLERY # TAB 2: GALLERY (with pagination)
with gr.TabItem("Gallery"): with gr.TabItem("Gallery"):
with gr.Row(): with gr.Row():
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button") with gr.Column(scale=3):
shared.gradio['image_history_gallery'] = gr.Gallery(value=lambda : get_history_images(), label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True, elem_id="image-history-gallery") # Pagination controls
with gr.Row():
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh", elem_classes="refresh-button")
shared.gradio['image_prev_page'] = gr.Button("◀ Prev", elem_classes="refresh-button")
shared.gradio['image_page_info'] = gr.Markdown("Loading...", elem_id="image-page-info")
shared.gradio['image_next_page'] = gr.Button("Next ▶", elem_classes="refresh-button")
shared.gradio['image_page_input'] = gr.Number(value=1, label="Page", precision=0, minimum=1, scale=0, min_width=80)
shared.gradio['image_go_to_page'] = gr.Button("Go", elem_classes="refresh-button", scale=0, min_width=50)
# State for current page and selected image path
shared.gradio['image_current_page'] = gr.State(value=0)
shared.gradio['image_selected_path'] = gr.State(value="")
# Paginated gallery using gr.Gallery
shared.gradio['image_history_gallery'] = gr.Gallery(
value=lambda: get_paginated_images(0)[0],
label="Image History",
show_label=False,
columns=6,
object_fit="cover",
height="auto",
allow_preview=True,
elem_id="image-history-gallery"
)
with gr.Column(scale=1):
gr.Markdown("### Selected Image")
shared.gradio['image_settings_display'] = gr.Markdown("Select an image to view its settings")
shared.gradio['image_send_to_generate'] = gr.Button("📤 Send to Generate", variant="primary")
shared.gradio['image_gallery_status'] = gr.Markdown("")
# TAB 3: MODEL # TAB 3: MODEL
with gr.TabItem("Model"): with gr.TabItem("Model"):
@ -281,11 +540,59 @@ def create_event_handlers():
show_progress=True show_progress=True
) )
# History # Gallery pagination handlers
shared.gradio['image_refresh_history'].click( shared.gradio['image_refresh_history'].click(
get_history_images, refresh_gallery,
None, gradio('image_current_page'),
gradio('image_history_gallery'), gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
shared.gradio['image_next_page'].click(
next_page,
gradio('image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
shared.gradio['image_prev_page'].click(
prev_page,
gradio('image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
shared.gradio['image_go_to_page'].click(
go_to_page,
gradio('image_page_input', 'image_current_page'),
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
show_progress=False
)
# Image selection from gallery
shared.gradio['image_history_gallery'].select(
on_gallery_select,
gradio('image_current_page'),
gradio('image_selected_path', 'image_settings_display'),
show_progress=False
)
# Send to Generate
shared.gradio['image_send_to_generate'].click(
send_to_generate,
gradio('image_selected_path'),
gradio(
'image_prompt',
'image_neg_prompt',
'image_width',
'image_height',
'image_aspect_ratio',
'image_steps',
'image_seed',
'image_batch_size',
'image_batch_count',
'image_gallery_status'
),
show_progress=False show_progress=False
) )
@ -334,7 +641,7 @@ def generate(state):
all_images.extend(batch_results) all_images.extend(batch_results)
t1 = time.time() t1 = time.time()
save_generated_images(all_images, state['image_prompt'], seed) save_generated_images(all_images, state, seed)
logger.info(f'Images generated in {(t1-t0):.2f} seconds (seed {seed})') logger.info(f'Images generated in {(t1-t0):.2f} seconds (seed {seed})')
return all_images return all_images
@ -402,30 +709,3 @@ def download_image_model_wrapper(model_path):
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name) yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
except Exception: except Exception:
yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update() yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update()
def save_generated_images(images, prompt, seed):
date_str = datetime.now().strftime("%Y-%m-%d")
folder_path = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder_path, exist_ok=True)
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
filename = f"{timestamp}_{seed}_{idx}.png"
img.save(os.path.join(folder_path, filename))
def get_history_images():
output_dir = os.path.join("user_data", "image_outputs")
if not os.path.exists(output_dir):
return []
image_files = []
for root, _, files in os.walk(output_dir):
for file in files:
if file.endswith((".png", ".jpg", ".jpeg")):
full_path = os.path.join(root, file)
image_files.append((full_path, os.path.getmtime(full_path)))
image_files.sort(key=lambda x: x[1], reverse=True)
return [x[0] for x in image_files]