mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Add image generation support (#7328)
This commit is contained in:
parent
a83821e941
commit
b3666e140d
|
|
@ -28,6 +28,8 @@ A Gradio web UI for Large Language Models.
|
||||||
- 100% offline and private, with zero telemetry, external resources, or remote update requests.
|
- 100% offline and private, with zero telemetry, external resources, or remote update requests.
|
||||||
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
|
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
|
||||||
- **Vision (multimodal models)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
|
- **Vision (multimodal models)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
|
||||||
|
Image generation: A dedicated tab for diffusers models like Z-Image-Turbo and Qwen-Image. Features 4-bit/8-bit quantization and a persistent gallery with metadata (tutorial).
|
||||||
|
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo** and **Qwen-Image**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
|
||||||
- **Web search**: Optionally search the internet with LLM-generated queries to add context to the conversation.
|
- **Web search**: Optionally search the internet with LLM-generated queries to add context to the conversation.
|
||||||
- Aesthetic UI with dark and light themes.
|
- Aesthetic UI with dark and light themes.
|
||||||
- Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions.
|
- Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions.
|
||||||
|
|
@ -432,6 +434,7 @@ https://colab.research.google.com/github/oobabooga/text-generation-webui/blob/ma
|
||||||
|
|
||||||
https://www.reddit.com/r/Oobabooga/
|
https://www.reddit.com/r/Oobabooga/
|
||||||
|
|
||||||
## Acknowledgment
|
## Acknowledgments
|
||||||
|
|
||||||
In August 2023, [Andreessen Horowitz](https://a16z.com/) (a16z) provided a generous grant to encourage and support my independent work on this project. I am **extremely** grateful for their trust and recognition.
|
- In August 2023, [Andreessen Horowitz](https://a16z.com/) (a16z) provided a generous grant to encourage and support my independent work on this project. I am **extremely** grateful for their trust and recognition.
|
||||||
|
- This project was inspired by [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) and wouldn't exist without it.
|
||||||
|
|
|
||||||
96
css/main.css
96
css/main.css
|
|
@ -93,11 +93,11 @@ ol li p, ul li p {
|
||||||
display: inline-block;
|
display: inline-block;
|
||||||
}
|
}
|
||||||
|
|
||||||
#notebook-parent-tab, #chat-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab {
|
#notebook-parent-tab, #chat-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab, #image-ai-tab {
|
||||||
border: 0;
|
border: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#notebook-parent-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab {
|
#notebook-parent-tab, #parameters, #chat-settings, #lora, #training-tab, #model-tab, #session-tab, #character-tab, #image-ai-tab {
|
||||||
padding: 1rem;
|
padding: 1rem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -244,37 +244,46 @@ button {
|
||||||
font-size: 100% !important;
|
font-size: 100% !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.pretty_scrollbar::-webkit-scrollbar {
|
.pretty_scrollbar::-webkit-scrollbar,
|
||||||
|
#image-history-gallery > :nth-child(2)::-webkit-scrollbar {
|
||||||
width: 8px;
|
width: 8px;
|
||||||
height: 8px;
|
height: 8px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.pretty_scrollbar::-webkit-scrollbar-track {
|
.pretty_scrollbar::-webkit-scrollbar-track,
|
||||||
|
#image-history-gallery > :nth-child(2)::-webkit-scrollbar-track {
|
||||||
background: transparent;
|
background: transparent;
|
||||||
}
|
}
|
||||||
|
|
||||||
.pretty_scrollbar::-webkit-scrollbar-thumb,
|
.pretty_scrollbar::-webkit-scrollbar-thumb,
|
||||||
.pretty_scrollbar::-webkit-scrollbar-thumb:hover {
|
.pretty_scrollbar::-webkit-scrollbar-thumb:hover,
|
||||||
|
#image-history-gallery > :nth-child(2)::-webkit-scrollbar-thumb,
|
||||||
|
#image-history-gallery > :nth-child(2)::-webkit-scrollbar-thumb:hover {
|
||||||
background: var(--neutral-300);
|
background: var(--neutral-300);
|
||||||
border-radius: 30px;
|
border-radius: 30px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .pretty_scrollbar::-webkit-scrollbar-thumb,
|
.dark .pretty_scrollbar::-webkit-scrollbar-thumb,
|
||||||
.dark .pretty_scrollbar::-webkit-scrollbar-thumb:hover {
|
.dark .pretty_scrollbar::-webkit-scrollbar-thumb:hover,
|
||||||
|
.dark #image-history-gallery > :nth-child(2)::-webkit-scrollbar-thumb,
|
||||||
|
.dark #image-history-gallery > :nth-child(2)::-webkit-scrollbar-thumb:hover {
|
||||||
background: rgb(255 255 255 / 6.25%);
|
background: rgb(255 255 255 / 6.25%);
|
||||||
border-radius: 10px;
|
border-radius: 10px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.pretty_scrollbar::-webkit-resizer {
|
.pretty_scrollbar::-webkit-resizer,
|
||||||
|
#image-history-gallery > :nth-child(2)::-webkit-resizer {
|
||||||
background: #c5c5d2;
|
background: #c5c5d2;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .pretty_scrollbar::-webkit-resizer {
|
.dark .pretty_scrollbar::-webkit-resizer,
|
||||||
|
.dark #image-history-gallery > :nth-child(2)::-webkit-resizer {
|
||||||
background: #ccc;
|
background: #ccc;
|
||||||
border-radius: 10px;
|
border-radius: 10px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.pretty_scrollbar::-webkit-scrollbar-corner {
|
.pretty_scrollbar::-webkit-scrollbar-corner,
|
||||||
|
#image-history-gallery > :nth-child(2)::-webkit-scrollbar-corner {
|
||||||
background: transparent;
|
background: transparent;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1674,3 +1683,72 @@ button:focus {
|
||||||
.dark .sidebar-vertical-separator {
|
.dark .sidebar-vertical-separator {
|
||||||
border-bottom: 1px solid rgb(255 255 255 / 10%);
|
border-bottom: 1px solid rgb(255 255 255 / 10%);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
button#swap-height-width {
|
||||||
|
position: absolute;
|
||||||
|
top: -50px;
|
||||||
|
right: 0;
|
||||||
|
border: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#image-output-gallery, #image-output-gallery > :nth-child(2) {
|
||||||
|
height: calc(100vh - 83px);
|
||||||
|
max-height: calc(100vh - 83px);
|
||||||
|
}
|
||||||
|
|
||||||
|
#image-history-gallery, #image-history-gallery > :nth-child(2) {
|
||||||
|
height: calc(100vh - 174px);
|
||||||
|
max-height: calc(100vh - 174px);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Additional CSS for the paginated image gallery */
|
||||||
|
|
||||||
|
/* Page info styling */
|
||||||
|
#image-page-info {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
min-width: 200px;
|
||||||
|
font-size: 0.9em;
|
||||||
|
color: var(--body-text-color-subdued);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Settings display panel */
|
||||||
|
#image-ai-tab .settings-display-panel {
|
||||||
|
background: var(--background-fill-secondary);
|
||||||
|
padding: 12px;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 0.9em;
|
||||||
|
max-height: 300px;
|
||||||
|
overflow-y: auto;
|
||||||
|
margin-top: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Gallery status message */
|
||||||
|
#image-ai-tab .gallery-status {
|
||||||
|
color: var(--color-accent);
|
||||||
|
font-size: 0.85em;
|
||||||
|
margin-top: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Pagination button row alignment */
|
||||||
|
#image-ai-tab .pagination-controls {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 8px;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Selected image preview container */
|
||||||
|
#image-ai-tab .selected-preview-container {
|
||||||
|
border: 1px solid var(--border-color-primary);
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 8px;
|
||||||
|
background: var(--background-fill-secondary);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Fix a gr.Markdown UI glitch when clicking Next in the
|
||||||
|
* Image AI > Gallery tab */
|
||||||
|
.min.svelte-1yrv54 {
|
||||||
|
min-height: 0;
|
||||||
|
}
|
||||||
|
|
|
||||||
20
docs/Image Generation Tutorial.md
Normal file
20
docs/Image Generation Tutorial.md
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Image Generation Tutorial
|
||||||
|
|
||||||
|
This feature allows you to generate images using high-speed models like Z-Image-Turbo directly within the web UI.
|
||||||
|
|
||||||
|
## How to use
|
||||||
|
|
||||||
|
1. Click on the **Image AI** tab at the top of the interface.
|
||||||
|
2. Select the **Model** sub-tab.
|
||||||
|
3. Copy and paste the following link into the **Download model** box:
|
||||||
|
|
||||||
|
```
|
||||||
|
https://huggingface.co/Tongyi-MAI/Z-Image-Turbo
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Click the **Download** button and wait for the confirmation message.
|
||||||
|
5. In the **Model** dropdown menu, select the model you just downloaded (if you don't see it, click the 🔄 refresh button).
|
||||||
|
6. Click **Load**.
|
||||||
|
7. Go to the **Generate** sub-tab, type a prompt, and click **GENERATE**.
|
||||||
|
|
||||||
|
> **Note for Z-Image-Turbo:** For the best results with this specific model, keep the **CFG Scale** slider at **0**.
|
||||||
183
modules/image_models.py
Normal file
183
modules/image_models.py
Normal file
|
|
@ -0,0 +1,183 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
from modules.torch_utils import get_device
|
||||||
|
from modules.utils import resolve_model_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_quantization_config(quant_method):
|
||||||
|
"""
|
||||||
|
Get the appropriate quantization config based on the selected method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_method: One of 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PipelineQuantizationConfig or None
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from diffusers import BitsAndBytesConfig, QuantoConfig
|
||||||
|
from diffusers.quantizers import PipelineQuantizationConfig
|
||||||
|
|
||||||
|
if quant_method == 'none' or not quant_method:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Bitsandbytes 8-bit quantization
|
||||||
|
elif quant_method == 'bnb-8bit':
|
||||||
|
return PipelineQuantizationConfig(
|
||||||
|
quant_mapping={
|
||||||
|
"transformer": BitsAndBytesConfig(
|
||||||
|
load_in_8bit=True
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bitsandbytes 4-bit quantization
|
||||||
|
elif quant_method == 'bnb-4bit':
|
||||||
|
return PipelineQuantizationConfig(
|
||||||
|
quant_mapping={
|
||||||
|
"transformer": BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||||
|
bnb_4bit_use_double_quant=True
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quanto 8-bit quantization
|
||||||
|
elif quant_method == 'quanto-8bit':
|
||||||
|
return PipelineQuantizationConfig(
|
||||||
|
quant_mapping={
|
||||||
|
"transformer": QuantoConfig(weights_dtype="int8")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quanto 4-bit quantization
|
||||||
|
elif quant_method == 'quanto-4bit':
|
||||||
|
return PipelineQuantizationConfig(
|
||||||
|
quant_mapping={
|
||||||
|
"transformer": QuantoConfig(weights_dtype="int4")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quanto 2-bit quantization
|
||||||
|
elif quant_method == 'quanto-2bit':
|
||||||
|
return PipelineQuantizationConfig(
|
||||||
|
quant_mapping={
|
||||||
|
"transformer": QuantoConfig(weights_dtype="int2")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown quantization method: {quant_method}. Loading without quantization.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_pipeline_type(pipe):
|
||||||
|
"""
|
||||||
|
Detect the pipeline type based on the loaded pipeline class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 'zimage', 'qwenimage', or 'unknown'
|
||||||
|
"""
|
||||||
|
class_name = pipe.__class__.__name__
|
||||||
|
if class_name == 'ZImagePipeline':
|
||||||
|
return 'zimage'
|
||||||
|
elif class_name == 'QwenImagePipeline':
|
||||||
|
return 'qwenimage'
|
||||||
|
else:
|
||||||
|
return 'unknown'
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False, quant_method='none'):
|
||||||
|
"""
|
||||||
|
Load a diffusers image generation model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model directory
|
||||||
|
dtype: 'bfloat16' or 'float16'
|
||||||
|
attn_backend: 'sdpa', 'flash_attention_2', or 'flash_attention_3'
|
||||||
|
cpu_offload: Enable CPU offloading for low VRAM
|
||||||
|
compile_model: Compile the model for faster inference (slow first run)
|
||||||
|
quant_method: Quantization method - 'none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}")
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16}
|
||||||
|
target_dtype = dtype_map.get(dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
model_path = resolve_model_path(model_name, image_model=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get quantization config based on selected method
|
||||||
|
pipeline_quant_config = get_quantization_config(quant_method)
|
||||||
|
|
||||||
|
# Load the pipeline
|
||||||
|
load_kwargs = {
|
||||||
|
"torch_dtype": target_dtype,
|
||||||
|
"low_cpu_mem_usage": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if pipeline_quant_config is not None:
|
||||||
|
load_kwargs["quantization_config"] = pipeline_quant_config
|
||||||
|
|
||||||
|
# Use DiffusionPipeline for automatic pipeline detection
|
||||||
|
# This handles both ZImagePipeline and QwenImagePipeline
|
||||||
|
pipe = DiffusionPipeline.from_pretrained(
|
||||||
|
str(model_path),
|
||||||
|
**load_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline_type = get_pipeline_type(pipe)
|
||||||
|
|
||||||
|
if not cpu_offload:
|
||||||
|
pipe.to(get_device())
|
||||||
|
|
||||||
|
# Set attention backend (if supported by the pipeline)
|
||||||
|
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'):
|
||||||
|
if attn_backend == 'flash_attention_2':
|
||||||
|
pipe.transformer.set_attention_backend("flash")
|
||||||
|
elif attn_backend == 'flash_attention_3':
|
||||||
|
pipe.transformer.set_attention_backend("_flash_3")
|
||||||
|
# sdpa is the default, no action needed
|
||||||
|
|
||||||
|
if compile_model:
|
||||||
|
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'compile'):
|
||||||
|
logger.info("Compiling model (first run will be slow)...")
|
||||||
|
pipe.transformer.compile()
|
||||||
|
|
||||||
|
if cpu_offload:
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
shared.image_model = pipe
|
||||||
|
shared.image_model_name = model_name
|
||||||
|
shared.image_pipeline_type = pipeline_type
|
||||||
|
|
||||||
|
logger.info(f"Loaded image model \"{model_name}\" in {(time.time() - t0):.2f} seconds.")
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load image model: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def unload_image_model():
|
||||||
|
"""Unload the current image model and free VRAM."""
|
||||||
|
if shared.image_model is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
del shared.image_model
|
||||||
|
shared.image_model = None
|
||||||
|
shared.image_model_name = 'None'
|
||||||
|
shared.image_pipeline_type = None
|
||||||
|
|
||||||
|
from modules.torch_utils import clear_torch_cache
|
||||||
|
clear_torch_cache()
|
||||||
|
|
||||||
|
logger.info("Image model unloaded.")
|
||||||
|
|
@ -11,7 +11,7 @@ import yaml
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.presets import default_preset
|
from modules.presets import default_preset
|
||||||
|
|
||||||
# Model variables
|
# Text model variables
|
||||||
model = None
|
model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
model_name = 'None'
|
model_name = 'None'
|
||||||
|
|
@ -20,6 +20,11 @@ is_multimodal = False
|
||||||
model_dirty_from_training = False
|
model_dirty_from_training = False
|
||||||
lora_names = []
|
lora_names = []
|
||||||
|
|
||||||
|
# Image model variables
|
||||||
|
image_model = None
|
||||||
|
image_model_name = 'None'
|
||||||
|
image_pipeline_type = None
|
||||||
|
|
||||||
# Generation variables
|
# Generation variables
|
||||||
stop_everything = False
|
stop_everything = False
|
||||||
generation_lock = None
|
generation_lock = None
|
||||||
|
|
@ -46,6 +51,18 @@ group.add_argument('--extensions', type=str, nargs='+', help='The list of extens
|
||||||
group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||||
group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.')
|
group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.')
|
||||||
|
|
||||||
|
# Image generation
|
||||||
|
group = parser.add_argument_group('Image model')
|
||||||
|
group.add_argument('--image-model', type=str, help='Name of the image model to select on startup (overrides saved setting).')
|
||||||
|
group.add_argument('--image-model-dir', type=str, default='user_data/image_models', help='Path to directory with all the image models.')
|
||||||
|
group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16', 'float16'], help='Data type for image model.')
|
||||||
|
group.add_argument('--image-attn-backend', type=str, default=None, choices=['sdpa', 'flash_attention_2', 'flash_attention_3'], help='Attention backend for image model.')
|
||||||
|
group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.')
|
||||||
|
group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.')
|
||||||
|
group.add_argument('--image-quant', type=str, default=None,
|
||||||
|
choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'],
|
||||||
|
help='Quantization method for image model.')
|
||||||
|
|
||||||
# Model loader
|
# Model loader
|
||||||
group = parser.add_argument_group('Model loader')
|
group = parser.add_argument_group('Model loader')
|
||||||
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, TensorRT-LLM.')
|
group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav2_HF, ExLlamav2, TensorRT-LLM.')
|
||||||
|
|
@ -290,6 +307,24 @@ settings = {
|
||||||
|
|
||||||
# Extensions
|
# Extensions
|
||||||
'default_extensions': [],
|
'default_extensions': [],
|
||||||
|
|
||||||
|
# Image generation settings
|
||||||
|
'image_prompt': '',
|
||||||
|
'image_neg_prompt': '',
|
||||||
|
'image_width': 1024,
|
||||||
|
'image_height': 1024,
|
||||||
|
'image_aspect_ratio': '1:1 Square',
|
||||||
|
'image_steps': 9,
|
||||||
|
'image_cfg_scale': 0.0,
|
||||||
|
'image_seed': -1,
|
||||||
|
'image_batch_size': 1,
|
||||||
|
'image_batch_count': 1,
|
||||||
|
'image_model_menu': 'None',
|
||||||
|
'image_dtype': 'bfloat16',
|
||||||
|
'image_attn_backend': 'sdpa',
|
||||||
|
'image_cpu_offload': False,
|
||||||
|
'image_compile': False,
|
||||||
|
'image_quant': 'none',
|
||||||
}
|
}
|
||||||
|
|
||||||
default_settings = copy.deepcopy(settings)
|
default_settings = copy.deepcopy(settings)
|
||||||
|
|
@ -314,6 +349,22 @@ def do_cmd_flags_warnings():
|
||||||
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
|
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
|
||||||
|
|
||||||
|
|
||||||
|
def apply_image_model_cli_overrides():
|
||||||
|
"""Apply command-line overrides for image model settings."""
|
||||||
|
if args.image_model is not None:
|
||||||
|
settings['image_model_menu'] = args.image_model
|
||||||
|
if args.image_dtype is not None:
|
||||||
|
settings['image_dtype'] = args.image_dtype
|
||||||
|
if args.image_attn_backend is not None:
|
||||||
|
settings['image_attn_backend'] = args.image_attn_backend
|
||||||
|
if args.image_cpu_offload:
|
||||||
|
settings['image_cpu_offload'] = True
|
||||||
|
if args.image_compile:
|
||||||
|
settings['image_compile'] = True
|
||||||
|
if args.image_quant is not None:
|
||||||
|
settings['image_quant'] = args.image_quant
|
||||||
|
|
||||||
|
|
||||||
def fix_loader_name(name):
|
def fix_loader_name(name):
|
||||||
if not name:
|
if not name:
|
||||||
return name
|
return name
|
||||||
|
|
|
||||||
|
|
@ -280,6 +280,26 @@ def list_interface_input_elements():
|
||||||
'include_past_attachments',
|
'include_past_attachments',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Image generation elements
|
||||||
|
elements += [
|
||||||
|
'image_prompt',
|
||||||
|
'image_neg_prompt',
|
||||||
|
'image_width',
|
||||||
|
'image_height',
|
||||||
|
'image_aspect_ratio',
|
||||||
|
'image_steps',
|
||||||
|
'image_cfg_scale',
|
||||||
|
'image_seed',
|
||||||
|
'image_batch_size',
|
||||||
|
'image_batch_count',
|
||||||
|
'image_model_menu',
|
||||||
|
'image_dtype',
|
||||||
|
'image_attn_backend',
|
||||||
|
'image_compile',
|
||||||
|
'image_cpu_offload',
|
||||||
|
'image_quant',
|
||||||
|
]
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -509,7 +529,25 @@ def setup_auto_save():
|
||||||
'theme_state',
|
'theme_state',
|
||||||
'show_two_notebook_columns',
|
'show_two_notebook_columns',
|
||||||
'paste_to_attachment',
|
'paste_to_attachment',
|
||||||
'include_past_attachments'
|
'include_past_attachments',
|
||||||
|
|
||||||
|
# Image generation tab (ui_image_generation.py)
|
||||||
|
'image_prompt',
|
||||||
|
'image_neg_prompt',
|
||||||
|
'image_width',
|
||||||
|
'image_height',
|
||||||
|
'image_aspect_ratio',
|
||||||
|
'image_steps',
|
||||||
|
'image_cfg_scale',
|
||||||
|
'image_seed',
|
||||||
|
'image_batch_size',
|
||||||
|
'image_batch_count',
|
||||||
|
'image_model_menu',
|
||||||
|
'image_dtype',
|
||||||
|
'image_attn_backend',
|
||||||
|
'image_compile',
|
||||||
|
'image_cpu_offload',
|
||||||
|
'image_quant',
|
||||||
]
|
]
|
||||||
|
|
||||||
for element_name in change_elements:
|
for element_name in change_elements:
|
||||||
|
|
|
||||||
847
modules/ui_image_generation.py
Normal file
847
modules/ui_image_generation.py
Normal file
|
|
@ -0,0 +1,847 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.PngImagePlugin import PngInfo
|
||||||
|
|
||||||
|
from modules import shared, ui, utils
|
||||||
|
from modules.image_models import (
|
||||||
|
get_pipeline_type,
|
||||||
|
load_image_model,
|
||||||
|
unload_image_model
|
||||||
|
)
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
from modules.text_generation import stop_everything_event
|
||||||
|
from modules.torch_utils import get_device
|
||||||
|
from modules.utils import gradio
|
||||||
|
|
||||||
|
ASPECT_RATIOS = {
|
||||||
|
"1:1 Square": (1, 1),
|
||||||
|
"16:9 Cinema": (16, 9),
|
||||||
|
"9:16 Mobile": (9, 16),
|
||||||
|
"4:3 Photo": (4, 3),
|
||||||
|
"Custom": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
STEP = 16
|
||||||
|
IMAGES_PER_PAGE = 64
|
||||||
|
|
||||||
|
# Settings keys to save in PNG metadata (Generate tab only)
|
||||||
|
METADATA_SETTINGS_KEYS = [
|
||||||
|
'image_prompt',
|
||||||
|
'image_neg_prompt',
|
||||||
|
'image_width',
|
||||||
|
'image_height',
|
||||||
|
'image_aspect_ratio',
|
||||||
|
'image_steps',
|
||||||
|
'image_seed',
|
||||||
|
'image_batch_size',
|
||||||
|
'image_batch_count',
|
||||||
|
'image_cfg_scale',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Cache for all image paths
|
||||||
|
_image_cache = []
|
||||||
|
_cache_timestamp = 0
|
||||||
|
|
||||||
|
|
||||||
|
def round_to_step(value, step=STEP):
|
||||||
|
return round(value / step) * step
|
||||||
|
|
||||||
|
|
||||||
|
def clamp(value, min_val, max_val):
|
||||||
|
return max(min_val, min(max_val, value))
|
||||||
|
|
||||||
|
|
||||||
|
def apply_aspect_ratio(aspect_ratio, current_width, current_height):
|
||||||
|
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
|
||||||
|
return current_width, current_height
|
||||||
|
|
||||||
|
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
||||||
|
|
||||||
|
if w_ratio == h_ratio:
|
||||||
|
base = min(current_width, current_height)
|
||||||
|
new_width = base
|
||||||
|
new_height = base
|
||||||
|
elif w_ratio < h_ratio:
|
||||||
|
new_width = current_width
|
||||||
|
new_height = round_to_step(current_width * h_ratio / w_ratio)
|
||||||
|
else:
|
||||||
|
new_height = current_height
|
||||||
|
new_width = round_to_step(current_height * w_ratio / h_ratio)
|
||||||
|
|
||||||
|
new_width = clamp(new_width, 256, 2048)
|
||||||
|
new_height = clamp(new_height, 256, 2048)
|
||||||
|
|
||||||
|
return int(new_width), int(new_height)
|
||||||
|
|
||||||
|
|
||||||
|
def update_height_from_width(width, aspect_ratio):
|
||||||
|
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
|
||||||
|
return gr.update()
|
||||||
|
|
||||||
|
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
||||||
|
new_height = round_to_step(width * h_ratio / w_ratio)
|
||||||
|
new_height = clamp(new_height, 256, 2048)
|
||||||
|
|
||||||
|
return int(new_height)
|
||||||
|
|
||||||
|
|
||||||
|
def update_width_from_height(height, aspect_ratio):
|
||||||
|
if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS:
|
||||||
|
return gr.update()
|
||||||
|
|
||||||
|
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
||||||
|
new_width = round_to_step(height * w_ratio / h_ratio)
|
||||||
|
new_width = clamp(new_width, 256, 2048)
|
||||||
|
|
||||||
|
return int(new_width)
|
||||||
|
|
||||||
|
|
||||||
|
def swap_dimensions_and_update_ratio(width, height, aspect_ratio):
|
||||||
|
new_width, new_height = height, width
|
||||||
|
|
||||||
|
new_ratio = "Custom"
|
||||||
|
for name, ratios in ASPECT_RATIOS.items():
|
||||||
|
if ratios is None:
|
||||||
|
continue
|
||||||
|
w_r, h_r = ratios
|
||||||
|
expected_height = new_width * h_r / w_r
|
||||||
|
if abs(expected_height - new_height) < STEP:
|
||||||
|
new_ratio = name
|
||||||
|
break
|
||||||
|
|
||||||
|
return new_width, new_height, new_ratio
|
||||||
|
|
||||||
|
|
||||||
|
def build_generation_metadata(state, actual_seed):
|
||||||
|
"""Build metadata dict from generation settings."""
|
||||||
|
metadata = {}
|
||||||
|
for key in METADATA_SETTINGS_KEYS:
|
||||||
|
if key in state:
|
||||||
|
metadata[key] = state[key]
|
||||||
|
|
||||||
|
# Store the actual seed used (not -1)
|
||||||
|
metadata['image_seed'] = actual_seed
|
||||||
|
metadata['generated_at'] = datetime.now().isoformat()
|
||||||
|
metadata['model'] = shared.image_model_name
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def save_generated_images(images, state, actual_seed):
|
||||||
|
"""Save images with generation metadata embedded in PNG."""
|
||||||
|
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
folder_path = os.path.join("user_data", "image_outputs", date_str)
|
||||||
|
os.makedirs(folder_path, exist_ok=True)
|
||||||
|
|
||||||
|
metadata = build_generation_metadata(state, actual_seed)
|
||||||
|
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||||
|
|
||||||
|
for idx, img in enumerate(images):
|
||||||
|
timestamp = datetime.now().strftime("%H-%M-%S")
|
||||||
|
filename = f"{timestamp}_{actual_seed:010d}_{idx:03d}.png"
|
||||||
|
filepath = os.path.join(folder_path, filename)
|
||||||
|
|
||||||
|
# Create PNG metadata
|
||||||
|
png_info = PngInfo()
|
||||||
|
png_info.add_text("image_gen_settings", metadata_json)
|
||||||
|
|
||||||
|
# Save with metadata
|
||||||
|
img.save(filepath, pnginfo=png_info)
|
||||||
|
|
||||||
|
|
||||||
|
def read_image_metadata(image_path):
|
||||||
|
"""Read generation metadata from PNG file."""
|
||||||
|
try:
|
||||||
|
with Image.open(image_path) as img:
|
||||||
|
if hasattr(img, 'text') and 'image_gen_settings' in img.text:
|
||||||
|
return json.loads(img.text['image_gen_settings'])
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not read metadata from {image_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def format_metadata_for_display(metadata):
|
||||||
|
"""Format metadata as readable text."""
|
||||||
|
if not metadata:
|
||||||
|
return "No generation settings found in this image."
|
||||||
|
|
||||||
|
lines = ["**Generation Settings**", ""]
|
||||||
|
|
||||||
|
# Display in a nice order
|
||||||
|
display_order = [
|
||||||
|
('image_prompt', 'Prompt'),
|
||||||
|
('image_neg_prompt', 'Negative Prompt'),
|
||||||
|
('image_width', 'Width'),
|
||||||
|
('image_height', 'Height'),
|
||||||
|
('image_aspect_ratio', 'Aspect Ratio'),
|
||||||
|
('image_steps', 'Steps'),
|
||||||
|
('image_cfg_scale', 'CFG Scale'),
|
||||||
|
('image_seed', 'Seed'),
|
||||||
|
('image_batch_size', 'Batch Size'),
|
||||||
|
('image_batch_count', 'Batch Count'),
|
||||||
|
('model', 'Model'),
|
||||||
|
('generated_at', 'Generated At'),
|
||||||
|
]
|
||||||
|
|
||||||
|
for key, label in display_order:
|
||||||
|
if key in metadata:
|
||||||
|
value = metadata[key]
|
||||||
|
if key in ['image_prompt', 'image_neg_prompt'] and value:
|
||||||
|
# Truncate long prompts for display
|
||||||
|
if len(str(value)) > 200:
|
||||||
|
value = str(value)[:200] + "..."
|
||||||
|
lines.append(f"**{label}:** {value}")
|
||||||
|
|
||||||
|
return "\n\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_history_images(force_refresh=False):
|
||||||
|
"""Get all history images sorted by modification time (newest first). Uses caching."""
|
||||||
|
global _image_cache, _cache_timestamp
|
||||||
|
|
||||||
|
output_dir = os.path.join("user_data", "image_outputs")
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Check if we need to refresh cache
|
||||||
|
current_time = time.time()
|
||||||
|
if not force_refresh and _image_cache and (current_time - _cache_timestamp) < 2:
|
||||||
|
return _image_cache
|
||||||
|
|
||||||
|
image_files = []
|
||||||
|
for root, _, files in os.walk(output_dir):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith((".png", ".jpg", ".jpeg")):
|
||||||
|
full_path = os.path.join(root, file)
|
||||||
|
image_files.append((full_path, os.path.getmtime(full_path)))
|
||||||
|
|
||||||
|
image_files.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
_image_cache = [x[0] for x in image_files]
|
||||||
|
_cache_timestamp = current_time
|
||||||
|
|
||||||
|
return _image_cache
|
||||||
|
|
||||||
|
|
||||||
|
def get_paginated_images(page=0, force_refresh=False):
|
||||||
|
"""Get images for a specific page."""
|
||||||
|
all_images = get_all_history_images(force_refresh)
|
||||||
|
total_images = len(all_images)
|
||||||
|
total_pages = max(1, (total_images + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE)
|
||||||
|
|
||||||
|
# Clamp page to valid range
|
||||||
|
page = max(0, min(page, total_pages - 1))
|
||||||
|
|
||||||
|
start_idx = page * IMAGES_PER_PAGE
|
||||||
|
end_idx = min(start_idx + IMAGES_PER_PAGE, total_images)
|
||||||
|
|
||||||
|
page_images = all_images[start_idx:end_idx]
|
||||||
|
|
||||||
|
return page_images, page, total_pages, total_images
|
||||||
|
|
||||||
|
|
||||||
|
def get_initial_page_info():
|
||||||
|
"""Get page info string for initial load."""
|
||||||
|
_, page, total_pages, total_images = get_paginated_images(0)
|
||||||
|
return f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_gallery(current_page=0):
|
||||||
|
"""Refresh gallery with current page."""
|
||||||
|
images, page, total_pages, total_images = get_paginated_images(current_page, force_refresh=True)
|
||||||
|
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||||
|
return images, page, page_info
|
||||||
|
|
||||||
|
|
||||||
|
def go_to_page(page_num, current_page):
|
||||||
|
"""Go to a specific page (1-indexed input)."""
|
||||||
|
try:
|
||||||
|
page = int(page_num) - 1 # Convert to 0-indexed
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
page = current_page
|
||||||
|
|
||||||
|
images, page, total_pages, total_images = get_paginated_images(page)
|
||||||
|
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||||
|
return images, page, page_info
|
||||||
|
|
||||||
|
|
||||||
|
def next_page(current_page):
|
||||||
|
"""Go to next page."""
|
||||||
|
images, page, total_pages, total_images = get_paginated_images(current_page + 1)
|
||||||
|
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||||
|
return images, page, page_info
|
||||||
|
|
||||||
|
|
||||||
|
def prev_page(current_page):
|
||||||
|
"""Go to previous page."""
|
||||||
|
images, page, total_pages, total_images = get_paginated_images(current_page - 1)
|
||||||
|
page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)"
|
||||||
|
return images, page, page_info
|
||||||
|
|
||||||
|
|
||||||
|
def on_gallery_select(evt: gr.SelectData, current_page):
|
||||||
|
"""Handle image selection from gallery."""
|
||||||
|
if evt.index is None:
|
||||||
|
return "", "Select an image to view its settings"
|
||||||
|
|
||||||
|
# Get the current page's images to find the actual file path
|
||||||
|
all_images = get_all_history_images()
|
||||||
|
total_images = len(all_images)
|
||||||
|
|
||||||
|
# Calculate the actual index in the full list
|
||||||
|
start_idx = current_page * IMAGES_PER_PAGE
|
||||||
|
actual_idx = start_idx + evt.index
|
||||||
|
|
||||||
|
if actual_idx >= total_images:
|
||||||
|
return "", "Image not found"
|
||||||
|
|
||||||
|
image_path = all_images[actual_idx]
|
||||||
|
metadata = read_image_metadata(image_path)
|
||||||
|
metadata_display = format_metadata_for_display(metadata)
|
||||||
|
|
||||||
|
return image_path, metadata_display
|
||||||
|
|
||||||
|
|
||||||
|
def send_to_generate(selected_image_path):
|
||||||
|
"""Load settings from selected image and return updates for all Generate tab inputs."""
|
||||||
|
if not selected_image_path or not os.path.exists(selected_image_path):
|
||||||
|
return [gr.update()] * 10 + ["No image selected"]
|
||||||
|
|
||||||
|
metadata = read_image_metadata(selected_image_path)
|
||||||
|
if not metadata:
|
||||||
|
return [gr.update()] * 10 + ["No settings found in this image"]
|
||||||
|
|
||||||
|
# Return updates for each input element in order
|
||||||
|
updates = [
|
||||||
|
gr.update(value=metadata.get('image_prompt', '')),
|
||||||
|
gr.update(value=metadata.get('image_neg_prompt', '')),
|
||||||
|
gr.update(value=metadata.get('image_width', 1024)),
|
||||||
|
gr.update(value=metadata.get('image_height', 1024)),
|
||||||
|
gr.update(value=metadata.get('image_aspect_ratio', '1:1 Square')),
|
||||||
|
gr.update(value=metadata.get('image_steps', 9)),
|
||||||
|
gr.update(value=metadata.get('image_seed', -1)),
|
||||||
|
gr.update(value=metadata.get('image_batch_size', 1)),
|
||||||
|
gr.update(value=metadata.get('image_batch_count', 1)),
|
||||||
|
gr.update(value=metadata.get('image_cfg_scale', 0.0)),
|
||||||
|
]
|
||||||
|
|
||||||
|
status = f"✓ Settings loaded from image (seed: {metadata.get('image_seed', 'unknown')})"
|
||||||
|
return updates + [status]
|
||||||
|
|
||||||
|
|
||||||
|
def read_dropped_image_metadata(image_path):
|
||||||
|
"""Read metadata from a dropped/uploaded image."""
|
||||||
|
if not image_path:
|
||||||
|
return "Drop an image to view its generation settings."
|
||||||
|
|
||||||
|
metadata = read_image_metadata(image_path)
|
||||||
|
return format_metadata_for_display(metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def create_ui():
|
||||||
|
if shared.settings['image_model_menu'] != 'None':
|
||||||
|
shared.image_model_name = shared.settings['image_model_menu']
|
||||||
|
|
||||||
|
with gr.Tab("Image AI", elem_id="image-ai-tab"):
|
||||||
|
with gr.Tabs():
|
||||||
|
# TAB 1: GENERATE
|
||||||
|
with gr.TabItem("Generate"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=4, min_width=350):
|
||||||
|
shared.gradio['image_prompt'] = gr.Textbox(
|
||||||
|
label="Prompt",
|
||||||
|
placeholder="Describe your imagination...",
|
||||||
|
lines=3,
|
||||||
|
autofocus=True,
|
||||||
|
value=shared.settings['image_prompt']
|
||||||
|
)
|
||||||
|
shared.gradio['image_neg_prompt'] = gr.Textbox(
|
||||||
|
label="Negative Prompt",
|
||||||
|
placeholder="Low quality...",
|
||||||
|
lines=3,
|
||||||
|
value=shared.settings['image_neg_prompt']
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['image_generate_btn'] = gr.Button("Generate", variant="primary", size="lg")
|
||||||
|
shared.gradio['image_stop_btn'] = gr.Button("Stop", size="lg", visible=False)
|
||||||
|
gr.HTML("<hr style='border-top: 1px solid #444; margin: 20px 0;'>")
|
||||||
|
|
||||||
|
gr.Markdown("### Dimensions")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=STEP, label="Width")
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=STEP, label="Height")
|
||||||
|
shared.gradio['image_swap_btn'] = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['image_aspect_ratio'] = gr.Radio(
|
||||||
|
choices=["1:1 Square", "16:9 Cinema", "9:16 Mobile", "4:3 Photo", "Custom"],
|
||||||
|
value=shared.settings['image_aspect_ratio'],
|
||||||
|
label="Aspect Ratio",
|
||||||
|
interactive=True
|
||||||
|
)
|
||||||
|
|
||||||
|
gr.Markdown("### Config")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['image_steps'] = gr.Slider(1, 100, value=shared.settings['image_steps'], step=1, label="Steps")
|
||||||
|
shared.gradio['image_cfg_scale'] = gr.Slider(
|
||||||
|
0.0, 10.0,
|
||||||
|
value=shared.settings['image_cfg_scale'],
|
||||||
|
step=0.1,
|
||||||
|
label="CFG Scale",
|
||||||
|
info="Z-Image Turbo: 0.0 | Qwen: 4.0"
|
||||||
|
)
|
||||||
|
shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random")
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
|
||||||
|
shared.gradio['image_batch_count'] = gr.Slider(1, 128, value=shared.settings['image_batch_count'], step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
|
||||||
|
|
||||||
|
with gr.Column(scale=6, min_width=500):
|
||||||
|
with gr.Column(elem_classes=["viewport-container"]):
|
||||||
|
shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery")
|
||||||
|
|
||||||
|
# TAB 2: GALLERY (with pagination)
|
||||||
|
with gr.TabItem("Gallery"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
# Pagination controls
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh", elem_classes="refresh-button")
|
||||||
|
shared.gradio['image_prev_page'] = gr.Button("◀ Prev", elem_classes="refresh-button")
|
||||||
|
shared.gradio['image_page_info'] = gr.Markdown(value=get_initial_page_info, elem_id="image-page-info")
|
||||||
|
shared.gradio['image_next_page'] = gr.Button("Next ▶", elem_classes="refresh-button")
|
||||||
|
shared.gradio['image_page_input'] = gr.Number(value=1, label="Page", precision=0, minimum=1, scale=0, min_width=80)
|
||||||
|
shared.gradio['image_go_to_page'] = gr.Button("Go", elem_classes="refresh-button", scale=0, min_width=50)
|
||||||
|
|
||||||
|
# State for current page and selected image path
|
||||||
|
shared.gradio['image_current_page'] = gr.State(value=0)
|
||||||
|
shared.gradio['image_selected_path'] = gr.State(value="")
|
||||||
|
|
||||||
|
# Paginated gallery using gr.Gallery
|
||||||
|
shared.gradio['image_history_gallery'] = gr.Gallery(
|
||||||
|
value=lambda: get_paginated_images(0)[0],
|
||||||
|
label="Image History",
|
||||||
|
show_label=False,
|
||||||
|
columns=6,
|
||||||
|
object_fit="cover",
|
||||||
|
height="auto",
|
||||||
|
allow_preview=True,
|
||||||
|
elem_id="image-history-gallery"
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
gr.Markdown("### Selected Image")
|
||||||
|
shared.gradio['image_settings_display'] = gr.Markdown("Select an image to view its settings")
|
||||||
|
shared.gradio['image_send_to_generate'] = gr.Button("Send to Generate", variant="primary")
|
||||||
|
shared.gradio['image_gallery_status'] = gr.Markdown("")
|
||||||
|
|
||||||
|
gr.Markdown("### Import Image")
|
||||||
|
shared.gradio['image_drop_upload'] = gr.Image(
|
||||||
|
label="Drop image here to view settings",
|
||||||
|
type="filepath",
|
||||||
|
height=150
|
||||||
|
)
|
||||||
|
|
||||||
|
# TAB 3: MODEL
|
||||||
|
with gr.TabItem("Model"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['image_model_menu'] = gr.Dropdown(
|
||||||
|
choices=utils.get_available_image_models(),
|
||||||
|
value=shared.settings['image_model_menu'],
|
||||||
|
label='Model',
|
||||||
|
elem_classes='slim-dropdown'
|
||||||
|
)
|
||||||
|
shared.gradio['image_refresh_models'] = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40)
|
||||||
|
shared.gradio['image_load_model'] = gr.Button("Load", variant='primary', elem_classes='refresh-button')
|
||||||
|
shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button')
|
||||||
|
|
||||||
|
gr.Markdown("## Settings")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['image_quant'] = gr.Dropdown(
|
||||||
|
label='Quantization',
|
||||||
|
choices=['none', 'bnb-8bit', 'bnb-4bit', 'quanto-8bit', 'quanto-4bit', 'quanto-2bit'],
|
||||||
|
value=shared.settings['image_quant'],
|
||||||
|
info='Quantization method for reduced VRAM usage. Quanto supports lower precisions (2-bit, 4-bit, 8-bit).'
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['image_dtype'] = gr.Dropdown(
|
||||||
|
choices=['bfloat16', 'float16'],
|
||||||
|
value=shared.settings['image_dtype'],
|
||||||
|
label='Data Type',
|
||||||
|
info='bfloat16 recommended for modern GPUs'
|
||||||
|
)
|
||||||
|
shared.gradio['image_attn_backend'] = gr.Dropdown(
|
||||||
|
choices=['sdpa', 'flash_attention_2', 'flash_attention_3'],
|
||||||
|
value=shared.settings['image_attn_backend'],
|
||||||
|
label='Attention Backend',
|
||||||
|
info='SDPA is default. Flash Attention requires compatible GPU.'
|
||||||
|
)
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['image_compile'] = gr.Checkbox(
|
||||||
|
value=shared.settings['image_compile'],
|
||||||
|
label='Compile Model',
|
||||||
|
info='Faster inference after first run. First run will be slow.'
|
||||||
|
)
|
||||||
|
shared.gradio['image_cpu_offload'] = gr.Checkbox(
|
||||||
|
value=shared.settings['image_cpu_offload'],
|
||||||
|
label='CPU Offload',
|
||||||
|
info='Enable for low VRAM GPUs. Slower but uses less memory.'
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['image_download_path'] = gr.Textbox(
|
||||||
|
label="Download model",
|
||||||
|
placeholder="Tongyi-MAI/Z-Image-Turbo",
|
||||||
|
info="Enter HuggingFace path. Use : for branch, e.g. user/model:main"
|
||||||
|
)
|
||||||
|
shared.gradio['image_download_btn'] = gr.Button("Download", variant='primary')
|
||||||
|
shared.gradio['image_model_status'] = gr.Markdown(
|
||||||
|
value=f"Model: **{shared.settings['image_model_menu']}** (not loaded)" if shared.settings['image_model_menu'] != 'None' else "No model selected"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_event_handlers():
|
||||||
|
# Dimension controls
|
||||||
|
shared.gradio['image_aspect_ratio'].change(
|
||||||
|
apply_aspect_ratio,
|
||||||
|
gradio('image_aspect_ratio', 'image_width', 'image_height'),
|
||||||
|
gradio('image_width', 'image_height'),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['image_width'].release(
|
||||||
|
update_height_from_width,
|
||||||
|
gradio('image_width', 'image_aspect_ratio'),
|
||||||
|
gradio('image_height'),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['image_height'].release(
|
||||||
|
update_width_from_height,
|
||||||
|
gradio('image_height', 'image_aspect_ratio'),
|
||||||
|
gradio('image_width'),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['image_swap_btn'].click(
|
||||||
|
swap_dimensions_and_update_ratio,
|
||||||
|
gradio('image_width', 'image_height', 'image_aspect_ratio'),
|
||||||
|
gradio('image_width', 'image_height', 'image_aspect_ratio'),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generation
|
||||||
|
shared.gradio['image_generate_btn'].click(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
|
||||||
|
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
|
||||||
|
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
|
||||||
|
|
||||||
|
shared.gradio['image_prompt'].submit(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
|
||||||
|
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
|
||||||
|
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
|
||||||
|
|
||||||
|
shared.gradio['image_neg_prompt'].submit(
|
||||||
|
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
|
lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then(
|
||||||
|
generate, gradio('interface_state'), gradio('image_output_gallery'), show_progress=False).then(
|
||||||
|
lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn'))
|
||||||
|
|
||||||
|
# Stop button
|
||||||
|
shared.gradio['image_stop_btn'].click(
|
||||||
|
stop_everything_event, None, None, show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model management
|
||||||
|
shared.gradio['image_refresh_models'].click(
|
||||||
|
lambda: gr.update(choices=utils.get_available_image_models()),
|
||||||
|
None,
|
||||||
|
gradio('image_model_menu'),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['image_load_model'].click(
|
||||||
|
load_image_model_wrapper,
|
||||||
|
gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile', 'image_quant'),
|
||||||
|
gradio('image_model_status'),
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['image_unload_model'].click(
|
||||||
|
unload_image_model_wrapper,
|
||||||
|
None,
|
||||||
|
gradio('image_model_status'),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['image_download_btn'].click(
|
||||||
|
download_image_model_wrapper,
|
||||||
|
gradio('image_download_path'),
|
||||||
|
gradio('image_model_status', 'image_model_menu'),
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gallery pagination handlers
|
||||||
|
shared.gradio['image_refresh_history'].click(
|
||||||
|
refresh_gallery,
|
||||||
|
gradio('image_current_page'),
|
||||||
|
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['image_next_page'].click(
|
||||||
|
next_page,
|
||||||
|
gradio('image_current_page'),
|
||||||
|
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['image_prev_page'].click(
|
||||||
|
prev_page,
|
||||||
|
gradio('image_current_page'),
|
||||||
|
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['image_go_to_page'].click(
|
||||||
|
go_to_page,
|
||||||
|
gradio('image_page_input', 'image_current_page'),
|
||||||
|
gradio('image_history_gallery', 'image_current_page', 'image_page_info'),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Image selection from gallery
|
||||||
|
shared.gradio['image_history_gallery'].select(
|
||||||
|
on_gallery_select,
|
||||||
|
gradio('image_current_page'),
|
||||||
|
gradio('image_selected_path', 'image_settings_display'),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send to Generate
|
||||||
|
shared.gradio['image_send_to_generate'].click(
|
||||||
|
send_to_generate,
|
||||||
|
gradio('image_selected_path'),
|
||||||
|
gradio(
|
||||||
|
'image_prompt',
|
||||||
|
'image_neg_prompt',
|
||||||
|
'image_width',
|
||||||
|
'image_height',
|
||||||
|
'image_aspect_ratio',
|
||||||
|
'image_steps',
|
||||||
|
'image_seed',
|
||||||
|
'image_batch_size',
|
||||||
|
'image_batch_count',
|
||||||
|
'image_cfg_scale',
|
||||||
|
'image_gallery_status'
|
||||||
|
),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['image_drop_upload'].change(
|
||||||
|
read_dropped_image_metadata,
|
||||||
|
gradio('image_drop_upload'),
|
||||||
|
gradio('image_settings_display'),
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(state):
|
||||||
|
"""
|
||||||
|
Generate images using the loaded model.
|
||||||
|
Automatically adjusts parameters based on pipeline type.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules.torch_utils import clear_torch_cache
|
||||||
|
|
||||||
|
clear_torch_cache()
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_name = state['image_model_menu']
|
||||||
|
|
||||||
|
if not model_name or model_name == 'None':
|
||||||
|
logger.error("No image model selected. Go to the Model tab and select a model.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if shared.image_model is None:
|
||||||
|
result = load_image_model(
|
||||||
|
model_name,
|
||||||
|
dtype=state['image_dtype'],
|
||||||
|
attn_backend=state['image_attn_backend'],
|
||||||
|
cpu_offload=state['image_cpu_offload'],
|
||||||
|
compile_model=state['image_compile'],
|
||||||
|
quant_method=state['image_quant']
|
||||||
|
)
|
||||||
|
if result is None:
|
||||||
|
logger.error(f"Failed to load model `{model_name}`.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
shared.image_model_name = model_name
|
||||||
|
|
||||||
|
seed = state['image_seed']
|
||||||
|
if seed == -1:
|
||||||
|
seed = np.random.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
if device is None:
|
||||||
|
device = "cpu"
|
||||||
|
generator = torch.Generator(device).manual_seed(int(seed))
|
||||||
|
|
||||||
|
all_images = []
|
||||||
|
|
||||||
|
# Get pipeline type for parameter adjustment
|
||||||
|
pipeline_type = getattr(shared, 'image_pipeline_type', None)
|
||||||
|
if pipeline_type is None:
|
||||||
|
pipeline_type = get_pipeline_type(shared.image_model)
|
||||||
|
|
||||||
|
# Process Prompt
|
||||||
|
prompt = state['image_prompt']
|
||||||
|
|
||||||
|
# Apply "Positive Magic" for Qwen models only
|
||||||
|
if pipeline_type == 'qwenimage':
|
||||||
|
magic_suffix = ", Ultra HD, 4K, cinematic composition"
|
||||||
|
# Avoid duplication if user already added it
|
||||||
|
if magic_suffix.strip(", ") not in prompt:
|
||||||
|
prompt += magic_suffix
|
||||||
|
|
||||||
|
# Reset stop flag at start
|
||||||
|
shared.stop_everything = False
|
||||||
|
|
||||||
|
# Callback to check for interruption during diffusion steps
|
||||||
|
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
|
||||||
|
if shared.stop_everything:
|
||||||
|
pipe._interrupt = True
|
||||||
|
|
||||||
|
return callback_kwargs
|
||||||
|
|
||||||
|
# Build generation kwargs
|
||||||
|
gen_kwargs = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": state['image_neg_prompt'],
|
||||||
|
"height": int(state['image_height']),
|
||||||
|
"width": int(state['image_width']),
|
||||||
|
"num_inference_steps": int(state['image_steps']),
|
||||||
|
"num_images_per_prompt": int(state['image_batch_size']),
|
||||||
|
"generator": generator,
|
||||||
|
"callback_on_step_end": interrupt_callback,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add pipeline-specific parameters for CFG
|
||||||
|
cfg_val = state.get('image_cfg_scale', 0.0)
|
||||||
|
|
||||||
|
if pipeline_type == 'qwenimage':
|
||||||
|
# Qwen-Image uses true_cfg_scale (typically 4.0)
|
||||||
|
gen_kwargs["true_cfg_scale"] = cfg_val
|
||||||
|
else:
|
||||||
|
# Z-Image and others use guidance_scale (typically 0.0 for Turbo)
|
||||||
|
gen_kwargs["guidance_scale"] = cfg_val
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
for i in range(int(state['image_batch_count'])):
|
||||||
|
if shared.stop_everything:
|
||||||
|
break
|
||||||
|
|
||||||
|
generator.manual_seed(int(seed + i))
|
||||||
|
batch_results = shared.image_model(**gen_kwargs).images
|
||||||
|
all_images.extend(batch_results)
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
save_generated_images(all_images, state, seed)
|
||||||
|
|
||||||
|
total_images = int(state['image_batch_count']) * int(state['image_batch_size'])
|
||||||
|
total_steps = state["image_steps"] * int(state['image_batch_count'])
|
||||||
|
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
|
||||||
|
|
||||||
|
return all_images
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Image generation failed: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method):
|
||||||
|
if not model_name or model_name == 'None':
|
||||||
|
yield "No model selected"
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield f"Loading `{model_name}`..."
|
||||||
|
unload_image_model()
|
||||||
|
|
||||||
|
result = load_image_model(
|
||||||
|
model_name,
|
||||||
|
dtype=dtype,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
cpu_offload=cpu_offload,
|
||||||
|
compile_model=compile_model,
|
||||||
|
quant_method=quant_method
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is not None:
|
||||||
|
shared.image_model_name = model_name
|
||||||
|
yield f"✓ Loaded **{model_name}** (quantization: {quant_method})"
|
||||||
|
else:
|
||||||
|
yield f"✗ Failed to load `{model_name}`"
|
||||||
|
except Exception:
|
||||||
|
yield f"Error:\n```\n{traceback.format_exc()}\n```"
|
||||||
|
|
||||||
|
|
||||||
|
def unload_image_model_wrapper():
|
||||||
|
previous_name = shared.image_model_name
|
||||||
|
unload_image_model()
|
||||||
|
if previous_name != 'None':
|
||||||
|
return f"Model: **{previous_name}** (unloaded)"
|
||||||
|
return "No model loaded"
|
||||||
|
|
||||||
|
|
||||||
|
def download_image_model_wrapper(model_path):
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
if not model_path:
|
||||||
|
yield "No model specified", gr.update()
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_path = model_path.strip()
|
||||||
|
if model_path.startswith('https://huggingface.co/'):
|
||||||
|
model_path = model_path[len('https://huggingface.co/'):]
|
||||||
|
elif model_path.startswith('huggingface.co/'):
|
||||||
|
model_path = model_path[len('huggingface.co/'):]
|
||||||
|
|
||||||
|
if ':' in model_path:
|
||||||
|
model_id, branch = model_path.rsplit(':', 1)
|
||||||
|
else:
|
||||||
|
model_id, branch = model_path, 'main'
|
||||||
|
|
||||||
|
folder_name = model_id.replace('/', '_')
|
||||||
|
output_folder = Path(shared.args.image_model_dir) / folder_name
|
||||||
|
|
||||||
|
yield f"Downloading `{model_id}` (branch: {branch})...", gr.update()
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
revision=branch,
|
||||||
|
local_dir=output_folder,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_choices = utils.get_available_image_models()
|
||||||
|
yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name)
|
||||||
|
except Exception:
|
||||||
|
yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update()
|
||||||
|
|
@ -86,7 +86,7 @@ def check_model_loaded():
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
def resolve_model_path(model_name_or_path):
|
def resolve_model_path(model_name_or_path, image_model=False):
|
||||||
"""
|
"""
|
||||||
Resolves a model path, checking for a direct path
|
Resolves a model path, checking for a direct path
|
||||||
before the default models directory.
|
before the default models directory.
|
||||||
|
|
@ -95,6 +95,8 @@ def resolve_model_path(model_name_or_path):
|
||||||
path_candidate = Path(model_name_or_path)
|
path_candidate = Path(model_name_or_path)
|
||||||
if path_candidate.exists():
|
if path_candidate.exists():
|
||||||
return path_candidate
|
return path_candidate
|
||||||
|
elif image_model:
|
||||||
|
return Path(f'{shared.args.image_model_dir}/{model_name_or_path}')
|
||||||
else:
|
else:
|
||||||
return Path(f'{shared.args.model_dir}/{model_name_or_path}')
|
return Path(f'{shared.args.model_dir}/{model_name_or_path}')
|
||||||
|
|
||||||
|
|
@ -153,6 +155,24 @@ def get_available_models():
|
||||||
return filtered_gguf_files + model_dirs
|
return filtered_gguf_files + model_dirs
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_image_models():
|
||||||
|
model_dir = Path(shared.args.image_model_dir)
|
||||||
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Find valid model directories
|
||||||
|
model_dirs = []
|
||||||
|
for item in os.listdir(model_dir):
|
||||||
|
item_path = model_dir / item
|
||||||
|
if not item_path.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_dirs.append(item)
|
||||||
|
|
||||||
|
model_dirs = sorted(model_dirs, key=natural_keys)
|
||||||
|
|
||||||
|
return model_dirs
|
||||||
|
|
||||||
|
|
||||||
def get_available_ggufs():
|
def get_available_ggufs():
|
||||||
model_list = []
|
model_list = []
|
||||||
model_dir = Path(shared.args.model_dir)
|
model_dir = Path(shared.args.model_dir)
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ huggingface-hub==0.36.0
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
|
optimum-quanto==0.2.7
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
|
|
@ -34,6 +35,9 @@ wandb
|
||||||
gradio==4.37.*
|
gradio==4.37.*
|
||||||
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
||||||
|
|
||||||
|
# Diffusers
|
||||||
|
diffusers @ git+https://github.com/huggingface/diffusers.git@edf36f5128abf3e6ecf92b5145115514363c58e6
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.14
|
flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ huggingface-hub==0.36.0
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
|
optimum-quanto==0.2.7
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
|
|
@ -32,6 +33,9 @@ wandb
|
||||||
gradio==4.37.*
|
gradio==4.37.*
|
||||||
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
||||||
|
|
||||||
|
# Diffusers
|
||||||
|
diffusers @ git+https://github.com/huggingface/diffusers.git@edf36f5128abf3e6ecf92b5145115514363c58e6
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.14
|
flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ huggingface-hub==0.36.0
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
|
optimum-quanto==0.2.7
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
|
|
@ -32,6 +33,9 @@ wandb
|
||||||
gradio==4.37.*
|
gradio==4.37.*
|
||||||
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
||||||
|
|
||||||
|
# Diffusers
|
||||||
|
diffusers @ git+https://github.com/huggingface/diffusers.git@edf36f5128abf3e6ecf92b5145115514363c58e6
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.14
|
flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ huggingface-hub==0.36.0
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
|
optimum-quanto==0.2.7
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
|
|
@ -32,6 +33,9 @@ wandb
|
||||||
gradio==4.37.*
|
gradio==4.37.*
|
||||||
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
||||||
|
|
||||||
|
# Diffusers
|
||||||
|
diffusers @ git+https://github.com/huggingface/diffusers.git@edf36f5128abf3e6ecf92b5145115514363c58e6
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.14
|
flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ huggingface-hub==0.36.0
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
|
optimum-quanto==0.2.7
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
|
|
@ -32,6 +33,9 @@ wandb
|
||||||
gradio==4.37.*
|
gradio==4.37.*
|
||||||
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
||||||
|
|
||||||
|
# Diffusers
|
||||||
|
diffusers @ git+https://github.com/huggingface/diffusers.git@edf36f5128abf3e6ecf92b5145115514363c58e6
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.14
|
flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ huggingface-hub==0.36.0
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
|
optimum-quanto==0.2.7
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
|
|
@ -32,6 +33,9 @@ wandb
|
||||||
gradio==4.37.*
|
gradio==4.37.*
|
||||||
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
||||||
|
|
||||||
|
# Diffusers
|
||||||
|
diffusers @ git+https://github.com/huggingface/diffusers.git@edf36f5128abf3e6ecf92b5145115514363c58e6
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.14
|
flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ huggingface-hub==0.36.0
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
|
optimum-quanto==0.2.7
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
|
|
@ -32,6 +33,9 @@ wandb
|
||||||
gradio==4.37.*
|
gradio==4.37.*
|
||||||
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
||||||
|
|
||||||
|
# Diffusers
|
||||||
|
diffusers @ git+https://github.com/huggingface/diffusers.git@edf36f5128abf3e6ecf92b5145115514363c58e6
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.14
|
flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ huggingface-hub==0.36.0
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
|
optimum-quanto==0.2.7
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
|
|
@ -34,6 +35,9 @@ wandb
|
||||||
gradio==4.37.*
|
gradio==4.37.*
|
||||||
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
||||||
|
|
||||||
|
# Diffusers
|
||||||
|
diffusers @ git+https://github.com/huggingface/diffusers.git@edf36f5128abf3e6ecf92b5145115514363c58e6
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.14
|
flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ huggingface-hub==0.36.0
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
markdown
|
markdown
|
||||||
numpy==2.2.*
|
numpy==2.2.*
|
||||||
|
optimum-quanto==0.2.7
|
||||||
pandas
|
pandas
|
||||||
peft==0.18.*
|
peft==0.18.*
|
||||||
Pillow>=9.5.0
|
Pillow>=9.5.0
|
||||||
|
|
@ -32,6 +33,9 @@ wandb
|
||||||
gradio==4.37.*
|
gradio==4.37.*
|
||||||
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
https://github.com/oobabooga/gradio/releases/download/custom-build/gradio_client-1.0.2+custom.1-py3-none-any.whl
|
||||||
|
|
||||||
|
# Diffusers
|
||||||
|
diffusers @ git+https://github.com/huggingface/diffusers.git@edf36f5128abf3e6ecf92b5145115514363c58e6
|
||||||
|
|
||||||
# API
|
# API
|
||||||
flask_cloudflared==0.0.14
|
flask_cloudflared==0.0.14
|
||||||
sse-starlette==1.6.5
|
sse-starlette==1.6.5
|
||||||
|
|
|
||||||
24
server.py
24
server.py
|
|
@ -5,6 +5,7 @@ from pathlib import Path
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.block_requests import OpenMonkeyPatch, RequestBlocker
|
from modules.block_requests import OpenMonkeyPatch, RequestBlocker
|
||||||
|
from modules.image_models import load_image_model
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.prompts import load_prompt
|
from modules.prompts import load_prompt
|
||||||
|
|
||||||
|
|
@ -50,6 +51,7 @@ from modules import (
|
||||||
ui_chat,
|
ui_chat,
|
||||||
ui_default,
|
ui_default,
|
||||||
ui_file_saving,
|
ui_file_saving,
|
||||||
|
ui_image_generation,
|
||||||
ui_model_menu,
|
ui_model_menu,
|
||||||
ui_notebook,
|
ui_notebook,
|
||||||
ui_parameters,
|
ui_parameters,
|
||||||
|
|
@ -163,6 +165,7 @@ def create_interface():
|
||||||
ui_chat.create_character_settings_ui() # Character tab
|
ui_chat.create_character_settings_ui() # Character tab
|
||||||
ui_model_menu.create_ui() # Model tab
|
ui_model_menu.create_ui() # Model tab
|
||||||
if not shared.args.portable:
|
if not shared.args.portable:
|
||||||
|
ui_image_generation.create_ui() # Image generation tab
|
||||||
training.create_ui() # Training tab
|
training.create_ui() # Training tab
|
||||||
ui_session.create_ui() # Session tab
|
ui_session.create_ui() # Session tab
|
||||||
|
|
||||||
|
|
@ -170,6 +173,8 @@ def create_interface():
|
||||||
ui_chat.create_event_handlers()
|
ui_chat.create_event_handlers()
|
||||||
ui_default.create_event_handlers()
|
ui_default.create_event_handlers()
|
||||||
ui_notebook.create_event_handlers()
|
ui_notebook.create_event_handlers()
|
||||||
|
if not shared.args.portable:
|
||||||
|
ui_image_generation.create_event_handlers()
|
||||||
|
|
||||||
# Other events
|
# Other events
|
||||||
ui_file_saving.create_event_handlers()
|
ui_file_saving.create_event_handlers()
|
||||||
|
|
@ -256,6 +261,9 @@ if __name__ == "__main__":
|
||||||
if new_settings:
|
if new_settings:
|
||||||
shared.settings.update(new_settings)
|
shared.settings.update(new_settings)
|
||||||
|
|
||||||
|
# Apply CLI overrides for image model settings (CLI flags take precedence over saved settings)
|
||||||
|
shared.apply_image_model_cli_overrides()
|
||||||
|
|
||||||
# Fallback settings for models
|
# Fallback settings for models
|
||||||
shared.model_config['.*'] = get_fallback_settings()
|
shared.model_config['.*'] = get_fallback_settings()
|
||||||
shared.model_config.move_to_end('.*', last=False) # Move to the beginning
|
shared.model_config.move_to_end('.*', last=False) # Move to the beginning
|
||||||
|
|
@ -313,6 +321,22 @@ if __name__ == "__main__":
|
||||||
if shared.args.lora:
|
if shared.args.lora:
|
||||||
add_lora_to_model(shared.args.lora)
|
add_lora_to_model(shared.args.lora)
|
||||||
|
|
||||||
|
# Load image model if specified via CLI
|
||||||
|
if shared.args.image_model:
|
||||||
|
logger.info(f"Loading image model: {shared.args.image_model}")
|
||||||
|
result = load_image_model(
|
||||||
|
shared.args.image_model,
|
||||||
|
dtype=shared.settings.get('image_dtype', 'bfloat16'),
|
||||||
|
attn_backend=shared.settings.get('image_attn_backend', 'sdpa'),
|
||||||
|
cpu_offload=shared.settings.get('image_cpu_offload', False),
|
||||||
|
compile_model=shared.settings.get('image_compile', False),
|
||||||
|
quant_method=shared.settings.get('image_quant', 'none')
|
||||||
|
)
|
||||||
|
if result is not None:
|
||||||
|
shared.image_model_name = shared.args.image_model
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to load image model: {shared.args.image_model}")
|
||||||
|
|
||||||
shared.generation_lock = Lock()
|
shared.generation_lock = Lock()
|
||||||
|
|
||||||
if shared.args.idle_timeout > 0:
|
if shared.args.idle_timeout > 0:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue