Add an API endpoint for generating images

This commit is contained in:
oobabooga 2025-12-03 11:50:35 -08:00
parent 9448bf1caa
commit 5433ef3333
3 changed files with 200 additions and 72 deletions

View file

@ -1,70 +1,154 @@
"""
OpenAI-compatible image generation using local diffusion models.
"""
import base64
import io
import json
import os
import time
from datetime import datetime
import requests
import numpy as np
from extensions.openai.errors import ServiceUnavailableError
from modules import shared
from modules.logging_colors import logger
from PIL.PngImagePlugin import PngInfo
def generations(prompt: str, size: str, response_format: str, n: int):
# Stable Diffusion callout wrapper for txt2img
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
# it's too general an API to try and shape the result with specific tags like negative prompts
# or "masterpiece", etc. SD configuration is beyond the scope of this API.
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
# require changing the form data handling to accept multipart form data, also to properly support
# url return types will require file management and a web serving files... Perhaps later!
base_model_size = 512 if 'SD_BASE_MODEL_SIZE' not in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512))
sd_defaults = {
'sampler_name': 'DPM++ 2M Karras', # vast improvement
'steps': 30,
def generations(prompt: str, size: str, response_format: str, n: int,
negative_prompt: str = "", steps: int = 9, seed: int = -1,
cfg_scale: float = 0.0, batch_count: int = 1):
"""
Generate images using the loaded diffusion model.
Args:
prompt: Text description of the desired image
size: Image dimensions as "WIDTHxHEIGHT"
response_format: 'url' or 'b64_json'
n: Number of images per batch
negative_prompt: What to avoid in the image
steps: Number of inference steps
seed: Random seed (-1 for random)
cfg_scale: Classifier-free guidance scale
batch_count: Number of sequential batches
Returns:
dict with 'created' timestamp and 'data' list of images
"""
import torch
from modules.image_models import get_pipeline_type
from modules.torch_utils import clear_torch_cache, get_device
if shared.image_model is None:
raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.")
clear_torch_cache()
# Parse dimensions
try:
width, height = [int(x) for x in size.split('x')]
except (ValueError, IndexError):
width, height = 1024, 1024
# Handle seed
if seed == -1:
seed = np.random.randint(0, 2**32 - 1)
device = get_device() or "cpu"
generator = torch.Generator(device).manual_seed(int(seed))
# Get pipeline type for CFG parameter name
pipeline_type = getattr(shared, 'image_pipeline_type', None) or get_pipeline_type(shared.image_model)
# Build generation kwargs
gen_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"num_inference_steps": steps,
"num_images_per_prompt": n,
"generator": generator,
}
width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size
# Pipeline-specific CFG parameter
if pipeline_type == 'qwenimage':
gen_kwargs["true_cfg_scale"] = cfg_scale
else:
gen_kwargs["guidance_scale"] = cfg_scale
# to hack on better generation, edit default payload.
payload = {
'prompt': prompt, # ignore prompt limit of 1000 characters
'width': width,
'height': height,
'batch_size': n,
}
payload.update(sd_defaults)
# Generate
all_images = []
t0 = time.time()
scale = min(width, height) / base_model_size
if scale >= 1.2:
# for better performance with the default size (1024), and larger res.
scaler = {
'width': width // scale,
'height': height // scale,
'hr_scale': scale,
'enable_hr': True,
'hr_upscaler': 'Latent',
'denoising_strength': 0.68,
}
payload.update(scaler)
shared.stop_everything = False
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
if shared.stop_everything:
pipe._interrupt = True
return callback_kwargs
gen_kwargs["callback_on_step_end"] = interrupt_callback
for i in range(batch_count):
if shared.stop_everything:
break
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(**gen_kwargs).images
all_images.extend(batch_results)
t1 = time.time()
total_images = len(all_images)
total_steps = steps * batch_count
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
# Save images
_save_images(all_images, prompt, negative_prompt, width, height, steps, seed, cfg_scale)
# Build response
resp = {
'created': int(time.time()),
'data': []
}
from extensions.openai.script import params
# TODO: support SD_WEBUI_AUTH username:password pair.
sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img"
response = requests.post(url=sd_url, json=payload)
r = response.json()
if response.status_code != 200 or 'images' not in r:
print(r)
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None))
# r['parameters']...
for b64_json in r['images']:
for img in all_images:
b64 = _image_to_base64(img)
if response_format == 'b64_json':
resp['data'].extend([{'b64_json': b64_json}])
resp['data'].append({'b64_json': b64})
else:
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
resp['data'].append({'url': f'data:image/png;base64,{b64}'})
return resp
def _image_to_base64(image) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def _save_images(images, prompt, negative_prompt, width, height, steps, seed, cfg_scale):
"""Save images with metadata."""
date_str = datetime.now().strftime("%Y-%m-%d")
folder = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder, exist_ok=True)
metadata = {
'image_prompt': prompt,
'image_neg_prompt': negative_prompt,
'image_width': width,
'image_height': height,
'image_steps': steps,
'image_seed': seed,
'image_cfg_scale': cfg_scale,
'model': getattr(shared, 'image_model_name', 'unknown'),
}
for idx, img in enumerate(images):
ts = datetime.now().strftime("%H-%M-%S")
filepath = os.path.join(folder, f"{ts}_{seed:010d}_{idx:03d}.png")
png_info = PngInfo()
png_info.add_text("image_gen_settings", json.dumps(metadata))
img.save(filepath, pnginfo=png_info)

