diff --git a/README.md b/README.md index d350d959..b1aeba48 100644 --- a/README.md +++ b/README.md @@ -432,6 +432,7 @@ https://colab.research.google.com/github/oobabooga/text-generation-webui/blob/ma https://www.reddit.com/r/Oobabooga/ -## Acknowledgment +## Acknowledgments -In August 2023, [Andreessen Horowitz](https://a16z.com/) (a16z) provided a generous grant to encourage and support my independent work on this project. I am **extremely** grateful for their trust and recognition. +- In August 2023, [Andreessen Horowitz](https://a16z.com/) (a16z) provided a generous grant to encourage and support my independent work on this project. I am **extremely** grateful for their trust and recognition. +- This project was inspired by [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) and wouldn't exist without it. diff --git a/css/main.css b/css/main.css index fd79d24c..0bfdca0a 100644 --- a/css/main.css +++ b/css/main.css @@ -93,11 +93,11 @@ ol li p, ul li p { display: inline-block; } -#notebook-parent-tab, #chat-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab { +#notebook-parent-tab, #chat-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab, #image-ai-tab { border: 0; } -#notebook-parent-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab { +#notebook-parent-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab, #image-ai-tab { padding: 1rem; } @@ -1674,3 +1674,66 @@ button:focus { .dark .sidebar-vertical-separator { border-bottom: 1px solid rgb(255 255 255 / 10%); } + +button#swap-height-width { + position: absolute; + top: -50px; + right: 0; + border: 0; +} + +#image-output-gallery, #image-output-gallery > :nth-child(2) { + height: calc(100vh - 83px); + max-height: calc(100vh - 83px); +} + +#image-history-gallery, #image-history-gallery > :nth-child(2) { + height: calc(100vh - 174px); + max-height: calc(100vh - 174px); +} + +/* Additional CSS for the paginated image gallery */ + +/* Page info styling */ +#image-page-info { + display: flex; + align-items: center; + justify-content: center; + min-width: 200px; + font-size: 0.9em; + color: var(--body-text-color-subdued); +} + +/* Settings display panel */ +#image-ai-tab .settings-display-panel { + background: var(--background-fill-secondary); + padding: 12px; + border-radius: 8px; + font-size: 0.9em; + max-height: 300px; + overflow-y: auto; + margin-top: 8px; +} + +/* Gallery status message */ +#image-ai-tab .gallery-status { + color: var(--color-accent); + font-size: 0.85em; + margin-top: 4px; +} + +/* Pagination button row alignment */ +#image-ai-tab .pagination-controls { + display: flex; + align-items: center; + gap: 8px; + flex-wrap: wrap; +} + +/* Selected image preview container */ +#image-ai-tab .selected-preview-container { + border: 1px solid var(--border-color-primary); + border-radius: 8px; + padding: 8px; + background: var(--background-fill-secondary); +} diff --git a/modules/image_models.py b/modules/image_models.py new file mode 100644 index 00000000..e4831758 --- /dev/null +++ b/modules/image_models.py @@ -0,0 +1,183 @@ +import time + +import modules.shared as shared +from modules.logging_colors import logger +from modules.torch_utils import get_device +from modules.utils import resolve_model_path + + +def get_quantization_config(quant_method): + """ + Get the appropriate quantization config based on the selected method. + + Args: + quant_method: One of 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit' + + Returns: + PipelineQuantizationConfig or None + """ + import torch + from diffusers import BitsAndBytesConfig, QuantoConfig + from diffusers.quantizers import PipelineQuantizationConfig + + if quant_method == 'none' or not quant_method: + return None + + # Bitsandbytes 8-bit quantization + elif quant_method == 'bnb-8bit': + return PipelineQuantizationConfig( + quant_mapping={ + "transformer": BitsAndBytesConfig( + load_in_8bit=True + ) + } + ) + + # Bitsandbytes 4-bit quantization + elif quant_method == 'bnb-4bit': + return PipelineQuantizationConfig( + quant_mapping={ + "transformer": BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True + ) + } + ) + + # Quanto 8-bit quantization + elif quant_method == 'quanto-8bit': + return PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8") + } + ) + + # Quanto 4-bit quantization + elif quant_method == 'quanto-4bit': + return PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int4") + } + ) + + # Quanto 2-bit quantization + elif quant_method == 'quanto-2bit': + return PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int2") + } + ) + + else: + logger.warning(f"Unknown quantization method: {quant_method}. Loading without quantization.") + return None + + +def get_pipeline_type(pipe): + """ + Detect the pipeline type based on the loaded pipeline class. + + Returns: + str: 'zimage', 'qwenimage', or 'unknown' + """ + class_name = pipe.__class__.__name__ + if 'ZImage' in class_name: + return 'zimage' + elif 'QwenImage' in class_name: + return 'qwenimage' + else: + return 'unknown' + + +def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False, quant_method='none'): + """ + Load a diffusers image generation model. + + Args: + model_name: Name of the model directory + dtype: 'bfloat16' or 'float16' + attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3' + cpu_offload: Enable CPU offloading for low VRAM + compile_model: Compile the model for faster inference (slow first run) + quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit' + """ + import torch + from diffusers import DiffusionPipeline + + logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}") + t0 = time.time() + + dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16} + target_dtype = dtype_map.get(dtype, torch.bfloat16) + + model_path = resolve_model_path(model_name, image_model=True) + + try: + # Get quantization config based on selected method + pipeline_quant_config = get_quantization_config(quant_method) + + # Load the pipeline + load_kwargs = { + "torch_dtype": target_dtype, + "low_cpu_mem_usage": True, + } + + if pipeline_quant_config is not None: + load_kwargs["quantization_config"] = pipeline_quant_config + + # Use DiffusionPipeline for automatic pipeline detection + # This handles both ZImagePipeline and QwenImagePipeline + pipe = DiffusionPipeline.from_pretrained( + str(model_path), + **load_kwargs + ) + + pipeline_type = get_pipeline_type(pipe) + + if not cpu_offload: + pipe.to(get_device()) + + # Set attention backend (if supported by the pipeline) + if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'): + if attn_backend == 'flash_attention_2': + pipe.transformer.set_attention_backend("flash") + elif attn_backend == 'flash_attention_3': + pipe.transformer.set_attention_backend("_flash_3") + # sdpa is the default, no action needed + + if compile_model: + if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'compile'): + logger.info("Compiling model (first run will be slow)...") + pipe.transformer.compile() + + if cpu_offload: + pipe.enable_model_cpu_offload() + + shared.image_model = pipe + shared.image_model_name = model_name + shared.image_pipeline_type = pipeline_type + + logger.info(f"Loaded image model \"{model_name}\" in {(time.time() - t0):.2f} seconds.") + return pipe + + except Exception as e: + logger.error(f"Failed to load image model: {str(e)}") + return None + + +def unload_image_model(): + """Unload the current image model and free VRAM.""" + if shared.image_model is None: + return + + del shared.image_model + shared.image_model = None + shared.image_model_name = 'None' + shared.image_pipeline_type = None + + from modules.torch_utils import clear_torch_cache + clear_torch_cache() + + logger.info("Image model unloaded.") diff --git a/modules/shared.py b/modules/shared.py index 134c0cac..fef67489 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -11,7 +11,7 @@ import yaml from modules.logging_colors import logger from modules.presets import default_preset -# Model variables +# Text model variables model = None tokenizer = None model_name = 'None' @@ -20,6 +20,10 @@ is_multimodal = False model_dirty_from_training = False lora_names = [] +# Image model variables +image_model = None +image_model_name = 'None' + # Generation variables stop_everything = False generation_lock = None @@ -46,6 +50,18 @@ group.add_argument('--extensions', type=str, nargs='+', help='The list of extens group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.') +# Image generation +group = parser.add_argument_group('Image model') +group.add_argument('--image-model', type=str, help='Name of the image model to select on startup (overrides saved setting).') +group.add_argument('--image-model-dir', type=str, default='user_data/image_models', help='Path to directory with all the image models.') +group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16', 'float16'], help='Data type for image model.') +group.add_argument('--image-attn-backend', type=str, default=None, choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.') +group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.') +group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.') +group.add_argument('--image-quant', type=str, default=None, + choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'], + help='Quantization method for image model.') + # Model loader group = parser.add_argument_group('Model loader') group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, TensorRT-LLM.') @@ -290,6 +306,24 @@ settings = { # Extensions 'default_extensions': [], + + # Image generation settings + 'image_prompt': '', + 'image_neg_prompt': '', + 'image_width': 1024, + 'image_height': 1024, + 'image_aspect_ratio': '1:1 Square', + 'image_steps': 9, + 'image_cfg_scale': 0.0, + 'image_seed': -1, + 'image_batch_size': 1, + 'image_batch_count': 1, + 'image_model_menu': 'None', + 'image_dtype': 'bfloat16', + 'image_attn_backend': 'sdpa', + 'image_cpu_offload': False, + 'image_compile': False, + 'image_quant': 'none', } default_settings = copy.deepcopy(settings) @@ -314,6 +348,23 @@ def do_cmd_flags_warnings(): logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.') +def apply_image_model_cli_overrides(): + """Apply command-line overrides for image model settings.""" + if args.image_model is not None: + settings['image_model_menu'] = args.image_model + if args.image_dtype is not None: + settings['image_dtype'] = args.image_dtype + if args.image_attn_backend is not None: + settings['image_attn_backend'] = args.image_attn_backend + if args.image_cpu_offload: + settings['image_cpu_offload'] = True + if args.image_compile: + settings['image_compile'] = True + if args.image_quant is not None: + settings['image_quant'] = args.image_quant + + + def fix_loader_name(name): if not name: return name diff --git a/modules/ui.py b/modules/ui.py index f99e8b6a..9700d297 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -280,6 +280,26 @@ def list_interface_input_elements(): 'include_past_attachments', ] + # Image generation elements + elements += [ + 'image_prompt', + 'image_neg_prompt', + 'image_width', + 'image_height', + 'image_aspect_ratio', + 'image_steps', + 'image_cfg_scale', + 'image_seed', + 'image_batch_size', + 'image_batch_count', + 'image_model_menu', + 'image_dtype', + 'image_attn_backend', + 'image_compile', + 'image_cpu_offload', + 'image_quant', + ] + return elements @@ -509,7 +529,25 @@ def setup_auto_save(): 'theme_state', 'show_two_notebook_columns', 'paste_to_attachment', - 'include_past_attachments' + 'include_past_attachments', + + # Image generation tab (ui_image_generation.py) + 'image_prompt', + 'image_neg_prompt', + 'image_width', + 'image_height', + 'image_aspect_ratio', + 'image_steps', + 'image_cfg_scale', + 'image_seed', + 'image_batch_size', + 'image_batch_count', + 'image_model_menu', + 'image_dtype', + 'image_attn_backend', + 'image_compile', + 'image_cpu_offload', + 'image_quant', ] for element_name in change_elements: diff --git a/modules/ui_image_generation.py b/modules/ui_image_generation.py new file mode 100644 index 00000000..fa3d6791 --- /dev/null +++ b/modules/ui_image_generation.py @@ -0,0 +1,777 @@ +import json +import os +import time +import traceback +from datetime import datetime +from pathlib import Path + +import gradio as gr +import numpy as np +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +from modules import shared, ui, utils +from modules.image_models import load_image_model, unload_image_model +from modules.logging_colors import logger +from modules.utils import gradio + +ASPECT_RATIOS = { + "1:1 Square": (1, 1), + "16:9 Cinema": (16, 9), + "9:16 Mobile": (9, 16), + "4:3 Photo": (4, 3), + "Custom": None, +} + +STEP = 16 +IMAGES_PER_PAGE = 64 + +# Settings keys to save in PNG metadata (Generate tab only) +METADATA_SETTINGS_KEYS = [ + 'image_prompt', + 'image_neg_prompt', + 'image_width', + 'image_height', + 'image_aspect_ratio', + 'image_steps', + 'image_seed', + 'image_batch_size', + 'image_batch_count', + 'image_cfg_scale', +] + +# Cache for all image paths +_image_cache = [] +_cache_timestamp = 0 + + +def round_to_step(value, step=STEP): + return round(value / step) * step + + +def clamp(value, min_val, max_val): + return max(min_val, min(max_val, value)) + + +def apply_aspect_ratio(aspect_ratio, current_width, current_height): + if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS: + return current_width, current_height + + w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] + + if w_ratio == h_ratio: + base = min(current_width, current_height) + new_width = base + new_height = base + elif w_ratio < h_ratio: + new_width = current_width + new_height = round_to_step(current_width * h_ratio / w_ratio) + else: + new_height = current_height + new_width = round_to_step(current_height * w_ratio / h_ratio) + + new_width = clamp(new_width, 256, 2048) + new_height = clamp(new_height, 256, 2048) + + return int(new_width), int(new_height) + + +def update_height_from_width(width, aspect_ratio): + if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS: + return gr.update() + + w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] + new_height = round_to_step(width * h_ratio / w_ratio) + new_height = clamp(new_height, 256, 2048) + + return int(new_height) + + +def update_width_from_height(height, aspect_ratio): + if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS: + return gr.update() + + w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] + new_width = round_to_step(height * w_ratio / h_ratio) + new_width = clamp(new_width, 256, 2048) + + return int(new_width) + + +def swap_dimensions_and_update_ratio(width, height, aspect_ratio): + new_width, new_height = height, width + + new_ratio = "Custom" + for name, ratios in ASPECT_RATIOS.items(): + if ratios is None: + continue + w_r, h_r = ratios + expected_height = new_width * h_r / w_r + if abs(expected_height - new_height) < STEP: + new_ratio = name + break + + return new_width, new_height, new_ratio + + +def build_generation_metadata(state, actual_seed): + """Build metadata dict from generation settings.""" + metadata = {} + for key in METADATA_SETTINGS_KEYS: + if key in state: + metadata[key] = state[key] + + # Store the actual seed used (not -1) + metadata['image_seed'] = actual_seed + metadata['generated_at'] = datetime.now().isoformat() + metadata['model'] = shared.image_model_name + + return metadata + + +def save_generated_images(images, state, actual_seed): + """Save images with generation metadata embedded in PNG.""" + date_str = datetime.now().strftime("%Y-%m-%d") + folder_path = os.path.join("user_data", "image_outputs", date_str) + os.makedirs(folder_path, exist_ok=True) + + metadata = build_generation_metadata(state, actual_seed) + metadata_json = json.dumps(metadata, ensure_ascii=False) + + for idx, img in enumerate(images): + timestamp = datetime.now().strftime("%H-%M-%S") + filename = f"{timestamp}_{actual_seed:010d}_{idx:03d}.png" + filepath = os.path.join(folder_path, filename) + + # Create PNG metadata + png_info = PngInfo() + png_info.add_text("image_gen_settings", metadata_json) + + # Save with metadata + img.save(filepath, pnginfo=png_info) + + +def read_image_metadata(image_path): + """Read generation metadata from PNG file.""" + try: + with Image.open(image_path) as img: + if hasattr(img, 'text') and 'image_gen_settings' in img.text: + return json.loads(img.text['image_gen_settings']) + except Exception as e: + logger.debug(f"Could not read metadata from {image_path}: {e}") + return None + + +def format_metadata_for_display(metadata): + """Format metadata as readable text.""" + if not metadata: + return "No generation settings found in this image." + + lines = ["**Generation Settings**", ""] + + # Display in a nice order + display_order = [ + ('image_prompt', 'Prompt'), + ('image_neg_prompt', 'Negative Prompt'), + ('image_width', 'Width'), + ('image_height', 'Height'), + ('image_aspect_ratio', 'Aspect Ratio'), + ('image_steps', 'Steps'), + ('image_cfg_scale', 'CFG Scale'), + ('image_seed', 'Seed'), + ('image_batch_size', 'Batch Size'), + ('image_batch_count', 'Batch Count'), + ('model', 'Model'), + ('generated_at', 'Generated At'), + ] + + for key, label in display_order: + if key in metadata: + value = metadata[key] + if key in ['image_prompt', 'image_neg_prompt'] and value: + # Truncate long prompts for display + if len(str(value)) > 200: + value = str(value)[:200] + "..." + lines.append(f"**{label}:** {value}") + + return "\n\n".join(lines) + + +def get_all_history_images(force_refresh=False): + """Get all history images sorted by modification time (newest first). Uses caching.""" + global _image_cache, _cache_timestamp + + output_dir = os.path.join("user_data", "image_outputs") + if not os.path.exists(output_dir): + return [] + + # Check if we need to refresh cache + current_time = time.time() + if not force_refresh and _image_cache and (current_time - _cache_timestamp) < 2: + return _image_cache + + image_files = [] + for root, _, files in os.walk(output_dir): + for file in files: + if file.endswith((".png", ".jpg", ".jpeg")): + full_path = os.path.join(root, file) + image_files.append((full_path, os.path.getmtime(full_path))) + + image_files.sort(key=lambda x: x[1], reverse=True) + _image_cache = [x[0] for x in image_files] + _cache_timestamp = current_time + + return _image_cache + + +def get_paginated_images(page=0, force_refresh=False): + """Get images for a specific page.""" + all_images = get_all_history_images(force_refresh) + total_images = len(all_images) + total_pages = max(1, (total_images + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE) + + # Clamp page to valid range + page = max(0, min(page, total_pages - 1)) + + start_idx = page * IMAGES_PER_PAGE + end_idx = min(start_idx + IMAGES_PER_PAGE, total_images) + + page_images = all_images[start_idx:end_idx] + + return page_images, page, total_pages, total_images + + +def refresh_gallery(current_page=0): + """Refresh gallery with current page.""" + images, page, total_pages, total_images = get_paginated_images(current_page, force_refresh=True) + page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)" + return images, page, page_info + + +def go_to_page(page_num, current_page): + """Go to a specific page (1-indexed input).""" + try: + page = int(page_num) - 1 # Convert to 0-indexed + except (ValueError, TypeError): + page = current_page + + images, page, total_pages, total_images = get_paginated_images(page) + page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)" + return images, page, page_info + + +def next_page(current_page): + """Go to next page.""" + images, page, total_pages, total_images = get_paginated_images(current_page + 1) + page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)" + return images, page, page_info + + +def prev_page(current_page): + """Go to previous page.""" + images, page, total_pages, total_images = get_paginated_images(current_page - 1) + page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)" + return images, page, page_info + + +def on_gallery_select(evt: gr.SelectData, current_page): + """Handle image selection from gallery.""" + if evt.index is None: + return "", "Select an image to view its settings" + + # Get the current page's images to find the actual file path + all_images = get_all_history_images() + total_images = len(all_images) + + # Calculate the actual index in the full list + start_idx = current_page * IMAGES_PER_PAGE + actual_idx = start_idx + evt.index + + if actual_idx >= total_images: + return "", "Image not found" + + image_path = all_images[actual_idx] + metadata = read_image_metadata(image_path) + metadata_display = format_metadata_for_display(metadata) + + return image_path, metadata_display + + +def send_to_generate(selected_image_path): + """Load settings from selected image and return updates for all Generate tab inputs.""" + if not selected_image_path or not os.path.exists(selected_image_path): + return [gr.update()] * 10 + ["No image selected"] + + metadata = read_image_metadata(selected_image_path) + if not metadata: + return [gr.update()] * 10 + ["No settings found in this image"] + + # Return updates for each input element in order + updates = [ + gr.update(value=metadata.get('image_prompt', '')), + gr.update(value=metadata.get('image_neg_prompt', '')), + gr.update(value=metadata.get('image_width', 1024)), + gr.update(value=metadata.get('image_height', 1024)), + gr.update(value=metadata.get('image_aspect_ratio', '1:1 Square')), + gr.update(value=metadata.get('image_steps', 9)), + gr.update(value=metadata.get('image_seed', -1)), + gr.update(value=metadata.get('image_batch_size', 1)), + gr.update(value=metadata.get('image_batch_count', 1)), + gr.update(value=metadata.get('image_cfg_scale', 0.0)), + ] + + status = f"✓ Settings loaded from image (seed: {metadata.get('image_seed', 'unknown')})" + return updates + [status] + + +def create_ui(): + if shared.settings['image_model_menu'] != 'None': + shared.image_model_name = shared.settings['image_model_menu'] + + with gr.Tab("Image AI", elem_id="image-ai-tab"): + with gr.Tabs(): + # TAB 1: GENERATE + with gr.TabItem("Generate"): + with gr.Row(): + with gr.Column(scale=4, min_width=350): + shared.gradio['image_prompt'] = gr.Textbox( + label="Prompt", + placeholder="Describe your imagination...", + lines=3, + autofocus=True, + value=shared.settings['image_prompt'] + ) + shared.gradio['image_neg_prompt'] = gr.Textbox( + label="Negative Prompt", + placeholder="Low quality...", + lines=3, + value=shared.settings['image_neg_prompt'] + ) + + shared.gradio['image_generate_btn'] = gr.Button("Generate", variant="primary", size="lg") + shared.gradio['image_generating_btn'] = gr.Button("Generating...", size="lg", visible=False, interactive=False) + gr.HTML("