Image: Simplify the API code, add the llm_variations option

This commit is contained in:
oobabooga 2025-12-04 10:23:00 -08:00
parent 2793153717
commit 5763947c37
4 changed files with 49 additions and 153 deletions

View file

@ -4,117 +4,50 @@ OpenAI-compatible image generation using local diffusion models.
import base64
import io
import json
import os
import time
from datetime import datetime
import numpy as np
from extensions.openai.errors import ServiceUnavailableError
from modules import shared
from modules.logging_colors import logger
from PIL.PngImagePlugin import PngInfo
def generations(prompt: str, size: str, response_format: str, n: int,
negative_prompt: str = "", steps: int = 9, seed: int = -1,
cfg_scale: float = 0.0, batch_count: int = 1):
def generations(request):
"""
Generate images using the loaded diffusion model.
Args:
prompt: Text description of the desired image
size: Image dimensions as "WIDTHxHEIGHT"
response_format: 'url' or 'b64_json'
n: Number of images per batch
negative_prompt: What to avoid in the image
steps: Number of inference steps
seed: Random seed (-1 for random)
cfg_scale: Classifier-free guidance scale
batch_count: Number of sequential batches
Returns:
dict with 'created' timestamp and 'data' list of images
Returns dict with 'created' timestamp and 'data' list of images.
"""
import torch
from modules.image_models import get_pipeline_type
from modules.torch_utils import clear_torch_cache, get_device
from modules.ui_image_generation import generate
if shared.image_model is None:
raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.")
clear_torch_cache()
width, height = request.get_width_height()
# Parse dimensions
try:
width, height = [int(x) for x in size.split('x')]
except (ValueError, IndexError):
width, height = 1024, 1024
# Build state dict: GenerationOptions fields + image-specific keys
state = request.model_dump()
state.update({
'image_model_menu': shared.image_model_name,
'image_prompt': request.prompt,
'image_neg_prompt': request.negative_prompt,
'image_width': width,
'image_height': height,
'image_steps': request.steps,
'image_seed': request.image_seed,
'image_batch_size': request.batch_size,
'image_batch_count': request.batch_count,
'image_cfg_scale': request.cfg_scale,
'image_llm_variations': request.llm_variations,
})
# Handle seed
if seed == -1:
seed = np.random.randint(0, 2**32 - 1)
device = get_device() or "cpu"
generator = torch.Generator(device).manual_seed(int(seed))
# Get pipeline type for CFG parameter name
pipeline_type = getattr(shared, 'image_pipeline_type', None) or get_pipeline_type(shared.image_model)
# Build generation kwargs
gen_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"num_inference_steps": steps,
"num_images_per_prompt": n,
"generator": generator,
}
# Pipeline-specific CFG parameter
if pipeline_type == 'qwenimage':
gen_kwargs["true_cfg_scale"] = cfg_scale
else:
gen_kwargs["guidance_scale"] = cfg_scale
# Generate
all_images = []
t0 = time.time()
shared.stop_everything = False
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
if shared.stop_everything:
pipe._interrupt = True
return callback_kwargs
gen_kwargs["callback_on_step_end"] = interrupt_callback
for i in range(batch_count):
if shared.stop_everything:
break
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(**gen_kwargs).images
all_images.extend(batch_results)
t1 = time.time()
total_images = len(all_images)
total_steps = steps * batch_count
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
# Save images
_save_images(all_images, prompt, negative_prompt, width, height, steps, seed, cfg_scale)
# Exhaust generator, keep final result
images = []
for images, _ in generate(state, save_images=False):
pass
# Build response
resp = {
'created': int(time.time()),
'data': []
}
for img in all_images:
resp = {'created': int(time.time()), 'data': []}
for img in images:
b64 = _image_to_base64(img)
if response_format == 'b64_json':
if request.response_format == 'b64_json':
resp['data'].append({'b64_json': b64})
else:
resp['data'].append({'url': f'data:image/png;base64,{b64}'})
@ -126,29 +59,3 @@ def _image_to_base64(image) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def _save_images(images, prompt, negative_prompt, width, height, steps, seed, cfg_scale):
"""Save images with metadata."""
date_str = datetime.now().strftime("%Y-%m-%d")
folder = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder, exist_ok=True)
metadata = {
'image_prompt': prompt,
'image_neg_prompt': negative_prompt,
'image_width': width,
'image_height': height,
'image_steps': steps,
'image_seed': seed,
'image_cfg_scale': cfg_scale,
'model': getattr(shared, 'image_model_name', 'unknown'),
}
for idx, img in enumerate(images):
ts = datetime.now().strftime("%H-%M-%S")
filepath = os.path.join(folder, f"{ts}_{seed:010d}_{idx:03d}.png")
png_info = PngInfo()
png_info.add_text("image_gen_settings", json.dumps(metadata))
img.save(filepath, pnginfo=png_info)

