Compare commits

...

18 commits

Author SHA1 Message Date
oobabooga 5848c7884d Increase the height of the image output gallery 2025-12-05 10:24:51 -08:00
oobabooga c11c14590a Image: Better LLM variation default prompt 2025-12-05 08:08:11 -08:00
oobabooga 0dd468245c Image: Add back the gallery cache (for performance) 2025-12-05 07:11:38 -08:00
oobabooga b63d57158d Image: Add TGW as a prefix to output images 2025-12-05 05:59:54 -08:00
oobabooga afa29b9554 Image: Several fixes 2025-12-05 05:58:57 -08:00
oobabooga 8eac99599a Image: Better LLM variation default prompt 2025-12-04 19:58:06 -08:00
oobabooga b4f06a50b0 fix: Pass bos_token and eos_token from metadata to jinja2
Fixes loading Seed-Instruct-36B
2025-12-04 19:11:31 -08:00
oobabooga 15c6e43597 Image: Add a revised_prompt field to API results for OpenAI compatibility 2025-12-04 17:41:09 -08:00
oobabooga 56f2a9512f Revert "Image: Add the LLM-generated prompt to the API result"
This reverts commit c7ad28a4cd.
2025-12-04 17:34:27 -08:00
oobabooga 3ef428efaa Image: Remove llm_variations from the API 2025-12-04 17:34:17 -08:00
oobabooga c7ad28a4cd Image: Add the LLM-generated prompt to the API result 2025-12-04 17:22:08 -08:00
oobabooga b451bac082 Image: Improve a log message 2025-12-04 16:33:46 -08:00
oobabooga 47a0fcd614 Image: PNG metadata improvements 2025-12-04 16:25:48 -08:00
oobabooga ac31a7c008 Image: Organize the UI 2025-12-04 15:45:04 -08:00
oobabooga a90739f498 Image: Better LLM variation default prompt 2025-12-04 10:50:40 -08:00
oobabooga ffef3c7b1d Image: Make the LLM Variations prompt configurable 2025-12-04 10:44:35 -08:00
oobabooga 5763947c37 Image: Simplify the API code, add the llm_variations option 2025-12-04 10:23:00 -08:00
oobabooga 2793153717 Image: Add LLM-generated prompt variations 2025-12-04 08:10:24 -08:00
11 changed files with 203 additions and 189 deletions

View file

@ -28,8 +28,7 @@ A Gradio web UI for Large Language Models.
- 100% offline and private, with zero telemetry, external resources, or remote update requests.
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
- **Vision (multimodal models)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
Image generation: A dedicated tab for diffusers models like Z-Image-Turbo and Qwen-Image. Features 4-bit/8-bit quantization and a persistent gallery with metadata (tutorial).
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo** and **Qwen-Image**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
- **Web search**: Optionally search the internet with LLM-generated queries to add context to the conversation.
- Aesthetic UI with dark and light themes.
- Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions.

View file

@ -1692,8 +1692,8 @@ button#swap-height-width {
}
#image-output-gallery, #image-output-gallery > :nth-child(2) {
height: calc(100vh - 83px);
max-height: calc(100vh - 83px);
height: calc(100vh - 66px);
max-height: calc(100vh - 66px);
}
#image-history-gallery, #image-history-gallery > :nth-child(2) {
@ -1791,3 +1791,9 @@ button#swap-height-width {
.dark #image-progress .image-ai-progress-text {
color: #888;
}
#llm-prompt-variations {
position: absolute;
top: 0;
left: calc(100% - 174px);
}

View file

