mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-20 15:40:23 +01:00
Image: Simplify the API code, add the llm_variations option
This commit is contained in:
parent
2793153717
commit
5763947c37
|
|
@ -4,117 +4,50 @@ OpenAI-compatible image generation using local diffusion models.
|
|||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
from extensions.openai.errors import ServiceUnavailableError
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
|
||||
def generations(prompt: str, size: str, response_format: str, n: int,
|
||||
negative_prompt: str = "", steps: int = 9, seed: int = -1,
|
||||
cfg_scale: float = 0.0, batch_count: int = 1):
|
||||
def generations(request):
|
||||
"""
|
||||
Generate images using the loaded diffusion model.
|
||||
|
||||
Args:
|
||||
prompt: Text description of the desired image
|
||||
size: Image dimensions as "WIDTHxHEIGHT"
|
||||
response_format: 'url' or 'b64_json'
|
||||
n: Number of images per batch
|
||||
negative_prompt: What to avoid in the image
|
||||
steps: Number of inference steps
|
||||
seed: Random seed (-1 for random)
|
||||
cfg_scale: Classifier-free guidance scale
|
||||
batch_count: Number of sequential batches
|
||||
|
||||
Returns:
|
||||
dict with 'created' timestamp and 'data' list of images
|
||||
Returns dict with 'created' timestamp and 'data' list of images.
|
||||
"""
|
||||
import torch
|
||||
from modules.image_models import get_pipeline_type
|
||||
from modules.torch_utils import clear_torch_cache, get_device
|
||||
from modules.ui_image_generation import generate
|
||||
|
||||
if shared.image_model is None:
|
||||
raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.")
|
||||
|
||||
clear_torch_cache()
|
||||
width, height = request.get_width_height()
|
||||
|
||||
# Parse dimensions
|
||||
try:
|
||||
width, height = [int(x) for x in size.split('x')]
|
||||
except (ValueError, IndexError):
|
||||
width, height = 1024, 1024
|
||||
# Build state dict: GenerationOptions fields + image-specific keys
|
||||
state = request.model_dump()
|
||||
state.update({
|
||||
'image_model_menu': shared.image_model_name,
|
||||
'image_prompt': request.prompt,
|
||||
'image_neg_prompt': request.negative_prompt,
|
||||
'image_width': width,
|
||||
'image_height': height,
|
||||
'image_steps': request.steps,
|
||||
'image_seed': request.image_seed,
|
||||
'image_batch_size': request.batch_size,
|
||||
'image_batch_count': request.batch_count,
|
||||
'image_cfg_scale': request.cfg_scale,
|
||||
'image_llm_variations': request.llm_variations,
|
||||
})
|
||||
|
||||
# Handle seed
|
||||
if seed == -1:
|
||||
seed = np.random.randint(0, 2**32 - 1)
|
||||
|
||||
device = get_device() or "cpu"
|
||||
generator = torch.Generator(device).manual_seed(int(seed))
|
||||
|
||||
# Get pipeline type for CFG parameter name
|
||||
pipeline_type = getattr(shared, 'image_pipeline_type', None) or get_pipeline_type(shared.image_model)
|
||||
|
||||
# Build generation kwargs
|
||||
gen_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"num_inference_steps": steps,
|
||||
"num_images_per_prompt": n,
|
||||
"generator": generator,
|
||||
}
|
||||
|
||||
# Pipeline-specific CFG parameter
|
||||
if pipeline_type == 'qwenimage':
|
||||
gen_kwargs["true_cfg_scale"] = cfg_scale
|
||||
else:
|
||||
gen_kwargs["guidance_scale"] = cfg_scale
|
||||
|
||||
# Generate
|
||||
all_images = []
|
||||
t0 = time.time()
|
||||
|
||||
shared.stop_everything = False
|
||||
|
||||
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
|
||||
if shared.stop_everything:
|
||||
pipe._interrupt = True
|
||||
return callback_kwargs
|
||||
|
||||
gen_kwargs["callback_on_step_end"] = interrupt_callback
|
||||
|
||||
for i in range(batch_count):
|
||||
if shared.stop_everything:
|
||||
break
|
||||
generator.manual_seed(int(seed + i))
|
||||
batch_results = shared.image_model(**gen_kwargs).images
|
||||
all_images.extend(batch_results)
|
||||
|
||||
t1 = time.time()
|
||||
total_images = len(all_images)
|
||||
total_steps = steps * batch_count
|
||||
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
|
||||
|
||||
# Save images
|
||||
_save_images(all_images, prompt, negative_prompt, width, height, steps, seed, cfg_scale)
|
||||
# Exhaust generator, keep final result
|
||||
images = []
|
||||
for images, _ in generate(state, save_images=False):
|
||||
pass
|
||||
|
||||
# Build response
|
||||
resp = {
|
||||
'created': int(time.time()),
|
||||
'data': []
|
||||
}
|
||||
|
||||
for img in all_images:
|
||||
resp = {'created': int(time.time()), 'data': []}
|
||||
for img in images:
|
||||
b64 = _image_to_base64(img)
|
||||
if response_format == 'b64_json':
|
||||
if request.response_format == 'b64_json':
|
||||
resp['data'].append({'b64_json': b64})
|
||||
else:
|
||||
resp['data'].append({'url': f'data:image/png;base64,{b64}'})
|
||||
|
|
@ -126,29 +59,3 @@ def _image_to_base64(image) -> str:
|
|||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
|
||||
def _save_images(images, prompt, negative_prompt, width, height, steps, seed, cfg_scale):
|
||||
"""Save images with metadata."""
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
folder = os.path.join("user_data", "image_outputs", date_str)
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
metadata = {
|
||||
'image_prompt': prompt,
|
||||
'image_neg_prompt': negative_prompt,
|
||||
'image_width': width,
|
||||
'image_height': height,
|
||||
'image_steps': steps,
|
||||
'image_seed': seed,
|
||||
'image_cfg_scale': cfg_scale,
|
||||
'model': getattr(shared, 'image_model_name', 'unknown'),
|
||||
}
|
||||
|
||||
for idx, img in enumerate(images):
|
||||
ts = datetime.now().strftime("%H-%M-%S")
|
||||
filepath = os.path.join(folder, f"{ts}_{seed:010d}_{idx:03d}.png")
|
||||
|
||||
png_info = PngInfo()
|
||||
png_info.add_text("image_gen_settings", json.dumps(metadata))
|
||||
img.save(filepath, pnginfo=png_info)
|
||||
|
|
|
|||
|
|
@ -7,23 +7,24 @@ import traceback
|
|||
from collections import deque
|
||||
from threading import Thread
|
||||
|
||||
import extensions.openai.completions as OAIcompletions
|
||||
import extensions.openai.logits as OAIlogits
|
||||
import extensions.openai.models as OAImodels
|
||||
import uvicorn
|
||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||
from extensions.openai.utils import _start_cloudflared
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydub import AudioSegment
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
import extensions.openai.completions as OAIcompletions
|
||||
import extensions.openai.logits as OAIlogits
|
||||
import extensions.openai.models as OAImodels
|
||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||
from extensions.openai.utils import _start_cloudflared
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import unload_model
|
||||
from modules.text_generation import stop_everything_event
|
||||
from pydub import AudioSegment
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
from .typing import (
|
||||
ChatCompletionRequest,
|
||||
|
|
@ -232,20 +233,7 @@ async def handle_image_generation(request_data: ImageGenerationRequest):
|
|||
import extensions.openai.images as OAIimages
|
||||
|
||||
async with image_generation_semaphore:
|
||||
width, height = request_data.get_width_height()
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
OAIimages.generations,
|
||||
prompt=request_data.prompt,
|
||||
size=f"{width}x{height}",
|
||||
response_format=request_data.response_format,
|
||||
n=request_data.batch_size, # <-- use resolved batch_size
|
||||
negative_prompt=request_data.negative_prompt,
|
||||
steps=request_data.steps,
|
||||
seed=request_data.seed,
|
||||
cfg_scale=request_data.cfg_scale,
|
||||
batch_count=request_data.batch_count,
|
||||
)
|
||||
response = await asyncio.to_thread(OAIimages.generations, request_data)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -264,20 +264,18 @@ class LoadLorasRequest(BaseModel):
|
|||
lora_names: List[str]
|
||||
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
"""OpenAI-compatible image generation request with extended parameters."""
|
||||
# Required
|
||||
class ImageGenerationRequestParams(BaseModel):
|
||||
"""Image-specific parameters for generation."""
|
||||
prompt: str
|
||||
|
||||
# Generation parameters
|
||||
negative_prompt: str = ""
|
||||
size: str = Field(default="1024x1024", description="'WIDTHxHEIGHT'")
|
||||
steps: int = Field(default=9, ge=1)
|
||||
cfg_scale: float = Field(default=0.0, ge=0.0)
|
||||
seed: int = Field(default=-1, description="-1 for random")
|
||||
image_seed: int = Field(default=-1, description="-1 for random")
|
||||
batch_size: int | None = Field(default=None, ge=1, description="Parallel batch size (VRAM heavy)")
|
||||
n: int = Field(default=1, ge=1, description="Alias for batch_size (OpenAI compatibility)")
|
||||
batch_count: int = Field(default=1, ge=1, description="Sequential batch count")
|
||||
llm_variations: bool = False
|
||||
|
||||
# OpenAI compatibility (unused)
|
||||
model: str | None = None
|
||||
|
|
@ -286,7 +284,6 @@ class ImageGenerationRequest(BaseModel):
|
|||
|
||||
@model_validator(mode='after')
|
||||
def resolve_batch_size(self):
|
||||
"""Use batch_size if provided, otherwise fall back to n."""
|
||||
if self.batch_size is None:
|
||||
self.batch_size = self.n
|
||||
return self
|
||||
|
|
@ -299,6 +296,10 @@ class ImageGenerationRequest(BaseModel):
|
|||
return 1024, 1024
|
||||
|
||||
|
||||
class ImageGenerationRequest(GenerationOptions, ImageGenerationRequestParams):
|
||||
pass
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
created: int = int(time.time())
|
||||
data: List[dict]
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import numpy as np
|
|||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
from modules import shared, ui, utils
|
||||
from modules.utils import check_model_loaded
|
||||
from modules.image_models import (
|
||||
get_pipeline_type,
|
||||
load_image_model,
|
||||
|
|
@ -19,7 +18,7 @@ from modules.image_models import (
|
|||
from modules.image_utils import open_image_safely
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import stop_everything_event
|
||||
from modules.utils import gradio
|
||||
from modules.utils import check_model_loaded, gradio
|
||||
|
||||
ASPECT_RATIOS = {
|
||||
"1:1 Square": (1, 1),
|
||||
|
|
@ -725,13 +724,13 @@ def progress_bar_html(progress=0, text=""):
|
|||
|
||||
return f'''<div class="image-ai-progress-wrapper">
|
||||
<div class="image-ai-progress-track">
|
||||
<div class="image-ai-progress-fill" style="width: {progress*100:.1f}%;"></div>
|
||||
<div class="image-ai-progress-fill" style="width: {progress * 100:.1f}%;"></div>
|
||||
</div>
|
||||
<div class="image-ai-progress-text">{text}</div>
|
||||
</div>'''
|
||||
|
||||
|
||||
def generate(state):
|
||||
def generate(state, save_images=True):
|
||||
"""
|
||||
Generate images using the loaded model.
|
||||
Automatically adjusts parameters based on pipeline type.
|
||||
|
|
@ -868,7 +867,8 @@ def generate(state):
|
|||
yield all_images, progress_bar_html((batch_idx + 1) / batch_count, f"Batch {batch_idx + 1}/{batch_count} complete")
|
||||
|
||||
t1 = time.time()
|
||||
save_generated_images(all_images, state, seed)
|
||||
if save_images:
|
||||
save_generated_images(all_images, state, seed)
|
||||
|
||||
total_images = batch_count * int(state['image_batch_size'])
|
||||
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
|
||||
|
|
|
|||
Loading…
Reference in a new issue