mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Compare commits
16 commits
366fe353f0
...
f45412676d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f45412676d | ||
|
|
5c61fcf479 | ||
|
|
d75d7a3a63 | ||
|
|
151b552bc3 | ||
|
|
322aab3410 | ||
|
|
f46f49e26c | ||
|
|
225b8c326b | ||
|
|
5fb1380ac1 | ||
|
|
7dfb6e9c57 | ||
|
|
a7808f7f42 | ||
|
|
748e2e55fd | ||
|
|
6a7209a842 | ||
|
|
c8e9d7fc37 | ||
|
|
b4738beaf8 | ||
|
|
75796f5a58 | ||
|
|
990f0e2468 |
56
css/main.css
56
css/main.css
|
|
@ -1681,3 +1681,59 @@ button#swap-height-width {
|
|||
right: 0;
|
||||
border: 0;
|
||||
}
|
||||
|
||||
#image-output-gallery, #image-output-gallery > :nth-child(2) {
|
||||
height: calc(100vh - 83px);
|
||||
max-height: calc(100vh - 83px);
|
||||
}
|
||||
|
||||
#image-history-gallery, #image-history-gallery > :nth-child(2) {
|
||||
height: calc(100vh - 174px);
|
||||
max-height: calc(100vh - 174px);
|
||||
}
|
||||
|
||||
/* Additional CSS for the paginated image gallery */
|
||||
|
||||
/* Page info styling */
|
||||
#image-page-info {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
min-width: 200px;
|
||||
font-size: 0.9em;
|
||||
color: var(--body-text-color-subdued);
|
||||
}
|
||||
|
||||
/* Settings display panel */
|
||||
#image-ai-tab .settings-display-panel {
|
||||
background: var(--background-fill-secondary);
|
||||
padding: 12px;
|
||||
border-radius: 8px;
|
||||
font-size: 0.9em;
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
/* Gallery status message */
|
||||
#image-ai-tab .gallery-status {
|
||||
color: var(--color-accent);
|
||||
font-size: 0.85em;
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
/* Pagination button row alignment */
|
||||
#image-ai-tab .pagination-controls {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
/* Selected image preview container */
|
||||
#image-ai-tab .selected-preview-container {
|
||||
border: 1px solid var(--border-color-primary);
|
||||
border-radius: 8px;
|
||||
padding: 8px;
|
||||
background: var(--background-fill-secondary);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,97 @@
|
|||
import time
|
||||
|
||||
import torch
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.torch_utils import get_device
|
||||
from modules.utils import resolve_model_path
|
||||
|
||||
|
||||
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False):
|
||||
def get_quantization_config(quant_method):
|
||||
"""
|
||||
Get the appropriate quantization config based on the selected method.
|
||||
|
||||
Args:
|
||||
quant_method: One of 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
|
||||
|
||||
Returns:
|
||||
PipelineQuantizationConfig or None
|
||||
"""
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig, QuantoConfig
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
|
||||
if quant_method == 'none' or not quant_method:
|
||||
return None
|
||||
|
||||
# Bitsandbytes 8-bit quantization
|
||||
elif quant_method == 'bnb-8bit':
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": BitsAndBytesConfig(
|
||||
load_in_8bit=True
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
# Bitsandbytes 4-bit quantization
|
||||
elif quant_method == 'bnb-4bit':
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
# Quanto 8-bit quantization
|
||||
elif quant_method == 'quanto-8bit':
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": QuantoConfig(weights_dtype="int8")
|
||||
}
|
||||
)
|
||||
|
||||
# Quanto 4-bit quantization
|
||||
elif quant_method == 'quanto-4bit':
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": QuantoConfig(weights_dtype="int4")
|
||||
}
|
||||
)
|
||||
|
||||
# Quanto 2-bit quantization
|
||||
elif quant_method == 'quanto-2bit':
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": QuantoConfig(weights_dtype="int2")
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown quantization method: {quant_method}. Loading without quantization.")
|
||||
return None
|
||||
|
||||
|
||||
def get_pipeline_type(pipe):
|
||||
"""
|
||||
Detect the pipeline type based on the loaded pipeline class.
|
||||
|
||||
Returns:
|
||||
str: 'zimage', 'qwenimage', or 'unknown'
|
||||
"""
|
||||
class_name = pipe.__class__.__name__
|
||||
if 'ZImage' in class_name:
|
||||
return 'zimage'
|
||||
elif 'QwenImage' in class_name:
|
||||
return 'qwenimage'
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
|
||||
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False, quant_method='none'):
|
||||
"""
|
||||
Load a diffusers image generation model.
|
||||
|
||||
|
|
@ -18,10 +101,12 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
|
|||
attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3'
|
||||
cpu_offload: Enable CPU offloading for low VRAM
|
||||
compile_model: Compile the model for faster inference (slow first run)
|
||||
quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
|
||||
"""
|
||||
from diffusers import PipelineQuantizationConfig, ZImagePipeline
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
logger.info(f"Loading image model \"{model_name}\"")
|
||||
logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}")
|
||||
t0 = time.time()
|
||||
|
||||
dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
|
||||
|
|
@ -30,49 +115,49 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
|
|||
model_path = resolve_model_path(model_name, image_model=True)
|
||||
|
||||
try:
|
||||
# Define quantization config for 8-bit
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
)
|
||||
# Get quantization config based on selected method
|
||||
pipeline_quant_config = get_quantization_config(quant_method)
|
||||
|
||||
# Define quantization config for 4-bit
|
||||
# pipeline_quant_config = PipelineQuantizationConfig(
|
||||
# quant_backend="bitsandbytes_4bit",
|
||||
# quant_kwargs={
|
||||
# "load_in_4bit": True,
|
||||
# "bnb_4bit_quant_type": "nf4", # Or "fp4" for floating point
|
||||
# "bnb_4bit_compute_dtype": torch.bfloat16, # For faster computation
|
||||
# "bnb_4bit_use_double_quant": True, # Nested quantization for extra savings
|
||||
# },
|
||||
# )
|
||||
# Load the pipeline
|
||||
load_kwargs = {
|
||||
"torch_dtype": target_dtype,
|
||||
"low_cpu_mem_usage": True,
|
||||
}
|
||||
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
if pipeline_quant_config is not None:
|
||||
load_kwargs["quantization_config"] = pipeline_quant_config
|
||||
|
||||
# Use DiffusionPipeline for automatic pipeline detection
|
||||
# This handles both ZImagePipeline and QwenImagePipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
str(model_path),
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=target_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
**load_kwargs
|
||||
)
|
||||
|
||||
pipeline_type = get_pipeline_type(pipe)
|
||||
|
||||
if not cpu_offload:
|
||||
pipe.to(get_device())
|
||||
|
||||
# Set attention backend
|
||||
if attn_backend == 'flash_attention_2':
|
||||
pipe.transformer.set_attention_backend("flash")
|
||||
elif attn_backend == 'flash_attention_3':
|
||||
pipe.transformer.set_attention_backend("_flash_3")
|
||||
# sdpa is the default, no action needed
|
||||
# Set attention backend (if supported by the pipeline)
|
||||
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'):
|
||||
if attn_backend == 'flash_attention_2':
|
||||
pipe.transformer.set_attention_backend("flash")
|
||||
elif attn_backend == 'flash_attention_3':
|
||||
pipe.transformer.set_attention_backend("_flash_3")
|
||||
# sdpa is the default, no action needed
|
||||
|
||||
if compile_model:
|
||||
logger.info("Compiling model (first run will be slow)...")
|
||||
pipe.transformer.compile()
|
||||
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'compile'):
|
||||
logger.info("Compiling model (first run will be slow)...")
|
||||
pipe.transformer.compile()
|
||||
|
||||
if cpu_offload:
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
shared.image_model = pipe
|
||||
shared.image_model_name = model_name
|
||||
shared.image_pipeline_type = pipeline_type
|
||||
|
||||
logger.info(f"Loaded image model \"{model_name}\" in {(time.time() - t0):.2f} seconds.")
|
||||
return pipe
|
||||
|
|
@ -90,6 +175,7 @@ def unload_image_model():
|
|||
del shared.image_model
|
||||
shared.image_model = None
|
||||
shared.image_model_name = 'None'
|
||||
shared.image_pipeline_type = None
|
||||
|
||||
from modules.torch_utils import clear_torch_cache
|
||||
clear_torch_cache()
|
||||
|
|
|
|||
|
|
@ -58,6 +58,9 @@ group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16',
|
|||
group.add_argument('--image-attn-backend', type=str, default=None, choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.')
|
||||
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
|
||||
group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.')
|
||||
group.add_argument('--image-quant', type=str, default=None,
|
||||
choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'],
|
||||
help='Quantization method for image model.')
|
||||
|
||||
# Model loader
|
||||
group = parser.add_argument_group('Model loader')
|
||||
|
|
@ -311,14 +314,16 @@ settings = {
|
|||
'image_height': 1024,
|
||||
'image_aspect_ratio': '1:1 Square',
|
||||
'image_steps': 9,
|
||||
'image_cfg_scale': 0.0,
|
||||
'image_seed': -1,
|
||||
'image_batch_size': 1,
|
||||
'image_batch_count': 1,
|
||||
'image_model_menu': 'None',
|
||||
'image_dtype': 'bfloat16',
|
||||
'image_attn_backend': 'sdpa',
|
||||
'image_compile': False,
|
||||
'image_cpu_offload': False,
|
||||
'image_compile': False,
|
||||
'image_quant': 'none',
|
||||
}
|
||||
|
||||
default_settings = copy.deepcopy(settings)
|
||||
|
|
@ -344,8 +349,8 @@ def do_cmd_flags_warnings():
|
|||
|
||||
|
||||
def apply_image_model_cli_overrides():
|
||||
"""Apply CLI flags for image model settings, overriding saved settings."""
|
||||
if args.image_model:
|
||||
"""Apply command-line overrides for image model settings."""
|
||||
if args.image_model is not None:
|
||||
settings['image_model_menu'] = args.image_model
|
||||
if args.image_dtype is not None:
|
||||
settings['image_dtype'] = args.image_dtype
|
||||
|
|
@ -355,6 +360,9 @@ def apply_image_model_cli_overrides():
|
|||
settings['image_cpu_offload'] = True
|
||||
if args.image_compile:
|
||||
settings['image_compile'] = True
|
||||
if args.image_quant is not None:
|
||||
settings['image_quant'] = args.image_quant
|
||||
|
||||
|
||||
|
||||
def fix_loader_name(name):
|
||||
|
|
|
|||
|
|
@ -288,6 +288,7 @@ def list_interface_input_elements():
|
|||
'image_height',
|
||||
'image_aspect_ratio',
|
||||
'image_steps',
|
||||
'image_cfg_scale',
|
||||
'image_seed',
|
||||
'image_batch_size',
|
||||
'image_batch_count',
|
||||
|
|
@ -296,6 +297,7 @@ def list_interface_input_elements():
|
|||
'image_attn_backend',
|
||||
'image_compile',
|
||||
'image_cpu_offload',
|
||||
'image_quant',
|
||||
]
|
||||
|
||||
return elements
|
||||
|
|
@ -530,10 +532,13 @@ def setup_auto_save():
|
|||
'include_past_attachments',
|
||||
|
||||
# Image generation tab (ui_image_generation.py)
|
||||
'image_prompt',
|
||||
'image_neg_prompt',
|
||||
'image_width',
|
||||
'image_height',
|
||||
'image_aspect_ratio',
|
||||
'image_steps',
|
||||
'image_cfg_scale',
|
||||
'image_seed',
|
||||
'image_batch_size',
|
||||
'image_batch_count',
|
||||
|
|
@ -542,6 +547,7 @@ def setup_auto_save():
|
|||
'image_attn_backend',
|
||||
'image_compile',
|
||||
'image_cpu_offload',
|
||||
'image_quant',
|
||||
]
|
||||
|
||||
for element_name in change_elements:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,18 @@
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
from modules import shared, ui, utils
|
||||
from modules.image_models import load_image_model, unload_image_model
|
||||
from modules.logging_colors import logger
|
||||
from modules.utils import gradio
|
||||
|
||||
ASPECT_RATIOS = {
|
||||
|
|
@ -19,7 +23,26 @@ ASPECT_RATIOS = {
|
|||
"Custom": None,
|
||||
}
|
||||
|
||||
STEP = 32
|
||||
STEP = 16
|
||||
IMAGES_PER_PAGE = 64
|
||||
|
||||
# Settings keys to save in PNG metadata (Generate tab only)
|
||||
METADATA_SETTINGS_KEYS = [
|
||||
'image_prompt',
|
||||
'image_neg_prompt',
|
||||
'image_width',
|
||||
'image_height',
|
||||
'image_aspect_ratio',
|
||||
'image_steps',
|
||||
'image_seed',
|
||||
'image_batch_size',
|
||||
'image_batch_count',
|
||||
'image_cfg_scale',
|
||||
]
|
||||
|
||||
# Cache for all image paths
|
||||
_image_cache = []
|
||||
_cache_timestamp = 0
|
||||
|
||||
|
||||
def round_to_step(value, step=STEP):
|
||||
|
|
@ -91,6 +114,216 @@ def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
|
|||
return new_width, new_height, new_ratio
|
||||
|
||||
|
||||
def build_generation_metadata(state, actual_seed):
|
||||
"""Build metadata dict from generation settings."""
|
||||
metadata = {}
|
||||
for key in METADATA_SETTINGS_KEYS:
|
||||
if key in state:
|
||||
metadata[key] = state[key]
|
||||
|
||||
# Store the actual seed used (not -1)
|
||||
metadata['image_seed'] = actual_seed
|
||||
metadata['generated_at'] = datetime.now().isoformat()
|
||||
metadata['model'] = shared.image_model_name
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def save_generated_images(images, state, actual_seed):
|
||||
"""Save images with generation metadata embedded in PNG."""
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
folder_path = os.path.join("user_data", "image_outputs", date_str)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
metadata = build_generation_metadata(state, actual_seed)
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||
|
||||
for idx, img in enumerate(images):
|
||||
timestamp = datetime.now().strftime("%H-%M-%S")
|
||||
filename = f"{timestamp}_{actual_seed:010d}_{idx:03d}.png"
|
||||
filepath = os.path.join(folder_path, filename)
|
||||
|
||||
# Create PNG metadata
|
||||
png_info = PngInfo()
|
||||
png_info.add_text("image_gen_settings", metadata_json)
|
||||
|
||||
# Save with metadata
|
||||
img.save(filepath, pnginfo=png_info)
|
||||
|
||||
|
||||
def read_image_metadata(image_path):
|
||||
"""Read generation metadata from PNG file."""
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
if hasattr(img, 'text') and 'image_gen_settings' in img.text:
|
||||
return json.loads(img.text['image_gen_settings'])
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not read metadata from {image_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def format_metadata_for_display(metadata):
|
||||
"""Format metadata as readable text."""
|
||||
if not metadata:
|
||||
return "No generation settings found in this image."
|
||||
|
||||
lines = ["**Generation Settings**", ""]
|
||||
|
||||
# Display in a nice order
|
||||
display_order = [
|
||||
('image_prompt', 'Prompt'),
|
||||
('image_neg_prompt', 'Negative Prompt'),
|
||||
('image_width', 'Width'),
|
||||
('image_height', 'Height'),
|
||||
('image_aspect_ratio', 'Aspect Ratio'),
|
||||
('image_steps', 'Steps'),
|
||||
('image_cfg_scale', 'CFG Scale'),
|
||||
('image_seed', 'Seed'),
|
||||
('image_batch_size', 'Batch Size'),
|
||||
('image_batch_count', 'Batch Count'),
|
||||
('model', 'Model'),
|
||||
('generated_at', 'Generated At'),
|
||||
]
|
||||
|
||||
for key, label in display_order:
|
||||
if key in metadata:
|
||||
value = metadata[key]
|
||||
if key in ['image_prompt', 'image_neg_prompt'] and value:
|
||||
# Truncate long prompts for display
|
||||
if len(str(value)) > 200:
|
||||
value = str(value)[:200] + "..."
|
||||
lines.append(f"**{label}:** {value}")
|
||||
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def get_all_history_images(force_refresh=False):
|
||||
"""Get all history images sorted by modification time (newest first). Uses caching."""
|
||||
global _image_cache, _cache_timestamp
|
||||
|
||||
output_dir = os.path.join("user_data", "image_outputs")
|
||||
if not os.path.exists(output_dir):
|
||||
return []
|
||||
|
||||
# Check if we need to refresh cache
|
||||
current_time = time.time()
|
||||
if not force_refresh and _image_cache and (current_time - _cache_timestamp) < 2:
|
||||
return _image_cache
|
||||
|
||||
image_files = []
|
||||
for root, _, files in os.walk(output_dir):
|
||||
for file in files:
|
||||
if file.endswith((".png", ".jpg", ".jpeg")):
|
||||
full_path = os.path.join(root, file)
|
||||
image_files.append((full_path, os.path.getmtime(full_path)))
|
||||
|
||||
image_files.sort(key=lambda x: x[1], reverse=True)
|
||||
_image_cache = [x[0] for x in image_files]
|
||||
_cache_timestamp = current_time
|
||||
|
||||
return _image_cache
|
||||
|
||||
|
||||
def get_paginated_images(page=0, force_refresh=False):
|
||||
"""Get images for a specific page."""
|
||||
all_images = get_all_history_images(force_refresh)
|
||||
total_images = len(all_images)
|
||||
total_pages = max(1, (total_images + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE)
|
||||
|
||||
# Clamp page to valid range
|
||||
page = max(0, min(page, total_pages - 1))
|
||||
|
||||
start_idx = page * IMAGES_PER_PAGE
|
||||
end_idx = min(start_idx + IMAGES_PER_PAGE, total_images)
|
||||
|
||||
page_images = all_images[start_idx:end_idx]
|
||||
|
||||
return page_images, page, total_pages, total_images
|
||||
|
||||
|
||||
def refresh_gallery(current_page=0):
|
||||
"""Refresh gallery with current page."""
|
||||
images, page, total_pages, total_images = get_paginated_images(current_page, force_refresh=True)
|
||||
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||
return images, page, page_info
|
||||
|
||||
|
||||
def go_to_page(page_num, current_page):
|
||||
"""Go to a specific page (1-indexed input)."""
|
||||
try:
|
||||
page = int(page_num) - 1 # Convert to 0-indexed
|
||||
except (ValueError, TypeError):
|
||||
page = current_page
|
||||
|
||||
images, page, total_pages, total_images = get_paginated_images(page)
|
||||
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||
return images, page, page_info
|
||||
|
||||
|
||||
def next_page(current_page):
|
||||
"""Go to next page."""
|
||||
images, page, total_pages, total_images = get_paginated_images(current_page + 1)
|
||||
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||
return images, page, page_info
|
||||
|
||||
|
||||
def prev_page(current_page):
|
||||
"""Go to previous page."""
|
||||
images, page, total_pages, total_images = get_paginated_images(current_page - 1)
|
||||
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||
return images, page, page_info
|
||||
|
||||
|
||||
def on_gallery_select(evt: gr.SelectData, current_page):
|
||||
"""Handle image selection from gallery."""
|
||||
if evt.index is None:
|
||||
return "", "Select an image to view its settings"
|
||||
|
||||
# Get the current page's images to find the actual file path
|
||||
all_images = get_all_history_images()
|
||||
total_images = len(all_images)
|
||||
|
||||
# Calculate the actual index in the full list
|
||||
start_idx = current_page * IMAGES_PER_PAGE
|
||||
actual_idx = start_idx + evt.index
|
||||
|
||||
if actual_idx >= total_images:
|
||||
return "", "Image not found"
|
||||
|
||||
image_path = all_images[actual_idx]
|
||||
metadata = read_image_metadata(image_path)
|
||||
metadata_display = format_metadata_for_display(metadata)
|
||||
|
||||
return image_path, metadata_display
|
||||
|
||||
|
||||
def send_to_generate(selected_image_path):
|
||||
"""Load settings from selected image and return updates for all Generate tab inputs."""
|
||||
if not selected_image_path or not os.path.exists(selected_image_path):
|
||||
return [gr.update()] * 10 + ["No image selected"]
|
||||
|
||||
metadata = read_image_metadata(selected_image_path)
|
||||
if not metadata:
|
||||
return [gr.update()] * 10 + ["No settings found in this image"]
|
||||
|
||||
# Return updates for each input element in order
|
||||
updates = [
|
||||
gr.update(value=metadata.get('image_prompt', '')),
|
||||
gr.update(value=metadata.get('image_neg_prompt', '')),
|
||||
gr.update(value=metadata.get('image_width', 1024)),
|
||||
gr.update(value=metadata.get('image_height', 1024)),
|
||||
gr.update(value=metadata.get('image_aspect_ratio', '1:1 Square')),
|
||||
gr.update(value=metadata.get('image_steps', 9)),
|
||||
gr.update(value=metadata.get('image_seed', -1)),
|
||||
gr.update(value=metadata.get('image_batch_size', 1)),
|
||||
gr.update(value=metadata.get('image_batch_count', 1)),
|
||||
gr.update(value=metadata.get('image_cfg_scale', 0.0)),
|
||||
]
|
||||
|
||||
status = f"✓ Settings loaded from image (seed: {metadata.get('image_seed', 'unknown')})"
|
||||
return updates + [status]
|
||||
|
||||
|
||||
def create_ui():
|
||||
if shared.settings['image_model_menu'] != 'None':
|
||||
shared.image_model_name = shared.settings['image_model_menu']
|
||||
|
|
@ -115,15 +348,15 @@ def create_ui():
|
|||
value=shared.settings['image_neg_prompt']
|
||||
)
|
||||
|
||||
shared.gradio['image_generate_btn'] = gr.Button("✨ GENERATE", variant="primary", size="lg", elem_id="gen-btn")
|
||||
shared.gradio['image_generate_btn'] = gr.Button("GENERATE", variant="primary", size="lg", elem_id="gen-btn")
|
||||
gr.HTML("<hr style='border-top: 1px solid #444; margin: 20px 0;'>")
|
||||
|
||||
gr.Markdown("### Dimensions")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=32, label="Width")
|
||||
shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=STEP, label="Width")
|
||||
with gr.Column():
|
||||
shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=32, label="Height")
|
||||
shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=STEP, label="Height")
|
||||
shared.gradio['image_swap_btn'] = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width")
|
||||
|
||||
with gr.Row():
|
||||
|
|
@ -137,7 +370,14 @@ def create_ui():
|
|||
gr.Markdown("### Config")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['image_steps'] = gr.Slider(1, 15, value=shared.settings['image_steps'], step=1, label="Steps")
|
||||
shared.gradio['image_steps'] = gr.Slider(1, 100, value=shared.settings['image_steps'], step=1, label="Steps")
|
||||
shared.gradio['image_cfg_scale'] = gr.Slider(
|
||||
0.0, 10.0,
|
||||
value=0.0,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
info="Z-Image Turbo: 0.0 | Qwen: 4.0"
|
||||
)
|
||||
shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random")
|
||||
with gr.Column():
|
||||
shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
|
||||
|
|
@ -146,14 +386,41 @@ def create_ui():
|
|||
with gr.Column(scale=6, min_width=500):
|
||||
with gr.Column(elem_classes=["viewport-container"]):
|
||||
shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery")
|
||||
with gr.Row():
|
||||
shared.gradio['image_used_seed'] = gr.Markdown(label="Info", interactive=False)
|
||||
|
||||
# TAB 2: GALLERY
|
||||
# TAB 2: GALLERY (with pagination)
|
||||
with gr.TabItem("Gallery"):
|
||||
with gr.Row():
|
||||
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh Gallery", elem_classes="refresh-button")
|
||||
shared.gradio['image_history_gallery'] = gr.Gallery(value=lambda : get_history_images(), label="History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True, elem_id="image-history-gallery")
|
||||
with gr.Column(scale=3):
|
||||
# Pagination controls
|
||||
with gr.Row():
|
||||
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh", elem_classes="refresh-button")
|
||||
shared.gradio['image_prev_page'] = gr.Button("◀ Prev", elem_classes="refresh-button")
|
||||
shared.gradio['image_page_info'] = gr.Markdown("Loading...", elem_id="image-page-info")
|
||||
shared.gradio['image_next_page'] = gr.Button("Next ▶", elem_classes="refresh-button")
|
||||
shared.gradio['image_page_input'] = gr.Number(value=1, label="Page", precision=0, minimum=1, scale=0, min_width=80)
|
||||
shared.gradio['image_go_to_page'] = gr.Button("Go", elem_classes="refresh-button", scale=0, min_width=50)
|
||||
|
||||
# State for current page and selected image path
|
||||
shared.gradio['image_current_page'] = gr.State(value=0)
|
||||
shared.gradio['image_selected_path'] = gr.State(value="")
|
||||
|
||||
# Paginated gallery using gr.Gallery
|
||||
shared.gradio['image_history_gallery'] = gr.Gallery(
|
||||
value=lambda: get_paginated_images(0)[0],
|
||||
label="Image History",
|
||||
show_label=False,
|
||||
columns=6,
|
||||
object_fit="cover",
|
||||
height="auto",
|
||||
allow_preview=True,
|
||||
elem_id="image-history-gallery"
|
||||
)
|
||||
|
||||
with gr.Column(scale=1):
|
||||
gr.Markdown("### Selected Image")
|
||||
shared.gradio['image_settings_display'] = gr.Markdown("Select an image to view its settings")
|
||||
shared.gradio['image_send_to_generate'] = gr.Button("Send to Generate", variant="primary")
|
||||
shared.gradio['image_gallery_status'] = gr.Markdown("")
|
||||
|
||||
# TAB 3: MODEL
|
||||
with gr.TabItem("Model"):
|
||||
|
|
@ -173,6 +440,13 @@ def create_ui():
|
|||
gr.Markdown("## Settings")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['image_quant'] = gr.Dropdown(
|
||||
label='Quantization',
|
||||
choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'],
|
||||
value=shared.settings['image_quant'],
|
||||
info='Quantization method for reduced VRAM usage. Quanto supports lower precisions (2-bit, 4-bit, 8-bit).'
|
||||
)
|
||||
|
||||
shared.gradio['image_dtype'] = gr.Dropdown(
|
||||
choices=['bfloat16', 'float16'],
|
||||
value=shared.settings['image_dtype'],
|
||||
|
|
@ -242,15 +516,15 @@ def create_event_handlers():
|
|||
# Generation
|
||||
shared.gradio['image_generate_btn'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery'))
|
||||
|
||||
shared.gradio['image_prompt'].submit(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery'))
|
||||
|
||||
shared.gradio['image_neg_prompt'].submit(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery', 'image_used_seed'))
|
||||
generate, gradio('interface_state'), gradio('image_output_gallery'))
|
||||
|
||||
# Model management
|
||||
shared.gradio['image_refresh_models'].click(
|
||||
|
|
@ -262,7 +536,7 @@ def create_event_handlers():
|
|||
|
||||
shared.gradio['image_load_model'].click(
|
||||
load_image_model_wrapper,
|
||||
gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile'),
|
||||
gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile', 'image_quant'),
|
||||
gradio('image_model_status'),
|
||||
show_progress=True
|
||||
)
|
||||
|
|
@ -281,20 +555,77 @@ def create_event_handlers():
|
|||
show_progress=True
|
||||
)
|
||||
|
||||
# History
|
||||
# Gallery pagination handlers
|
||||
shared.gradio['image_refresh_history'].click(
|
||||
get_history_images,
|
||||
None,
|
||||
gradio('image_history_gallery'),
|
||||
refresh_gallery,
|
||||
gradio('image_current_page'),
|
||||
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
shared.gradio['image_next_page'].click(
|
||||
next_page,
|
||||
gradio('image_current_page'),
|
||||
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
shared.gradio['image_prev_page'].click(
|
||||
prev_page,
|
||||
gradio('image_current_page'),
|
||||
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
shared.gradio['image_go_to_page'].click(
|
||||
go_to_page,
|
||||
gradio('image_page_input', 'image_current_page'),
|
||||
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
# Image selection from gallery
|
||||
shared.gradio['image_history_gallery'].select(
|
||||
on_gallery_select,
|
||||
gradio('image_current_page'),
|
||||
gradio('image_selected_path', 'image_settings_display'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
# Send to Generate
|
||||
shared.gradio['image_send_to_generate'].click(
|
||||
send_to_generate,
|
||||
gradio('image_selected_path'),
|
||||
gradio(
|
||||
'image_prompt',
|
||||
'image_neg_prompt',
|
||||
'image_width',
|
||||
'image_height',
|
||||
'image_aspect_ratio',
|
||||
'image_steps',
|
||||
'image_seed',
|
||||
'image_batch_size',
|
||||
'image_batch_count',
|
||||
'image_cfg_scale',
|
||||
'image_gallery_status'
|
||||
),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
|
||||
def generate(state):
|
||||
"""
|
||||
Generate images using the loaded model.
|
||||
Automatically adjusts parameters based on pipeline type.
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
model_name = state['image_model_menu']
|
||||
|
||||
if not model_name or model_name == 'None':
|
||||
return [], "No image model selected. Go to the Model tab and select a model."
|
||||
logger.error("No image model selected. Go to the Model tab and select a model.")
|
||||
return []
|
||||
|
||||
if shared.image_model is None:
|
||||
result = load_image_model(
|
||||
|
|
@ -302,10 +633,12 @@ def generate(state):
|
|||
dtype=state['image_dtype'],
|
||||
attn_backend=state['image_attn_backend'],
|
||||
cpu_offload=state['image_cpu_offload'],
|
||||
compile_model=state['image_compile']
|
||||
compile_model=state['image_compile'],
|
||||
quant_method=state['image_quant']
|
||||
)
|
||||
if result is None:
|
||||
return [], f"Failed to load model `{model_name}`."
|
||||
logger.error(f"Failed to load model `{model_name}`.")
|
||||
return []
|
||||
|
||||
shared.image_model_name = model_name
|
||||
|
||||
|
|
@ -316,25 +649,56 @@ def generate(state):
|
|||
generator = torch.Generator("cuda").manual_seed(int(seed))
|
||||
all_images = []
|
||||
|
||||
# Get pipeline type for parameter adjustment
|
||||
pipeline_type = getattr(shared, 'image_pipeline_type', None)
|
||||
if pipeline_type is None:
|
||||
pipeline_type = get_pipeline_type(shared.image_model)
|
||||
|
||||
# Process Prompt
|
||||
prompt = state['image_prompt']
|
||||
|
||||
# Apply "Positive Magic" for Qwen models only
|
||||
if pipeline_type == 'qwenimage':
|
||||
magic_suffix = ", Ultra HD, 4K, cinematic composition"
|
||||
# Avoid duplication if user already added it
|
||||
if magic_suffix.strip(", ") not in prompt:
|
||||
prompt += magic_suffix
|
||||
|
||||
# Build generation kwargs
|
||||
gen_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": state['image_neg_prompt'],
|
||||
"height": int(state['image_height']),
|
||||
"width": int(state['image_width']),
|
||||
"num_inference_steps": int(state['image_steps']),
|
||||
"num_images_per_prompt": int(state['image_batch_size']),
|
||||
"generator": generator,
|
||||
}
|
||||
|
||||
# Add pipeline-specific parameters for CFG
|
||||
cfg_val = state.get('image_cfg_scale', 0.0)
|
||||
|
||||
if pipeline_type == 'qwenimage':
|
||||
# Qwen-Image uses true_cfg_scale (typically 4.0)
|
||||
gen_kwargs["true_cfg_scale"] = cfg_val
|
||||
else:
|
||||
# Z-Image and others use guidance_scale (typically 0.0 for Turbo)
|
||||
gen_kwargs["guidance_scale"] = cfg_val
|
||||
|
||||
t0 = time.time()
|
||||
for i in range(int(state['image_batch_count'])):
|
||||
generator.manual_seed(int(seed + i))
|
||||
batch_results = shared.image_model(
|
||||
prompt=state['image_prompt'],
|
||||
negative_prompt=state['image_neg_prompt'],
|
||||
height=int(state['image_height']),
|
||||
width=int(state['image_width']),
|
||||
num_inference_steps=int(state['image_steps']),
|
||||
guidance_scale=0.0,
|
||||
num_images_per_prompt=int(state['image_batch_size']),
|
||||
generator=generator,
|
||||
).images
|
||||
batch_results = shared.image_model(**gen_kwargs).images
|
||||
all_images.extend(batch_results)
|
||||
|
||||
save_generated_images(all_images, state['image_prompt'], seed)
|
||||
return all_images, f"Seed: {seed}"
|
||||
t1 = time.time()
|
||||
save_generated_images(all_images, state, seed)
|
||||
|
||||
logger.info(f'Images generated in {(t1-t0):.2f} seconds ({state["image_steps"]/(t1-t0):.2f} steps/s, seed {seed})')
|
||||
return all_images
|
||||
|
||||
|
||||
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model):
|
||||
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method):
|
||||
if not model_name or model_name == 'None':
|
||||
yield "No model selected"
|
||||
return
|
||||
|
|
@ -348,12 +712,13 @@ def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compi
|
|||
dtype=dtype,
|
||||
attn_backend=attn_backend,
|
||||
cpu_offload=cpu_offload,
|
||||
compile_model=compile_model
|
||||
compile_model=compile_model,
|
||||
quant_method=quant_method
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
shared.image_model_name = model_name
|
||||
yield f"✓ Loaded **{model_name}**"
|
||||
yield f"✓ Loaded **{model_name}** (quantization: {quant_method})"
|
||||
else:
|
||||
yield f"✗ Failed to load `{model_name}`"
|
||||
except Exception:
|
||||
|
|
@ -375,6 +740,12 @@ def download_image_model_wrapper(model_path):
|
|||
return
|
||||
|
||||
try:
|
||||
model_path = model_path.strip()
|
||||
if model_path.startswith('https://huggingface.co/'):
|
||||
model_path = model_path[len('https://huggingface.co/'):]
|
||||
elif model_path.startswith('huggingface.co/'):
|
||||
model_path = model_path[len('huggingface.co/'):]
|
||||
|
||||
if ':' in model_path:
|
||||
model_id, branch = model_path.rsplit(':', 1)
|
||||
else:
|
||||
|
|
@ -396,30 +767,3 @@ def download_image_model_wrapper(model_path):
|
|||
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
|
||||
except Exception:
|
||||
yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update()
|
||||
|
||||
|
||||
def save_generated_images(images, prompt, seed):
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
folder_path = os.path.join("user_data", "image_outputs", date_str)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
for idx, img in enumerate(images):
|
||||
timestamp = datetime.now().strftime("%H-%M-%S")
|
||||
filename = f"{timestamp}_{seed}_{idx}.png"
|
||||
img.save(os.path.join(folder_path, filename))
|
||||
|
||||
|
||||
def get_history_images():
|
||||
output_dir = os.path.join("user_data", "image_outputs")
|
||||
if not os.path.exists(output_dir):
|
||||
return []
|
||||
|
||||
image_files = []
|
||||
for root, _, files in os.walk(output_dir):
|
||||
for file in files:
|
||||
if file.endswith((".png", ".jpg", ".jpeg")):
|
||||
full_path = os.path.join(root, file)
|
||||
image_files.append((full_path, os.path.getmtime(full_path)))
|
||||
|
||||
image_files.sort(key=lambda x: x[1], reverse=True)
|
||||
return [x[0] for x in image_files]
|
||||
|
|
|
|||
|
|
@ -172,7 +172,8 @@ def create_interface():
|
|||
ui_chat.create_event_handlers()
|
||||
ui_default.create_event_handlers()
|
||||
ui_notebook.create_event_handlers()
|
||||
ui_image_generation.create_event_handlers()
|
||||
if not shared.args.portable:
|
||||
ui_image_generation.create_event_handlers()
|
||||
|
||||
# Other events
|
||||
ui_file_saving.create_event_handlers()
|
||||
|
|
|
|||
Loading…
Reference in a new issue