@ -4,120 +4,61 @@ OpenAI-compatible image generation using local diffusion models.
import base64
import io
import json
import os
import time
from datetime import datetime
import numpy as np
from extensions.openai.errors import ServiceUnavailableError
from modules import shared
from modules.logging_colors import logger
from PIL.PngImagePlugin import PngInfo
def generations(prompt: str, size: str, response_format: str, n: int,
negative_prompt: str = "", steps: int = 9, seed: int = -1,
cfg_scale: float = 0.0, batch_count: int = 1):
def generations(request):
"""
Generate images using the loaded diffusion model.
Args:
prompt: Text description of the desired image
size: Image dimensions as "WIDTHxHEIGHT"
response_format: 'url' or 'b64_json'
n: Number of images per batch
negative_prompt: What to avoid in the image
steps: Number of inference steps
seed: Random seed (-1 for random)
cfg_scale: Classifier-free guidance scale
batch_count: Number of sequential batches
Returns:
dict with 'created' timestamp and 'data' list of images
Returns dict with 'created' timestamp and 'data' list of images.
"""
import torch
from modules.image_models import get_pipeline_type
from modules.torch_utils import clear_torch_cache, get_device
from modules.ui_image_generation import generate
if shared.image_model is None:
raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.")
clear_torch_cache()
width, height = request.get_width_height()
# Parse dimensions
try:
width, height = [int(x) for x in size.split('x')]
except (ValueError, IndexError):
width, height = 1024, 1024
# Build state dict: GenerationOptions fields + image-specific keys
state = request.model_dump()
state.update({
'image_model_menu': shared.image_model_name,
'image_prompt': request.prompt,
'image_neg_prompt': request.negative_prompt,
'image_width': width,
'image_height': height,
'image_steps': request.steps,
'image_seed': request.image_seed,
'image_batch_size': request.batch_size,
'image_batch_count': request.batch_count,
'image_cfg_scale': request.cfg_scale,
'image_llm_variations': False,
})
# Handle seed
if seed == -1:
seed = np.random.randint(0, 2**32 - 1)
# Exhaust generator, keep final result
images = []
for images, _ in generate(state, save_images=False):
pass
device = get_device() or "cpu"
generator = torch.Generator(device).manual_seed(int(seed))
# Get pipeline type for CFG parameter name
pipeline_type = getattr(shared, 'image_pipeline_type', None) or get_pipeline_type(shared.image_model)
# Build generation kwargs
gen_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"num_inference_steps": steps,
"num_images_per_prompt": n,
"generator": generator,
}
# Pipeline-specific CFG parameter
if pipeline_type == 'qwenimage':
gen_kwargs["true_cfg_scale"] = cfg_scale
else:
gen_kwargs["guidance_scale"] = cfg_scale
# Generate
all_images = []
t0 = time.time()
shared.stop_everything = False
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
if shared.stop_everything:
pipe._interrupt = True
return callback_kwargs
gen_kwargs["callback_on_step_end"] = interrupt_callback
for i in range(batch_count):
if shared.stop_everything:
break
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(**gen_kwargs).images
all_images.extend(batch_results)
t1 = time.time()
total_images = len(all_images)
total_steps = steps * batch_count
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
# Save images
_save_images(all_images, prompt, negative_prompt, width, height, steps, seed, cfg_scale)
if not images:
raise ServiceUnavailableError("Image generation failed or produced no images.")
# Build response
resp = {
'created': int(time.time()),
'data': []
}
for img in all_images:
resp = {'created': int(time.time()), 'data': []}
for img in images:
b64 = _image_to_base64(img)
if response_format == 'b64_json':
resp['data'].append({'b64_json': b64})
image_obj = {'revised_prompt': request.prompt}
if request.response_format == 'b64_json':
image_obj['b64_json'] = b64
else:
resp['data'].append({'url': f'data:image/png;base64,{b64}'})
image_obj['url'] = f'data:image/png;base64,{b64}'
resp['data'].append(image_obj)
return resp
@ -126,29 +67,3 @@ def _image_to_base64(image) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def _save_images(images, prompt, negative_prompt, width, height, steps, seed, cfg_scale):
"""Save images with metadata."""
date_str = datetime.now().strftime("%Y-%m-%d")
folder = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder, exist_ok=True)
metadata = {
'image_prompt': prompt,
'image_neg_prompt': negative_prompt,
'image_width': width,
'image_height': height,
'image_steps': steps,
'image_seed': seed,
'image_cfg_scale': cfg_scale,
'model': getattr(shared, 'image_model_name', 'unknown'),
}
for idx, img in enumerate(images):
ts = datetime.now().strftime("%H-%M-%S")
filepath = os.path.join(folder, f"{ts}_{seed:010d}_{idx:03d}.png")
png_info = PngInfo()
png_info.add_text("image_gen_settings", json.dumps(metadata))
img.save(filepath, pnginfo=png_info)

