From 6a7209a8422c94dde56f4638c233532c2e7ce002 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 1 Dec 2025 15:41:58 -0800 Subject: [PATCH] Add PNG metadata, add pagination to Gallery tab --- css/main.css | 50 ++++- modules/ui_image_generation.py | 350 +++++++++++++++++++++++++++++---- 2 files changed, 363 insertions(+), 37 deletions(-) diff --git a/css/main.css b/css/main.css index 26687eb4..0bfdca0a 100644 --- a/css/main.css +++ b/css/main.css @@ -1688,6 +1688,52 @@ button#swap-height-width { } #image-history-gallery, #image-history-gallery > :nth-child(2) { - height: calc(100vh - 139px); - max-height: calc(100vh - 139px); + height: calc(100vh - 174px); + max-height: calc(100vh - 174px); +} + +/* Additional CSS for the paginated image gallery */ + +/* Page info styling */ +#image-page-info { + display: flex; + align-items: center; + justify-content: center; + min-width: 200px; + font-size: 0.9em; + color: var(--body-text-color-subdued); +} + +/* Settings display panel */ +#image-ai-tab .settings-display-panel { + background: var(--background-fill-secondary); + padding: 12px; + border-radius: 8px; + font-size: 0.9em; + max-height: 300px; + overflow-y: auto; + margin-top: 8px; +} + +/* Gallery status message */ +#image-ai-tab .gallery-status { + color: var(--color-accent); + font-size: 0.85em; + margin-top: 4px; +} + +/* Pagination button row alignment */ +#image-ai-tab .pagination-controls { + display: flex; + align-items: center; + gap: 8px; + flex-wrap: wrap; +} + +/* Selected image preview container */ +#image-ai-tab .selected-preview-container { + border: 1px solid var(--border-color-primary); + border-radius: 8px; + padding: 8px; + background: var(--background-fill-secondary); } diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py index d9e79973..b202f6cc 100644 --- a/modules/ui_image_generation.py +++ b/modules/ui_image_generation.py @@ -1,3 +1,4 @@ +import json import os import time import traceback @@ -7,6 +8,8 @@ from pathlib import Path import gradio as gr import numpy as np import torch +from PIL import Image +from PIL.PngImagePlugin import PngInfo from modules import shared, ui, utils from modules.image_models import load_image_model, unload_image_model @@ -22,6 +25,24 @@ ASPECT_RATIOS = { } STEP = 32 +IMAGES_PER_PAGE = 64 + +# Settings keys to save in PNG metadata (Generate tab only) +METADATA_SETTINGS_KEYS = [ + 'image_prompt', + 'image_neg_prompt', + 'image_width', + 'image_height', + 'image_aspect_ratio', + 'image_steps', + 'image_seed', + 'image_batch_size', + 'image_batch_count', +] + +# Cache for all image paths +_image_cache = [] +_cache_timestamp = 0 def round_to_step(value, step=STEP): @@ -93,6 +114,215 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio): return new_width, new_height, new_ratio +def build_generation_metadata(state, actual_seed): + """Build metadata dict from generation settings.""" + metadata = {} + for key in METADATA_SETTINGS_KEYS: + if key in state: + metadata[key] = state[key] + + # Store the actual seed used (not -1) + metadata['image_seed'] = actual_seed + metadata['generated_at'] = datetime.now().isoformat() + metadata['model'] = shared.image_model_name + + return metadata + + +def save_generated_images(images, state, actual_seed): + """Save images with generation metadata embedded in PNG.""" + date_str = datetime.now().strftime("%Y-%m-%d") + folder_path = os.path.join("user_data", "image_outputs", date_str) + os.makedirs(folder_path, exist_ok=True) + + metadata = build_generation_metadata(state, actual_seed) + metadata_json = json.dumps(metadata, ensure_ascii=False) + + for idx, img in enumerate(images): + timestamp = datetime.now().strftime("%H-%M-%S") + filename = f"{timestamp}_{actual_seed}_{idx}.png" + filepath = os.path.join(folder_path, filename) + + # Create PNG metadata + png_info = PngInfo() + png_info.add_text("image_gen_settings", metadata_json) + + # Save with metadata + img.save(filepath, pnginfo=png_info) + + +def read_image_metadata(image_path): + """Read generation metadata from PNG file.""" + try: + with Image.open(image_path) as img: + if hasattr(img, 'text') and 'image_gen_settings' in img.text: + return json.loads(img.text['image_gen_settings']) + except Exception as e: + logger.debug(f"Could not read metadata from {image_path}: {e}") + return None + + +def format_metadata_for_display(metadata): + """Format metadata as readable text.""" + if not metadata: + return "No generation settings found in this image." + + lines = ["**Generation Settings**", ""] + + # Display in a nice order + display_order = [ + ('image_prompt', 'Prompt'), + ('image_neg_prompt', 'Negative Prompt'), + ('image_width', 'Width'), + ('image_height', 'Height'), + ('image_aspect_ratio', 'Aspect Ratio'), + ('image_steps', 'Steps'), + ('image_seed', 'Seed'), + ('image_batch_size', 'Batch Size'), + ('image_batch_count', 'Batch Count'), + ('model', 'Model'), + ('generated_at', 'Generated At'), + ] + + for key, label in display_order: + if key in metadata: + value = metadata[key] + if key in ['image_prompt', 'image_neg_prompt'] and value: + # Truncate long prompts for display + if len(str(value)) > 200: + value = str(value)[:200] + "..." + lines.append(f"**{label}:** {value}") + + return "\n\n".join(lines) + + +def get_all_history_images(force_refresh=False): + """Get all history images sorted by modification time (newest first). Uses caching.""" + global _image_cache, _cache_timestamp + + output_dir = os.path.join("user_data", "image_outputs") + if not os.path.exists(output_dir): + return [] + + # Check if we need to refresh cache + current_time = time.time() + if not force_refresh and _image_cache and (current_time - _cache_timestamp) < 2: + return _image_cache + + image_files = [] + for root, _, files in os.walk(output_dir): + for file in files: + if file.endswith((".png", ".jpg", ".jpeg")): + full_path = os.path.join(root, file) + image_files.append((full_path, os.path.getmtime(full_path))) + + image_files.sort(key=lambda x: x[1], reverse=True) + _image_cache = [x[0] for x in image_files] + _cache_timestamp = current_time + + return _image_cache + + +def get_paginated_images(page=0, force_refresh=False): + """Get images for a specific page.""" + all_images = get_all_history_images(force_refresh) + total_images = len(all_images) + total_pages = max(1, (total_images + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE) + + # Clamp page to valid range + page = max(0, min(page, total_pages - 1)) + + start_idx = page * IMAGES_PER_PAGE + end_idx = min(start_idx + IMAGES_PER_PAGE, total_images) + + page_images = all_images[start_idx:end_idx] + + return page_images, page, total_pages, total_images + + +def refresh_gallery(current_page=0): + """Refresh gallery with current page.""" + images, page, total_pages, total_images = get_paginated_images(current_page, force_refresh=True) + page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)" + return images, page, page_info + + +def go_to_page(page_num, current_page): + """Go to a specific page (1-indexed input).""" + try: + page = int(page_num) - 1 # Convert to 0-indexed + except (ValueError, TypeError): + page = current_page + + images, page, total_pages, total_images = get_paginated_images(page) + page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)" + return images, page, page_info + + +def next_page(current_page): + """Go to next page.""" + images, page, total_pages, total_images = get_paginated_images(current_page + 1) + page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)" + return images, page, page_info + + +def prev_page(current_page): + """Go to previous page.""" + images, page, total_pages, total_images = get_paginated_images(current_page - 1) + page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)" + return images, page, page_info + + +def on_gallery_select(evt: gr.SelectData, current_page): + """Handle image selection from gallery.""" + if evt.index is None: + return "", "Select an image to view its settings" + + # Get the current page's images to find the actual file path + all_images = get_all_history_images() + total_images = len(all_images) + + # Calculate the actual index in the full list + start_idx = current_page * IMAGES_PER_PAGE + actual_idx = start_idx + evt.index + + if actual_idx >= total_images: + return "", "Image not found" + + image_path = all_images[actual_idx] + metadata = read_image_metadata(image_path) + metadata_display = format_metadata_for_display(metadata) + + return image_path, metadata_display + + +def send_to_generate(selected_image_path): + """Load settings from selected image and return updates for all Generate tab inputs.""" + if not selected_image_path or not os.path.exists(selected_image_path): + return [gr.update()] * 9 + ["No image selected"] + + metadata = read_image_metadata(selected_image_path) + if not metadata: + return [gr.update()] * 9 + ["No settings found in this image"] + + # Return updates for each input element in order + updates = [ + gr.update(value=metadata.get('image_prompt', '')), + gr.update(value=metadata.get('image_neg_prompt', '')), + gr.update(value=metadata.get('image_width', 1024)), + gr.update(value=metadata.get('image_height', 1024)), + gr.update(value=metadata.get('image_aspect_ratio', '1:1 Square')), + gr.update(value=metadata.get('image_steps', 9)), + gr.update(value=metadata.get('image_seed', -1)), + gr.update(value=metadata.get('image_batch_size', 1)), + gr.update(value=metadata.get('image_batch_count', 1)), + ] + + status = f"✓ Settings loaded from image (seed: {metadata.get('image_seed', 'unknown')})" + return updates + [status] + + + def create_ui(): if shared.settings['image_model_menu'] != 'None': shared.image_model_name = shared.settings['image_model_menu'] @@ -149,11 +379,40 @@ def create_ui(): with gr.Column(elem_classes=["viewport-container"]): shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery") - # TAB 2: GALLERY + # TAB 2: GALLERY (with pagination) with gr.TabItem("Gallery"): with gr.Row(): - shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button") - shared.gradio['image_history_gallery'] = gr.Gallery(value=lambda : get_history_images(), label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True, elem_id="image-history-gallery") + with gr.Column(scale=3): + # Pagination controls + with gr.Row(): + shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh", elem_classes="refresh-button") + shared.gradio['image_prev_page'] = gr.Button("◀ Prev", elem_classes="refresh-button") + shared.gradio['image_page_info'] = gr.Markdown("Loading...", elem_id="image-page-info") + shared.gradio['image_next_page'] = gr.Button("Next ▶", elem_classes="refresh-button") + shared.gradio['image_page_input'] = gr.Number(value=1, label="Page", precision=0, minimum=1, scale=0, min_width=80) + shared.gradio['image_go_to_page'] = gr.Button("Go", elem_classes="refresh-button", scale=0, min_width=50) + + # State for current page and selected image path + shared.gradio['image_current_page'] = gr.State(value=0) + shared.gradio['image_selected_path'] = gr.State(value="") + + # Paginated gallery using gr.Gallery + shared.gradio['image_history_gallery'] = gr.Gallery( + value=lambda: get_paginated_images(0)[0], + label="Image History", + show_label=False, + columns=6, + object_fit="cover", + height="auto", + allow_preview=True, + elem_id="image-history-gallery" + ) + + with gr.Column(scale=1): + gr.Markdown("### Selected Image") + shared.gradio['image_settings_display'] = gr.Markdown("Select an image to view its settings") + shared.gradio['image_send_to_generate'] = gr.Button("📤 Send to Generate", variant="primary") + shared.gradio['image_gallery_status'] = gr.Markdown("") # TAB 3: MODEL with gr.TabItem("Model"): @@ -281,11 +540,59 @@ def create_event_handlers(): show_progress=True ) - # History + # Gallery pagination handlers shared.gradio['image_refresh_history'].click( - get_history_images, - None, - gradio('image_history_gallery'), + refresh_gallery, + gradio('image_current_page'), + gradio('image_history_gallery', 'image_current_page', 'image_page_info'), + show_progress=False + ) + + shared.gradio['image_next_page'].click( + next_page, + gradio('image_current_page'), + gradio('image_history_gallery', 'image_current_page', 'image_page_info'), + show_progress=False + ) + + shared.gradio['image_prev_page'].click( + prev_page, + gradio('image_current_page'), + gradio('image_history_gallery', 'image_current_page', 'image_page_info'), + show_progress=False + ) + + shared.gradio['image_go_to_page'].click( + go_to_page, + gradio('image_page_input', 'image_current_page'), + gradio('image_history_gallery', 'image_current_page', 'image_page_info'), + show_progress=False + ) + + # Image selection from gallery + shared.gradio['image_history_gallery'].select( + on_gallery_select, + gradio('image_current_page'), + gradio('image_selected_path', 'image_settings_display'), + show_progress=False + ) + + # Send to Generate + shared.gradio['image_send_to_generate'].click( + send_to_generate, + gradio('image_selected_path'), + gradio( + 'image_prompt', + 'image_neg_prompt', + 'image_width', + 'image_height', + 'image_aspect_ratio', + 'image_steps', + 'image_seed', + 'image_batch_size', + 'image_batch_count', + 'image_gallery_status' + ), show_progress=False ) @@ -334,7 +641,7 @@ def generate(state): all_images.extend(batch_results) t1 = time.time() - save_generated_images(all_images, state['image_prompt'], seed) + save_generated_images(all_images, state, seed) logger.info(f'Images generated in {(t1-t0):.2f} seconds (seed {seed})') return all_images @@ -402,30 +709,3 @@ def download_image_model_wrapper(model_path): yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name) except Exception: yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update() - - -def save_generated_images(images, prompt, seed): - date_str = datetime.now().strftime("%Y-%m-%d") - folder_path = os.path.join("user_data", "image_outputs", date_str) - os.makedirs(folder_path, exist_ok=True) - - for idx, img in enumerate(images): - timestamp = datetime.now().strftime("%H-%M-%S") - filename = f"{timestamp}_{seed}_{idx}.png" - img.save(os.path.join(folder_path, filename)) - - -def get_history_images(): - output_dir = os.path.join("user_data", "image_outputs") - if not os.path.exists(output_dir): - return [] - - image_files = [] - for root, _, files in os.walk(output_dir): - for file in files: - if file.endswith((".png", ".jpg", ".jpeg")): - full_path = os.path.join(root, file) - image_files.append((full_path, os.path.getmtime(full_path))) - - image_files.sort(key=lambda x: x[1], reverse=True) - return [x[0] for x in image_files]