View file

@ -7,26 +7,23 @@ import traceback
from collections import deque
from threading import Thread
import extensions.openai.completions as OAIcompletions
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
import uvicorn
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
from fastapi.responses import JSONResponse
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool
import extensions.openai.completions as OAIcompletions
import extensions.openai.images as OAIimages
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from modules import shared
from modules.logging_colors import logger
from modules.models import unload_model
from modules.text_generation import stop_everything_event
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool
from .typing import (
ChatCompletionRequest,
@ -40,6 +37,8 @@ from .typing import (
EmbeddingsResponse,
EncodeRequest,
EncodeResponse,
ImageGenerationRequest,
ImageGenerationResponse,
LoadLorasRequest,
LoadModelRequest,
LogitsRequest,
@ -228,19 +227,24 @@ async def handle_audio_transcription(request: Request):
return JSONResponse(content=transcription)
@app.post('/v1/images/generations', dependencies=check_key)
async def handle_image_generation(request: Request):
@app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key)
async def handle_image_generation(request_data: ImageGenerationRequest):
import extensions.openai.images as OAIimages
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
width, height = request_data.get_width_height()
body = await request.json()
prompt = body['prompt']
size = body.get('size', '1024x1024')
response_format = body.get('response_format', 'url') # or b64_json
n = body.get('n', 1) # ignore the batch limits of max 10
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
response = await asyncio.to_thread(
OAIimages.generations,
prompt=request_data.prompt,
size=f"{width}x{height}",
response_format=request_data.response_format,
n=request_data.batch_size, # <-- use resolved batch_size
negative_prompt=request_data.negative_prompt,
steps=request_data.steps,
seed=request_data.seed,
cfg_scale=request_data.cfg_scale,
batch_count=request_data.batch_count,
)
return JSONResponse(response)

View file

@ -264,6 +264,46 @@ class LoadLorasRequest(BaseModel):
lora_names: List[str]
class ImageGenerationRequest(BaseModel):
"""OpenAI-compatible image generation request with extended parameters."""
# Required
prompt: str
# Generation parameters
negative_prompt: str = ""
size: str = Field(default="1024x1024", description="'WIDTHxHEIGHT'")
steps: int = Field(default=9, ge=1)
cfg_scale: float = Field(default=0.0, ge=0.0)
seed: int = Field(default=-1, description="-1 for random")
batch_size: int | None = Field(default=None, ge=1, description="Parallel batch size (VRAM heavy)")
n: int = Field(default=1, ge=1, description="Alias for batch_size (OpenAI compatibility)")
batch_count: int = Field(default=1, ge=1, description="Sequential batch count")
# OpenAI compatibility (unused)
model: str | None = None
response_format: str = "b64_json"
user: str | None = None
@model_validator(mode='after')
def resolve_batch_size(self):
"""Use batch_size if provided, otherwise fall back to n."""
if self.batch_size is None:
self.batch_size = self.n
return self
def get_width_height(self) -> tuple[int, int]:
try:
parts = self.size.lower().split('x')
return int(parts[0]), int(parts[1])
except (ValueError, IndexError):
return 1024, 1024
class ImageGenerationResponse(BaseModel):
created: int = int(time.time())
data: List[dict]
def to_json(obj):
return json.dumps(obj.__dict__, indent=4)