View file

@ -7,23 +7,24 @@ import traceback
from collections import deque
from threading import Thread
import extensions.openai.completions as OAIcompletions
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
import uvicorn
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
from fastapi.responses import JSONResponse
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool
import extensions.openai.completions as OAIcompletions
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from modules import shared
from modules.logging_colors import logger
from modules.models import unload_model
from modules.text_generation import stop_everything_event
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool
from .typing import (
ChatCompletionRequest,
@ -232,20 +233,7 @@ async def handle_image_generation(request_data: ImageGenerationRequest):
import extensions.openai.images as OAIimages
async with image_generation_semaphore:
width, height = request_data.get_width_height()
response = await asyncio.to_thread(
OAIimages.generations,
prompt=request_data.prompt,
size=f"{width}x{height}",
response_format=request_data.response_format,
n=request_data.batch_size, # <-- use resolved batch_size
negative_prompt=request_data.negative_prompt,
steps=request_data.steps,
seed=request_data.seed,
cfg_scale=request_data.cfg_scale,
batch_count=request_data.batch_count,
)
response = await asyncio.to_thread(OAIimages.generations, request_data)
return JSONResponse(response)

View file

@ -265,16 +265,13 @@ class LoadLorasRequest(BaseModel):
class ImageGenerationRequest(BaseModel):
"""OpenAI-compatible image generation request with extended parameters."""
# Required
"""Image-specific parameters for generation."""
prompt: str
# Generation parameters
negative_prompt: str = ""
size: str = Field(default="1024x1024", description="'WIDTHxHEIGHT'")
steps: int = Field(default=9, ge=1)
cfg_scale: float = Field(default=0.0, ge=0.0)
seed: int = Field(default=-1, description="-1 for random")
image_seed: int = Field(default=-1, description="-1 for random")
batch_size: int | None = Field(default=None, ge=1, description="Parallel batch size (VRAM heavy)")
n: int = Field(default=1, ge=1, description="Alias for batch_size (OpenAI compatibility)")
batch_count: int = Field(default=1, ge=1, description="Sequential batch count")
@ -286,7 +283,6 @@ class ImageGenerationRequest(BaseModel):
@model_validator(mode='after')
def resolve_batch_size(self):
"""Use batch_size if provided, otherwise fall back to n."""
if self.batch_size is None:
self.batch_size = self.n
return self

View file

@ -112,7 +112,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
add_generation_prompt=False,
enable_thinking=state['enable_thinking'],
reasoning_effort=state['reasoning_effort'],
thinking_budget=-1 if state.get('enable_thinking', True) else 0
thinking_budget=-1 if state.get('enable_thinking', True) else 0,
bos_token=shared.bos_token,
eos_token=shared.eos_token,
)
chat_renderer = partial(
@ -475,7 +477,7 @@ def get_stopping_strings(state):
if state['mode'] in ['instruct', 'chat-instruct']:
template = jinja_env.from_string(state['instruction_template_str'])
renderer = partial(template.render, add_generation_prompt=False)
renderer = partial(template.render, add_generation_prompt=False, bos_token=shared.bos_token, eos_token=shared.eos_token)
renderers.append(renderer)
if state['mode'] in ['chat']:

View file

@ -141,16 +141,24 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
if not cpu_offload:
pipe.to(get_device())
# Set attention backend (if supported by the pipeline)
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'):
if attn_backend == 'flash_attention_2':
pipe.transformer.set_attention_backend("flash")
# sdpa is the default, no action needed
modules = ["transformer", "unet"]
# Set attention backend
if attn_backend == 'flash_attention_2':
for name in modules:
mod = getattr(pipe, name, None)
if hasattr(mod, "set_attention_backend"):
mod.set_attention_backend("flash")
break
# Compile model
if compile_model:
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'compile'):
logger.info("Compiling model (first run will be slow)...")
pipe.transformer.compile()
for name in modules:
mod = getattr(pipe, name, None)
if hasattr(mod, "compile"):
logger.info("Compiling model (first run will be slow)...")
mod.compile()
break
if cpu_offload:
pipe.enable_model_cpu_offload()