View file

@ -7,23 +7,24 @@ import traceback
from collections import deque
from threading import Thread
import extensions.openai.completions as OAIcompletions
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
import uvicorn
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
from fastapi.responses import JSONResponse
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool
import extensions.openai.completions as OAIcompletions
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from modules import shared
from modules.logging_colors import logger
from modules.models import unload_model
from modules.text_generation import stop_everything_event
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool
from .typing import (
ChatCompletionRequest,
@ -232,20 +233,7 @@ async def handle_image_generation(request_data: ImageGenerationRequest):
import extensions.openai.images as OAIimages
async with image_generation_semaphore:
width, height = request_data.get_width_height()
response = await asyncio.to_thread(
OAIimages.generations,
prompt=request_data.prompt,
size=f"{width}x{height}",
response_format=request_data.response_format,
n=request_data.batch_size, # <-- use resolved batch_size
negative_prompt=request_data.negative_prompt,
steps=request_data.steps,
seed=request_data.seed,
cfg_scale=request_data.cfg_scale,
batch_count=request_data.batch_count,
)
response = await asyncio.to_thread(OAIimages.generations, request_data)
return JSONResponse(response)

View file

@ -264,20 +264,18 @@ class LoadLorasRequest(BaseModel):
lora_names: List[str]
class ImageGenerationRequest(BaseModel):
"""OpenAI-compatible image generation request with extended parameters."""
# Required
class ImageGenerationRequestParams(BaseModel):
"""Image-specific parameters for generation."""
prompt: str
# Generation parameters
negative_prompt: str = ""
size: str = Field(default="1024x1024", description="'WIDTHxHEIGHT'")
steps: int = Field(default=9, ge=1)
cfg_scale: float = Field(default=0.0, ge=0.0)
seed: int = Field(default=-1, description="-1 for random")
image_seed: int = Field(default=-1, description="-1 for random")
batch_size: int | None = Field(default=None, ge=1, description="Parallel batch size (VRAM heavy)")
n: int = Field(default=1, ge=1, description="Alias for batch_size (OpenAI compatibility)")
batch_count: int = Field(default=1, ge=1, description="Sequential batch count")
llm_variations: bool = False
# OpenAI compatibility (unused)
model: str | None = None
@ -286,7 +284,6 @@ class ImageGenerationRequest(BaseModel):
@model_validator(mode='after')
def resolve_batch_size(self):
"""Use batch_size if provided, otherwise fall back to n."""
if self.batch_size is None:
self.batch_size = self.n
return self
@ -299,6 +296,10 @@ class ImageGenerationRequest(BaseModel):
return 1024, 1024
class ImageGenerationRequest(GenerationOptions, ImageGenerationRequestParams):
pass
class ImageGenerationResponse(BaseModel):
created: int = int(time.time())
data: List[dict]

View file

@ -10,7 +10,6 @@ import numpy as np
from PIL.PngImagePlugin import PngInfo
from modules import shared, ui, utils
from modules.utils import check_model_loaded
from modules.image_models import (
get_pipeline_type,
load_image_model,
@ -19,7 +18,7 @@ from modules.image_models import (
from modules.image_utils import open_image_safely
from modules.logging_colors import logger
from modules.text_generation import stop_everything_event
from modules.utils import gradio
from modules.utils import check_model_loaded, gradio
ASPECT_RATIOS = {
"1:1 Square": (1, 1),
@ -725,13 +724,13 @@ def progress_bar_html(progress=0, text=""):
return f'''<div class="image-ai-progress-wrapper">
<div class="image-ai-progress-track">
<div class="image-ai-progress-fill" style="width: {progress*100:.1f}%;"></div>
<div class="image-ai-progress-fill" style="width: {progress * 100:.1f}%;"></div>
</div>
<div class="image-ai-progress-text">{text}</div>
</div>'''
def generate(state):
def generate(state, save_images=True):
"""
Generate images using the loaded model.
Automatically adjusts parameters based on pipeline type.
@ -868,7 +867,8 @@ def generate(state):
yield all_images, progress_bar_html((batch_idx + 1) / batch_count, f"Batch {batch_idx + 1}/{batch_count} complete")
t1 = time.time()
save_generated_images(all_images, state, seed)
if save_images:
save_generated_images(all_images, state, seed)
total_images = batch_count * int(state['image_batch_size'])
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')