mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Compare commits
18 commits
7fb9f19bd8
...
5848c7884d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5848c7884d | ||
|
|
c11c14590a | ||
|
|
0dd468245c | ||
|
|
b63d57158d | ||
|
|
afa29b9554 | ||
|
|
8eac99599a | ||
|
|
b4f06a50b0 | ||
|
|
15c6e43597 | ||
|
|
56f2a9512f | ||
|
|
3ef428efaa | ||
|
|
c7ad28a4cd | ||
|
|
b451bac082 | ||
|
|
47a0fcd614 | ||
|
|
ac31a7c008 | ||
|
|
a90739f498 | ||
|
|
ffef3c7b1d | ||
|
|
5763947c37 | ||
|
|
2793153717 |
|
|
@ -28,8 +28,7 @@ A Gradio web UI for Large Language Models.
|
|||
- 100% offline and private, with zero telemetry, external resources, or remote update requests.
|
||||
- **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents.
|
||||
- **Vision (multimodal models)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)).
|
||||
Image generation: A dedicated tab for diffusers models like Z-Image-Turbo and Qwen-Image. Features 4-bit/8-bit quantization and a persistent gallery with metadata (tutorial).
|
||||
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo** and **Qwen-Image**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
|
||||
- **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)).
|
||||
- **Web search**: Optionally search the internet with LLM-generated queries to add context to the conversation.
|
||||
- Aesthetic UI with dark and light themes.
|
||||
- Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions.
|
||||
|
|
|
|||
10
css/main.css
10
css/main.css
|
|
@ -1692,8 +1692,8 @@ button#swap-height-width {
|
|||
}
|
||||
|
||||
#image-output-gallery, #image-output-gallery > :nth-child(2) {
|
||||
height: calc(100vh - 83px);
|
||||
max-height: calc(100vh - 83px);
|
||||
height: calc(100vh - 66px);
|
||||
max-height: calc(100vh - 66px);
|
||||
}
|
||||
|
||||
#image-history-gallery, #image-history-gallery > :nth-child(2) {
|
||||
|
|
@ -1791,3 +1791,9 @@ button#swap-height-width {
|
|||
.dark #image-progress .image-ai-progress-text {
|
||||
color: #888;
|
||||
}
|
||||
|
||||
#llm-prompt-variations {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: calc(100% - 174px);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,120 +4,61 @@ OpenAI-compatible image generation using local diffusion models.
|
|||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
from extensions.openai.errors import ServiceUnavailableError
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
|
||||
def generations(prompt: str, size: str, response_format: str, n: int,
|
||||
negative_prompt: str = "", steps: int = 9, seed: int = -1,
|
||||
cfg_scale: float = 0.0, batch_count: int = 1):
|
||||
def generations(request):
|
||||
"""
|
||||
Generate images using the loaded diffusion model.
|
||||
|
||||
Args:
|
||||
prompt: Text description of the desired image
|
||||
size: Image dimensions as "WIDTHxHEIGHT"
|
||||
response_format: 'url' or 'b64_json'
|
||||
n: Number of images per batch
|
||||
negative_prompt: What to avoid in the image
|
||||
steps: Number of inference steps
|
||||
seed: Random seed (-1 for random)
|
||||
cfg_scale: Classifier-free guidance scale
|
||||
batch_count: Number of sequential batches
|
||||
|
||||
Returns:
|
||||
dict with 'created' timestamp and 'data' list of images
|
||||
Returns dict with 'created' timestamp and 'data' list of images.
|
||||
"""
|
||||
import torch
|
||||
from modules.image_models import get_pipeline_type
|
||||
from modules.torch_utils import clear_torch_cache, get_device
|
||||
from modules.ui_image_generation import generate
|
||||
|
||||
if shared.image_model is None:
|
||||
raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.")
|
||||
|
||||
clear_torch_cache()
|
||||
width, height = request.get_width_height()
|
||||
|
||||
# Parse dimensions
|
||||
try:
|
||||
width, height = [int(x) for x in size.split('x')]
|
||||
except (ValueError, IndexError):
|
||||
width, height = 1024, 1024
|
||||
# Build state dict: GenerationOptions fields + image-specific keys
|
||||
state = request.model_dump()
|
||||
state.update({
|
||||
'image_model_menu': shared.image_model_name,
|
||||
'image_prompt': request.prompt,
|
||||
'image_neg_prompt': request.negative_prompt,
|
||||
'image_width': width,
|
||||
'image_height': height,
|
||||
'image_steps': request.steps,
|
||||
'image_seed': request.image_seed,
|
||||
'image_batch_size': request.batch_size,
|
||||
'image_batch_count': request.batch_count,
|
||||
'image_cfg_scale': request.cfg_scale,
|
||||
'image_llm_variations': False,
|
||||
})
|
||||
|
||||
# Handle seed
|
||||
if seed == -1:
|
||||
seed = np.random.randint(0, 2**32 - 1)
|
||||
# Exhaust generator, keep final result
|
||||
images = []
|
||||
for images, _ in generate(state, save_images=False):
|
||||
pass
|
||||
|
||||
device = get_device() or "cpu"
|
||||
generator = torch.Generator(device).manual_seed(int(seed))
|
||||
|
||||
# Get pipeline type for CFG parameter name
|
||||
pipeline_type = getattr(shared, 'image_pipeline_type', None) or get_pipeline_type(shared.image_model)
|
||||
|
||||
# Build generation kwargs
|
||||
gen_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"num_inference_steps": steps,
|
||||
"num_images_per_prompt": n,
|
||||
"generator": generator,
|
||||
}
|
||||
|
||||
# Pipeline-specific CFG parameter
|
||||
if pipeline_type == 'qwenimage':
|
||||
gen_kwargs["true_cfg_scale"] = cfg_scale
|
||||
else:
|
||||
gen_kwargs["guidance_scale"] = cfg_scale
|
||||
|
||||
# Generate
|
||||
all_images = []
|
||||
t0 = time.time()
|
||||
|
||||
shared.stop_everything = False
|
||||
|
||||
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
|
||||
if shared.stop_everything:
|
||||
pipe._interrupt = True
|
||||
return callback_kwargs
|
||||
|
||||
gen_kwargs["callback_on_step_end"] = interrupt_callback
|
||||
|
||||
for i in range(batch_count):
|
||||
if shared.stop_everything:
|
||||
break
|
||||
generator.manual_seed(int(seed + i))
|
||||
batch_results = shared.image_model(**gen_kwargs).images
|
||||
all_images.extend(batch_results)
|
||||
|
||||
t1 = time.time()
|
||||
total_images = len(all_images)
|
||||
total_steps = steps * batch_count
|
||||
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
|
||||
|
||||
# Save images
|
||||
_save_images(all_images, prompt, negative_prompt, width, height, steps, seed, cfg_scale)
|
||||
if not images:
|
||||
raise ServiceUnavailableError("Image generation failed or produced no images.")
|
||||
|
||||
# Build response
|
||||
resp = {
|
||||
'created': int(time.time()),
|
||||
'data': []
|
||||
}
|
||||
|
||||
for img in all_images:
|
||||
resp = {'created': int(time.time()), 'data': []}
|
||||
for img in images:
|
||||
b64 = _image_to_base64(img)
|
||||
if response_format == 'b64_json':
|
||||
resp['data'].append({'b64_json': b64})
|
||||
|
||||
image_obj = {'revised_prompt': request.prompt}
|
||||
|
||||
if request.response_format == 'b64_json':
|
||||
image_obj['b64_json'] = b64
|
||||
else:
|
||||
resp['data'].append({'url': f'data:image/png;base64,{b64}'})
|
||||
image_obj['url'] = f'data:image/png;base64,{b64}'
|
||||
|
||||
resp['data'].append(image_obj)
|
||||
|
||||
return resp
|
||||
|
||||
|
|
@ -126,29 +67,3 @@ def _image_to_base64(image) -> str:
|
|||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
|
||||
def _save_images(images, prompt, negative_prompt, width, height, steps, seed, cfg_scale):
|
||||
"""Save images with metadata."""
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
folder = os.path.join("user_data", "image_outputs", date_str)
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
metadata = {
|
||||
'image_prompt': prompt,
|
||||
'image_neg_prompt': negative_prompt,
|
||||
'image_width': width,
|
||||
'image_height': height,
|
||||
'image_steps': steps,
|
||||
'image_seed': seed,
|
||||
'image_cfg_scale': cfg_scale,
|
||||
'model': getattr(shared, 'image_model_name', 'unknown'),
|
||||
}
|
||||
|
||||
for idx, img in enumerate(images):
|
||||
ts = datetime.now().strftime("%H-%M-%S")
|
||||
filepath = os.path.join(folder, f"{ts}_{seed:010d}_{idx:03d}.png")
|
||||
|
||||
png_info = PngInfo()
|
||||
png_info.add_text("image_gen_settings", json.dumps(metadata))
|
||||
img.save(filepath, pnginfo=png_info)
|
||||
|
|
|
|||
|
|
@ -7,23 +7,24 @@ import traceback
|
|||
from collections import deque
|
||||
from threading import Thread
|
||||
|
||||
import extensions.openai.completions as OAIcompletions
|
||||
import extensions.openai.logits as OAIlogits
|
||||
import extensions.openai.models as OAImodels
|
||||
import uvicorn
|
||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||
from extensions.openai.utils import _start_cloudflared
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydub import AudioSegment
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
import extensions.openai.completions as OAIcompletions
|
||||
import extensions.openai.logits as OAIlogits
|
||||
import extensions.openai.models as OAImodels
|
||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||
from extensions.openai.utils import _start_cloudflared
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import unload_model
|
||||
from modules.text_generation import stop_everything_event
|
||||
from pydub import AudioSegment
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
from .typing import (
|
||||
ChatCompletionRequest,
|
||||
|
|
@ -232,20 +233,7 @@ async def handle_image_generation(request_data: ImageGenerationRequest):
|
|||
import extensions.openai.images as OAIimages
|
||||
|
||||
async with image_generation_semaphore:
|
||||
width, height = request_data.get_width_height()
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
OAIimages.generations,
|
||||
prompt=request_data.prompt,
|
||||
size=f"{width}x{height}",
|
||||
response_format=request_data.response_format,
|
||||
n=request_data.batch_size, # <-- use resolved batch_size
|
||||
negative_prompt=request_data.negative_prompt,
|
||||
steps=request_data.steps,
|
||||
seed=request_data.seed,
|
||||
cfg_scale=request_data.cfg_scale,
|
||||
batch_count=request_data.batch_count,
|
||||
)
|
||||
response = await asyncio.to_thread(OAIimages.generations, request_data)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -265,16 +265,13 @@ class LoadLorasRequest(BaseModel):
|
|||
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
"""OpenAI-compatible image generation request with extended parameters."""
|
||||
# Required
|
||||
"""Image-specific parameters for generation."""
|
||||
prompt: str
|
||||
|
||||
# Generation parameters
|
||||
negative_prompt: str = ""
|
||||
size: str = Field(default="1024x1024", description="'WIDTHxHEIGHT'")
|
||||
steps: int = Field(default=9, ge=1)
|
||||
cfg_scale: float = Field(default=0.0, ge=0.0)
|
||||
seed: int = Field(default=-1, description="-1 for random")
|
||||
image_seed: int = Field(default=-1, description="-1 for random")
|
||||
batch_size: int | None = Field(default=None, ge=1, description="Parallel batch size (VRAM heavy)")
|
||||
n: int = Field(default=1, ge=1, description="Alias for batch_size (OpenAI compatibility)")
|
||||
batch_count: int = Field(default=1, ge=1, description="Sequential batch count")
|
||||
|
|
@ -286,7 +283,6 @@ class ImageGenerationRequest(BaseModel):
|
|||
|
||||
@model_validator(mode='after')
|
||||
def resolve_batch_size(self):
|
||||
"""Use batch_size if provided, otherwise fall back to n."""
|
||||
if self.batch_size is None:
|
||||
self.batch_size = self.n
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -112,7 +112,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
add_generation_prompt=False,
|
||||
enable_thinking=state['enable_thinking'],
|
||||
reasoning_effort=state['reasoning_effort'],
|
||||
thinking_budget=-1 if state.get('enable_thinking', True) else 0
|
||||
thinking_budget=-1 if state.get('enable_thinking', True) else 0,
|
||||
bos_token=shared.bos_token,
|
||||
eos_token=shared.eos_token,
|
||||
)
|
||||
|
||||
chat_renderer = partial(
|
||||
|
|
@ -475,7 +477,7 @@ def get_stopping_strings(state):
|
|||
|
||||
if state['mode'] in ['instruct', 'chat-instruct']:
|
||||
template = jinja_env.from_string(state['instruction_template_str'])
|
||||
renderer = partial(template.render, add_generation_prompt=False)
|
||||
renderer = partial(template.render, add_generation_prompt=False, bos_token=shared.bos_token, eos_token=shared.eos_token)
|
||||
renderers.append(renderer)
|
||||
|
||||
if state['mode'] in ['chat']:
|
||||
|
|
|
|||
|
|
@ -141,16 +141,24 @@ def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offl
|
|||
if not cpu_offload:
|
||||
pipe.to(get_device())
|
||||
|
||||
# Set attention backend (if supported by the pipeline)
|
||||
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'set_attention_backend'):
|
||||
if attn_backend == 'flash_attention_2':
|
||||
pipe.transformer.set_attention_backend("flash")
|
||||
# sdpa is the default, no action needed
|
||||
modules = ["transformer", "unet"]
|
||||
|
||||
# Set attention backend
|
||||
if attn_backend == 'flash_attention_2':
|
||||
for name in modules:
|
||||
mod = getattr(pipe, name, None)
|
||||
if hasattr(mod, "set_attention_backend"):
|
||||
mod.set_attention_backend("flash")
|
||||
break
|
||||
|
||||
# Compile model
|
||||
if compile_model:
|
||||
if hasattr(pipe, 'transformer') and hasattr(pipe.transformer, 'compile'):
|
||||
logger.info("Compiling model (first run will be slow)...")
|
||||
pipe.transformer.compile()
|
||||
for name in modules:
|
||||
mod = getattr(pipe, name, None)
|
||||
if hasattr(mod, "compile"):
|
||||
logger.info("Compiling model (first run will be slow)...")
|
||||
mod.compile()
|
||||
break
|
||||
|
||||
if cpu_offload:
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
|
|
|||
|
|
@ -89,8 +89,9 @@ def get_model_metadata(model):
|
|||
else:
|
||||
bos_token = ""
|
||||
|
||||
template = template.replace('eos_token', "'{}'".format(eos_token))
|
||||
template = template.replace('bos_token', "'{}'".format(bos_token))
|
||||
|
||||
shared.bos_token = bos_token
|
||||
shared.eos_token = eos_token
|
||||
|
||||
template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL)
|
||||
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
|
||||
|
|
@ -160,13 +161,16 @@ def get_model_metadata(model):
|
|||
|
||||
# 4. If a template was found from any source, process it
|
||||
if template:
|
||||
shared.bos_token = '<s>'
|
||||
shared.eos_token = '</s>'
|
||||
|
||||
for k in ['eos_token', 'bos_token']:
|
||||
if k in metadata:
|
||||
value = metadata[k]
|
||||
if isinstance(value, dict):
|
||||
value = value['content']
|
||||
|
||||
template = template.replace(k, "'{}'".format(value))
|
||||
setattr(shared, k, value)
|
||||
|
||||
template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL)
|
||||
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ is_seq2seq = False
|
|||
is_multimodal = False
|
||||
model_dirty_from_training = False
|
||||
lora_names = []
|
||||
bos_token = '<s>'
|
||||
eos_token = '</s>'
|
||||
|
||||
# Image model variables
|
||||
image_model = None
|
||||
|
|
@ -319,6 +321,8 @@ settings = {
|
|||
'image_seed': -1,
|
||||
'image_batch_size': 1,
|
||||
'image_batch_count': 1,
|
||||
'image_llm_variations': False,
|
||||
'image_llm_variations_prompt': 'Write a variation of the image generation prompt above. Consider the intent of the user with that prompt and write something that will likely please them, with added details. Output only the new prompt. Do not add any explanations, prefixes, or additional text.',
|
||||
'image_model_menu': 'None',
|
||||
'image_dtype': 'bfloat16',
|
||||
'image_attn_backend': 'sdpa',
|
||||
|
|
|
|||
|
|
@ -293,6 +293,8 @@ def list_interface_input_elements():
|
|||
'image_seed',
|
||||
'image_batch_size',
|
||||
'image_batch_count',
|
||||
'image_llm_variations',
|
||||
'image_llm_variations_prompt',
|
||||
'image_model_menu',
|
||||
'image_dtype',
|
||||
'image_attn_backend',
|
||||
|
|
@ -547,6 +549,8 @@ def setup_auto_save():
|
|||
'image_seed',
|
||||
'image_batch_size',
|
||||
'image_batch_count',
|
||||
'image_llm_variations',
|
||||
'image_llm_variations_prompt',
|
||||
'image_model_menu',
|
||||
'image_dtype',
|
||||
'image_attn_backend',
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from modules.image_models import (
|
|||
from modules.image_utils import open_image_safely
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import stop_everything_event
|
||||
from modules.utils import gradio
|
||||
from modules.utils import check_model_loaded, gradio
|
||||
|
||||
ASPECT_RATIOS = {
|
||||
"1:1 Square": (1, 1),
|
||||
|
|
@ -40,8 +40,6 @@ METADATA_SETTINGS_KEYS = [
|
|||
'image_aspect_ratio',
|
||||
'image_steps',
|
||||
'image_seed',
|
||||
'image_batch_size',
|
||||
'image_batch_count',
|
||||
'image_cfg_scale',
|
||||
]
|
||||
|
||||
|
|
@ -136,6 +134,9 @@ def build_generation_metadata(state, actual_seed):
|
|||
|
||||
def save_generated_images(images, state, actual_seed):
|
||||
"""Save images with generation metadata embedded in PNG."""
|
||||
if shared.args.multi_user:
|
||||
return
|
||||
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
folder_path = os.path.join("user_data", "image_outputs", date_str)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
|
@ -145,7 +146,7 @@ def save_generated_images(images, state, actual_seed):
|
|||
|
||||
for idx, img in enumerate(images):
|
||||
timestamp = datetime.now().strftime("%H-%M-%S")
|
||||
filename = f"{timestamp}_{actual_seed:010d}_{idx:03d}.png"
|
||||
filename = f"TGW_{timestamp}_{actual_seed:010d}_{idx:03d}.png"
|
||||
filepath = os.path.join(folder_path, filename)
|
||||
|
||||
# Create PNG metadata
|
||||
|
|
@ -159,9 +160,14 @@ def save_generated_images(images, state, actual_seed):
|
|||
def read_image_metadata(image_path):
|
||||
"""Read generation metadata from PNG file."""
|
||||
try:
|
||||
with open_image_safely(image_path) as img:
|
||||
img = open_image_safely(image_path)
|
||||
if img is None:
|
||||
return None
|
||||
try:
|
||||
if hasattr(img, 'text') and 'image_gen_settings' in img.text:
|
||||
return json.loads(img.text['image_gen_settings'])
|
||||
finally:
|
||||
img.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not read metadata from {image_path}: {e}")
|
||||
return None
|
||||
|
|
@ -184,8 +190,6 @@ def format_metadata_for_display(metadata):
|
|||
('image_steps', 'Steps'),
|
||||
('image_cfg_scale', 'CFG Scale'),
|
||||
('image_seed', 'Seed'),
|
||||
('image_batch_size', 'Batch Size'),
|
||||
('image_batch_count', 'Batch Count'),
|
||||
('model', 'Model'),
|
||||
('generated_at', 'Generated At'),
|
||||
]
|
||||
|
|
@ -293,7 +297,6 @@ def on_gallery_select(evt: gr.SelectData, current_page):
|
|||
if not _image_cache:
|
||||
get_all_history_images()
|
||||
|
||||
# Get the current page's images to find the actual file path
|
||||
all_images = _image_cache
|
||||
total_images = len(all_images)
|
||||
|
||||
|
|
@ -314,11 +317,11 @@ def on_gallery_select(evt: gr.SelectData, current_page):
|
|||
def send_to_generate(selected_image_path):
|
||||
"""Load settings from selected image and return updates for all Generate tab inputs."""
|
||||
if not selected_image_path or not os.path.exists(selected_image_path):
|
||||
return [gr.update()] * 10 + ["No image selected"]
|
||||
return [gr.update()] * 8 + ["No image selected"]
|
||||
|
||||
metadata = read_image_metadata(selected_image_path)
|
||||
if not metadata:
|
||||
return [gr.update()] * 10 + ["No settings found in this image"]
|
||||
return [gr.update()] * 8 + ["No settings found in this image"]
|
||||
|
||||
# Return updates for each input element in order
|
||||
updates = [
|
||||
|
|
@ -329,8 +332,6 @@ def send_to_generate(selected_image_path):
|
|||
gr.update(value=metadata.get('image_aspect_ratio', '1:1 Square')),
|
||||
gr.update(value=metadata.get('image_steps', 9)),
|
||||
gr.update(value=metadata.get('image_seed', -1)),
|
||||
gr.update(value=metadata.get('image_batch_size', 1)),
|
||||
gr.update(value=metadata.get('image_batch_count', 1)),
|
||||
gr.update(value=metadata.get('image_cfg_scale', 0.0)),
|
||||
]
|
||||
|
||||
|
|
@ -370,6 +371,19 @@ def create_ui():
|
|||
lines=3,
|
||||
value=shared.settings['image_neg_prompt']
|
||||
)
|
||||
shared.gradio['image_llm_variations'] = gr.Checkbox(
|
||||
value=shared.settings['image_llm_variations'],
|
||||
label='LLM Prompt Variations',
|
||||
elem_id="llm-prompt-variations",
|
||||
)
|
||||
shared.gradio['image_llm_variations_prompt'] = gr.Textbox(
|
||||
value=shared.settings['image_llm_variations_prompt'],
|
||||
label='Variation Prompt',
|
||||
lines=3,
|
||||
placeholder='Instructions for generating prompt variations...',
|
||||
visible=shared.settings['image_llm_variations'],
|
||||
info='Use the loaded LLM to generate creative prompt variations for each sequential batch.'
|
||||
)
|
||||
|
||||
shared.gradio['image_generate_btn'] = gr.Button("Generate", variant="primary", size="lg")
|
||||
shared.gradio['image_stop_btn'] = gr.Button("Stop", size="lg", visible=False)
|
||||
|
|
@ -406,6 +420,7 @@ def create_ui():
|
|||
info="Z-Image Turbo: 0.0 | Qwen: 4.0"
|
||||
)
|
||||
shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random")
|
||||
|
||||
with gr.Column():
|
||||
shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.")
|
||||
shared.gradio['image_batch_count'] = gr.Slider(1, 128, value=shared.settings['image_batch_count'], step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.")
|
||||
|
|
@ -647,8 +662,6 @@ def create_event_handlers():
|
|||
'image_aspect_ratio',
|
||||
'image_steps',
|
||||
'image_seed',
|
||||
'image_batch_size',
|
||||
'image_batch_count',
|
||||
'image_cfg_scale',
|
||||
'image_gallery_status'
|
||||
),
|
||||
|
|
@ -663,6 +676,68 @@ def create_event_handlers():
|
|||
show_progress=False
|
||||
)
|
||||
|
||||
# LLM Variations visibility toggle
|
||||
shared.gradio['image_llm_variations'].change(
|
||||
lambda x: gr.update(visible=x),
|
||||
gradio('image_llm_variations'),
|
||||
gradio('image_llm_variations_prompt'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
|
||||
def generate_prompt_variation(state):
|
||||
"""Generate a creative variation of the image prompt using the LLM."""
|
||||
from modules.chat import generate_chat_prompt
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
prompt = state['image_prompt']
|
||||
|
||||
# Check if LLM is loaded
|
||||
model_loaded, _ = check_model_loaded()
|
||||
if not model_loaded:
|
||||
logger.warning("No LLM loaded for prompt variation. Using original prompt.")
|
||||
return prompt
|
||||
|
||||
# Get the custom variation prompt or use default
|
||||
variation_instruction = state.get('image_llm_variations_prompt', '')
|
||||
if not variation_instruction:
|
||||
variation_instruction = 'Write a variation of the image generation prompt above. Consider the intent of the user with that prompt and write something that will likely please them, with added details. Output only the new prompt. Do not add any explanations, prefixes, or additional text.'
|
||||
|
||||
augmented_message = f"{prompt}\n\n=====\n\n{variation_instruction}"
|
||||
|
||||
# Use minimal state for generation
|
||||
var_state = state.copy()
|
||||
var_state['history'] = {'internal': [], 'visible': [], 'metadata': {}}
|
||||
var_state['auto_max_new_tokens'] = True
|
||||
var_state['enable_thinking'] = False
|
||||
var_state['reasoning_effort'] = 'low'
|
||||
var_state['start_with'] = ""
|
||||
|
||||
formatted_prompt = generate_chat_prompt(augmented_message, var_state)
|
||||
|
||||
variation = ""
|
||||
for reply in generate_reply(formatted_prompt, var_state, stopping_strings=[], is_chat=True):
|
||||
variation = reply
|
||||
|
||||
# Strip thinking blocks if present
|
||||
if "</think>" in variation:
|
||||
variation = variation.rsplit("</think>", 1)[1]
|
||||
elif "<|start|>assistant<|channel|>final<|message|>" in variation:
|
||||
variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1]
|
||||
elif "</seed:think>" in variation:
|
||||
variation = variation.rsplit("</seed:think>", 1)[1]
|
||||
|
||||
variation = variation.strip()
|
||||
if len(variation) >= 2 and variation.startswith('"') and variation.endswith('"'):
|
||||
variation = variation[1:-1]
|
||||
|
||||
if variation:
|
||||
logger.info("Prompt variation:")
|
||||
print(variation)
|
||||
return variation
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def progress_bar_html(progress=0, text=""):
|
||||
"""Generate HTML for progress bar. Empty div when progress <= 0."""
|
||||
|
|
@ -671,13 +746,13 @@ def progress_bar_html(progress=0, text=""):
|
|||
|
||||
return f'''<div class="image-ai-progress-wrapper">
|
||||
<div class="image-ai-progress-track">
|
||||
<div class="image-ai-progress-fill" style="width: {progress*100:.1f}%;"></div>
|
||||
<div class="image-ai-progress-fill" style="width: {progress * 100:.1f}%;"></div>
|
||||
</div>
|
||||
<div class="image-ai-progress-text">{text}</div>
|
||||
</div>'''
|
||||
|
||||
|
||||
def generate(state):
|
||||
def generate(state, save_images=True):
|
||||
"""
|
||||
Generate images using the loaded model.
|
||||
Automatically adjusts parameters based on pipeline type.
|
||||
|
|
@ -720,7 +795,7 @@ def generate(state):
|
|||
device = get_device()
|
||||
if device is None:
|
||||
device = "cpu"
|
||||
generator = torch.Generator(device).manual_seed(int(seed))
|
||||
generator = torch.Generator(device)
|
||||
|
||||
all_images = []
|
||||
|
||||
|
|
@ -729,14 +804,8 @@ def generate(state):
|
|||
if pipeline_type is None:
|
||||
pipeline_type = get_pipeline_type(shared.image_model)
|
||||
|
||||
# Process Prompt
|
||||
prompt = state['image_prompt']
|
||||
|
||||
if pipeline_type == 'qwenimage':
|
||||
magic_suffix = ", Ultra HD, 4K, cinematic composition"
|
||||
if magic_suffix.strip(", ") not in prompt:
|
||||
prompt += magic_suffix
|
||||
|
||||
shared.stop_everything = False
|
||||
|
||||
batch_count = int(state['image_batch_count'])
|
||||
|
|
@ -777,13 +846,25 @@ def generate(state):
|
|||
|
||||
generator.manual_seed(int(seed + batch_idx))
|
||||
|
||||
# Generate prompt variation if enabled
|
||||
if state['image_llm_variations']:
|
||||
gen_kwargs["prompt"] = generate_prompt_variation(state)
|
||||
|
||||
# Run generation in thread so we can yield progress
|
||||
result_holder = []
|
||||
error_holder = []
|
||||
|
||||
def run_batch():
|
||||
try:
|
||||
# Apply magic suffix only at generation time for qwenimage
|
||||
clean_prompt = gen_kwargs["prompt"]
|
||||
if pipeline_type == 'qwenimage':
|
||||
magic_suffix = ", Ultra HD, 4K, cinematic composition"
|
||||
if magic_suffix.strip(", ") not in clean_prompt:
|
||||
gen_kwargs["prompt"] = clean_prompt + magic_suffix
|
||||
|
||||
result_holder.extend(shared.image_model(**gen_kwargs).images)
|
||||
gen_kwargs["prompt"] = clean_prompt # restore
|
||||
except Exception as e:
|
||||
error_holder.append(e)
|
||||
|
||||
|
|
@ -806,11 +887,18 @@ def generate(state):
|
|||
if error_holder:
|
||||
raise error_holder[0]
|
||||
|
||||
# Save this batch's images with the actual prompt and seed used
|
||||
if save_images:
|
||||
batch_seed = seed + batch_idx
|
||||
original_prompt = state['image_prompt']
|
||||
state['image_prompt'] = gen_kwargs["prompt"]
|
||||
save_generated_images(result_holder, state, batch_seed)
|
||||
state['image_prompt'] = original_prompt
|
||||
|
||||
all_images.extend(result_holder)
|
||||
yield all_images, progress_bar_html((batch_idx + 1) / batch_count, f"Batch {batch_idx + 1}/{batch_count} complete")
|
||||
|
||||
t1 = time.time()
|
||||
save_generated_images(all_images, state, seed)
|
||||
|
||||
total_images = batch_count * int(state['image_batch_size'])
|
||||
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
|
||||
|
|
|
|||
Loading…
Reference in a new issue