View file

@ -89,8 +89,9 @@ def get_model_metadata(model):
else:
bos_token = ""
template = template.replace('eos_token', "'{}'".format(eos_token))
template = template.replace('bos_token', "'{}'".format(bos_token))
shared.bos_token = bos_token
shared.eos_token = eos_token
template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL)
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
@ -160,13 +161,16 @@ def get_model_metadata(model):
# 4. If a template was found from any source, process it
if template:
shared.bos_token = '<s>'
shared.eos_token = '</s>'
for k in ['eos_token', 'bos_token']:
if k in metadata:
value = metadata[k]
if isinstance(value, dict):
value = value['content']
template = template.replace(k, "'{}'".format(value))
setattr(shared, k, value)
template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL)
template = re.sub(r'raise_exception\([^)]*\)', "''", template)

View file

@ -19,6 +19,8 @@ is_seq2seq = False
is_multimodal = False
model_dirty_from_training = False
lora_names = []
bos_token = '<s>'
eos_token = '</s>'
# Image model variables
image_model = None
@ -319,6 +321,8 @@ settings = {
'image_seed': -1,
'image_batch_size': 1,
'image_batch_count': 1,
'image_llm_variations': False,
'image_llm_variations_prompt': 'Write a variation of the image generation prompt above. Consider the intent of the user with that prompt and write something that will likely please them, with added details. Output only the new prompt. Do not add any explanations, prefixes, or additional text.',
'image_model_menu': 'None',
'image_dtype': 'bfloat16',
'image_attn_backend': 'sdpa',

View file

@ -293,6 +293,8 @@ def list_interface_input_elements():
'image_seed',
'image_batch_size',
'image_batch_count',
'image_llm_variations',
'image_llm_variations_prompt',
'image_model_menu',
'image_dtype',
'image_attn_backend',
@ -547,6 +549,8 @@ def setup_auto_save():
'image_seed',
'image_batch_size',
'image_batch_count',
'image_llm_variations',
'image_llm_variations_prompt',
'image_model_menu',
'image_dtype',
'image_attn_backend',

View file

@ -18,7 +18,7 @@ from modules.image_models import (
from modules.image_utils import open_image_safely
from modules.logging_colors import logger
from modules.text_generation import stop_everything_event
from modules.utils import gradio
from modules.utils import check_model_loaded, gradio
ASPECT_RATIOS = {
"1:1 Square": (1, 1),
@ -40,8 +40,6 @@ METADATA_SETTINGS_KEYS = [
'image_aspect_ratio',
'image_steps',
'image_seed',
'image_batch_size',
'image_batch_count',
'image_cfg_scale',
]
@ -136,6 +134,9 @@ def build_generation_metadata(state, actual_seed):
def save_generated_images(images, state, actual_seed):
"""Save images with generation metadata embedded in PNG."""
if shared.args.multi_user:
return
date_str = datetime.now().strftime("%Y-%m-%d")
folder_path = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder_path, exist_ok=True)
@ -145,7 +146,7 @@ def save_generated_images(images, state, actual_seed):
for idx, img in enumerate(images):
timestamp = datetime.now().strftime("%H-%M-%S")
filename = f"{timestamp}_{actual_seed:010d}_{idx:03d}.png"
filename = f"TGW_{timestamp}_{actual_seed:010d}_{idx:03d}.png"
filepath = os.path.join(folder_path, filename)
# Create PNG metadata
@ -159,9 +160,14 @@ def save_generated_images(images, state, actual_seed):
def read_image_metadata(image_path):
"""Read generation metadata from PNG file."""
try:
with open_image_safely(image_path) as img:
img = open_image_safely(image_path)
if img is None:
return None
try:
if hasattr(img, 'text') and 'image_gen_settings' in img.text:
return json.loads(img.text['image_gen_settings'])
finally:
img.close()
except Exception as e:
logger.debug(f"Could not read metadata from {image_path}: {e}")
return None
@ -184,8 +190,6 @@ def format_metadata_for_display(metadata):
('image_steps', 'Steps'),
('image_cfg_scale', 'CFG Scale'),
('image_seed', 'Seed'),
('image_batch_size', 'Batch Size'),
('image_batch_count', 'Batch Count'),
('model', 'Model'),
('generated_at', 'Generated At'),
]
@ -293,7 +297,6 @@ def on_gallery_select(evt: gr.SelectData, current_page):
if not _image_cache:
get_all_history_images()
# Get the current page's images to find the actual file path
all_images = _image_cache
total_images = len(all_images)
@ -314,11 +317,11 @@ def on_gallery_select(evt: gr.SelectData, current_page):
def send_to_generate(selected_image_path):
"""Load settings from selected image and return updates for all Generate tab inputs."""
if not selected_image_path or not os.path.exists(selected_image_path):
return [gr.update()] * 10 + ["No image selected"]
return [gr.update()] * 8 + ["No image selected"]
metadata = read_image_metadata(selected_image_path)
if not metadata:
return [gr.update()] * 10 + ["No settings found in this image"]
return [gr.update()] * 8 + ["No settings found in this image"]
# Return updates for each input element in order
updates = [
@ -329,8 +332,6 @@ def send_to_generate(selected_image_path):
gr.update(value=metadata.get('image_aspect_ratio', '1:1 Square')),
gr.update(value=metadata.get('image_steps', 9)),
gr.update(value=metadata.get('image_seed', -1)),
gr.update(value=metadata.get('image_batch_size', 1)),
gr.update(value=metadata.get('image_batch_count', 1)),
gr.update(value=metadata.get('image_cfg_scale', 0.0)),
]
@ -370,6 +371,19 @@ def create_ui():
lines=3,
value=shared.settings['image_neg_prompt']
)
shared.gradio['image_llm_variations'] = gr.Checkbox(
value=shared.settings['image_llm_variations'],
label='LLM Prompt Variations',
elem_id="llm-prompt-variations",
)
shared.gradio['image_llm_variations_prompt'] = gr.Textbox(
value=shared.settings['image_llm_variations_prompt'],
label='Variation Prompt',
lines=3,
placeholder='Instructions for generating prompt variations...',
visible=shared.settings['image_llm_variations'],
info='Use the loaded LLM to generate creative prompt variations for each sequential batch.'
)
shared.gradio['image_generate_btn'] = gr.Button("Generate", variant="primary", size="lg")
shared.gradio['image_stop_btn'] = gr.Button("Stop", size="lg", visible=False)
@ -406,6 +420,7 @@ def create_ui():
info="Z-Image Turbo: 0.0 | Qwen: 4.0"
)
shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random")
with gr.Column():
shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
shared.gradio['image_batch_count'] = gr.Slider(1, 128, value=shared.settings['image_batch_count'], step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
@ -647,8 +662,6 @@ def create_event_handlers():
'image_aspect_ratio',
'image_steps',
'image_seed',
'image_batch_size',
'image_batch_count',
'image_cfg_scale',
'image_gallery_status'
),
@ -663,6 +676,68 @@ def create_event_handlers():
show_progress=False
)
# LLM Variations visibility toggle
shared.gradio['image_llm_variations'].change(
lambda x: gr.update(visible=x),
gradio('image_llm_variations'),
gradio('image_llm_variations_prompt'),
show_progress=False
)
def generate_prompt_variation(state):
"""Generate a creative variation of the image prompt using the LLM."""
from modules.chat import generate_chat_prompt
from modules.text_generation import generate_reply
prompt = state['image_prompt']
# Check if LLM is loaded
model_loaded, _ = check_model_loaded()
if not model_loaded:
logger.warning("No LLM loaded for prompt variation. Using original prompt.")
return prompt
# Get the custom variation prompt or use default
variation_instruction = state.get('image_llm_variations_prompt', '')
if not variation_instruction:
variation_instruction = 'Write a variation of the image generation prompt above. Consider the intent of the user with that prompt and write something that will likely please them, with added details. Output only the new prompt. Do not add any explanations, prefixes, or additional text.'
augmented_message = f"{prompt}\n\n=====\n\n{variation_instruction}"
# Use minimal state for generation
var_state = state.copy()
var_state['history'] = {'internal': [], 'visible': [], 'metadata': {}}
var_state['auto_max_new_tokens'] = True
var_state['enable_thinking'] = False
var_state['reasoning_effort'] = 'low'
var_state['start_with'] = ""
formatted_prompt = generate_chat_prompt(augmented_message, var_state)
variation = ""
for reply in generate_reply(formatted_prompt, var_state, stopping_strings=[], is_chat=True):
variation = reply
# Strip thinking blocks if present
if "</think>" in variation:
variation = variation.rsplit("</think>", 1)[1]
elif "<|start|>assistant<|channel|>final<|message|>" in variation:
variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1]
elif "</seed:think>" in variation:
variation = variation.rsplit("</seed:think>", 1)[1]
variation = variation.strip()
if len(variation) >= 2 and variation.startswith('"') and variation.endswith('"'):
variation = variation[1:-1]
if variation:
logger.info("Prompt variation:")
print(variation)
return variation
return prompt
def progress_bar_html(progress=0, text=""):
"""Generate HTML for progress bar. Empty div when progress <= 0."""
@ -671,13 +746,13 @@ def progress_bar_html(progress=0, text=""):
return f'''<div class="image-ai-progress-wrapper">
<div class="image-ai-progress-track">
<div class="image-ai-progress-fill" style="width: {progress*100:.1f}%;"></div>
<div class="image-ai-progress-fill" style="width: {progress * 100:.1f}%;"></div>
</div>
<div class="image-ai-progress-text">{text}</div>
</div>'''
def generate(state):
def generate(state, save_images=True):
"""
Generate images using the loaded model.
Automatically adjusts parameters based on pipeline type.
@ -720,7 +795,7 @@ def generate(state):
device = get_device()
if device is None:
device = "cpu"
generator = torch.Generator(device).manual_seed(int(seed))
generator = torch.Generator(device)
all_images = []
@ -729,14 +804,8 @@ def generate(state):
if pipeline_type is None:
pipeline_type = get_pipeline_type(shared.image_model)
# Process Prompt
prompt = state['image_prompt']
if pipeline_type == 'qwenimage':
magic_suffix = ", Ultra HD, 4K, cinematic composition"
if magic_suffix.strip(", ") not in prompt:
prompt += magic_suffix
shared.stop_everything = False
batch_count = int(state['image_batch_count'])
@ -777,13 +846,25 @@ def generate(state):
generator.manual_seed(int(seed + batch_idx))
# Generate prompt variation if enabled
if state['image_llm_variations']:
gen_kwargs["prompt"] = generate_prompt_variation(state)
# Run generation in thread so we can yield progress
result_holder = []
error_holder = []
def run_batch():
try:
# Apply magic suffix only at generation time for qwenimage
clean_prompt = gen_kwargs["prompt"]
if pipeline_type == 'qwenimage':
magic_suffix = ", Ultra HD, 4K, cinematic composition"
if magic_suffix.strip(", ") not in clean_prompt:
gen_kwargs["prompt"] = clean_prompt + magic_suffix
result_holder.extend(shared.image_model(**gen_kwargs).images)
gen_kwargs["prompt"] = clean_prompt # restore
except Exception as e:
error_holder.append(e)
@ -806,11 +887,18 @@ def generate(state):
if error_holder:
raise error_holder[0]
# Save this batch's images with the actual prompt and seed used
if save_images:
batch_seed = seed + batch_idx
original_prompt = state['image_prompt']
state['image_prompt'] = gen_kwargs["prompt"]
save_generated_images(result_holder, state, batch_seed)
state['image_prompt'] = original_prompt
all_images.extend(result_holder)
yield all_images, progress_bar_html((batch_idx + 1) / batch_count, f"Batch {batch_idx + 1}/{batch_count} complete")
t1 = time.time()
save_generated_images(all_images, state, seed)
total_images = batch_count * int(state['image_batch_size'])
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')