Add an API endpoint for generating images

This commit is contained in:
oobabooga 2025-12-03 11:50:35 -08:00
parent 9448bf1caa
commit 5433ef3333
3 changed files with 200 additions and 72 deletions

View file

@ -1,70 +1,154 @@
"""
OpenAI-compatible image generation using local diffusion models.
"""
import base64
import io
import json
import os import os
import time import time
from datetime import datetime
import requests import numpy as np
from extensions.openai.errors import ServiceUnavailableError from extensions.openai.errors import ServiceUnavailableError
from modules import shared
from modules.logging_colors import logger
from PIL.PngImagePlugin import PngInfo
def generations(prompt: str, size: str, response_format: str, n: int): def generations(prompt: str, size: str, response_format: str, n: int,
# Stable Diffusion callout wrapper for txt2img negative_prompt: str = "", steps: int = 9, seed: int = -1,
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E cfg_scale: float = 0.0, batch_count: int = 1):
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings. """
# If you want high quality tailored results you should just use the Stable Diffusion API directly. Generate images using the loaded diffusion model.
# it's too general an API to try and shape the result with specific tags like negative prompts
# or "masterpiece", etc. SD configuration is beyond the scope of this API. Args:
# At this point I will not add the edits and variations endpoints (ie. img2img) because they prompt: Text description of the desired image
# require changing the form data handling to accept multipart form data, also to properly support size: Image dimensions as "WIDTHxHEIGHT"
# url return types will require file management and a web serving files... Perhaps later! response_format: 'url' or 'b64_json'
base_model_size = 512 if 'SD_BASE_MODEL_SIZE' not in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512)) n: Number of images per batch
sd_defaults = { negative_prompt: What to avoid in the image
'sampler_name': 'DPM++ 2M Karras', # vast improvement steps: Number of inference steps
'steps': 30, seed: Random seed (-1 for random)
cfg_scale: Classifier-free guidance scale
batch_count: Number of sequential batches
Returns:
dict with 'created' timestamp and 'data' list of images
"""
import torch
from modules.image_models import get_pipeline_type
from modules.torch_utils import clear_torch_cache, get_device
if shared.image_model is None:
raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.")
clear_torch_cache()
# Parse dimensions
try:
width, height = [int(x) for x in size.split('x')]
except (ValueError, IndexError):
width, height = 1024, 1024
# Handle seed
if seed == -1:
seed = np.random.randint(0, 2**32 - 1)
device = get_device() or "cpu"
generator = torch.Generator(device).manual_seed(int(seed))
# Get pipeline type for CFG parameter name
pipeline_type = getattr(shared, 'image_pipeline_type', None) or get_pipeline_type(shared.image_model)
# Build generation kwargs
gen_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"num_inference_steps": steps,
"num_images_per_prompt": n,
"generator": generator,
} }
width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size # Pipeline-specific CFG parameter
if pipeline_type == 'qwenimage':
gen_kwargs["true_cfg_scale"] = cfg_scale
else:
gen_kwargs["guidance_scale"] = cfg_scale
# to hack on better generation, edit default payload. # Generate
payload = { all_images = []
'prompt': prompt, # ignore prompt limit of 1000 characters t0 = time.time()
'width': width,
'height': height,
'batch_size': n,
}
payload.update(sd_defaults)
scale = min(width, height) / base_model_size shared.stop_everything = False
if scale >= 1.2:
# for better performance with the default size (1024), and larger res.
scaler = {
'width': width // scale,
'height': height // scale,
'hr_scale': scale,
'enable_hr': True,
'hr_upscaler': 'Latent',
'denoising_strength': 0.68,
}
payload.update(scaler)
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
if shared.stop_everything:
pipe._interrupt = True
return callback_kwargs
gen_kwargs["callback_on_step_end"] = interrupt_callback
for i in range(batch_count):
if shared.stop_everything:
break
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(**gen_kwargs).images
all_images.extend(batch_results)
t1 = time.time()
total_images = len(all_images)
total_steps = steps * batch_count
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
# Save images
_save_images(all_images, prompt, negative_prompt, width, height, steps, seed, cfg_scale)
# Build response
resp = { resp = {
'created': int(time.time()), 'created': int(time.time()),
'data': [] 'data': []
} }
from extensions.openai.script import params
# TODO: support SD_WEBUI_AUTH username:password pair. for img in all_images:
sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img" b64 = _image_to_base64(img)
response = requests.post(url=sd_url, json=payload)
r = response.json()
if response.status_code != 200 or 'images' not in r:
print(r)
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None))
# r['parameters']...
for b64_json in r['images']:
if response_format == 'b64_json': if response_format == 'b64_json':
resp['data'].extend([{'b64_json': b64_json}]) resp['data'].append({'b64_json': b64})
else: else:
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this resp['data'].append({'url': f'data:image/png;base64,{b64}'})
return resp return resp
def _image_to_base64(image) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def _save_images(images, prompt, negative_prompt, width, height, steps, seed, cfg_scale):
"""Save images with metadata."""
date_str = datetime.now().strftime("%Y-%m-%d")
folder = os.path.join("user_data", "image_outputs", date_str)
os.makedirs(folder, exist_ok=True)
metadata = {
'image_prompt': prompt,
'image_neg_prompt': negative_prompt,
'image_width': width,
'image_height': height,
'image_steps': steps,
'image_seed': seed,
'image_cfg_scale': cfg_scale,
'model': getattr(shared, 'image_model_name', 'unknown'),
}
for idx, img in enumerate(images):
ts = datetime.now().strftime("%H-%M-%S")
filepath = os.path.join(folder, f"{ts}_{seed:010d}_{idx:03d}.png")
png_info = PngInfo()
png_info.add_text("image_gen_settings", json.dumps(metadata))
img.save(filepath, pnginfo=png_info)

View file

@ -7,26 +7,23 @@ import traceback
from collections import deque from collections import deque
from threading import Thread from threading import Thread
import extensions.openai.completions as OAIcompletions
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
import uvicorn import uvicorn
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from fastapi import Depends, FastAPI, Header, HTTPException from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool
import extensions.openai.completions as OAIcompletions
import extensions.openai.images as OAIimages
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.models import unload_model from modules.models import unload_model
from modules.text_generation import stop_everything_event from modules.text_generation import stop_everything_event
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool
from .typing import ( from .typing import (
ChatCompletionRequest, ChatCompletionRequest,
@ -40,6 +37,8 @@ from .typing import (
EmbeddingsResponse, EmbeddingsResponse,
EncodeRequest, EncodeRequest,
EncodeResponse, EncodeResponse,
ImageGenerationRequest,
ImageGenerationResponse,
LoadLorasRequest, LoadLorasRequest,
LoadModelRequest, LoadModelRequest,
LogitsRequest, LogitsRequest,
@ -228,19 +227,24 @@ async def handle_audio_transcription(request: Request):
return JSONResponse(content=transcription) return JSONResponse(content=transcription)
@app.post('/v1/images/generations', dependencies=check_key) @app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key)
async def handle_image_generation(request: Request): async def handle_image_generation(request_data: ImageGenerationRequest):
import extensions.openai.images as OAIimages
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')): width, height = request_data.get_width_height()
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
body = await request.json() response = await asyncio.to_thread(
prompt = body['prompt'] OAIimages.generations,
size = body.get('size', '1024x1024') prompt=request_data.prompt,
response_format = body.get('response_format', 'url') # or b64_json size=f"{width}x{height}",
n = body.get('n', 1) # ignore the batch limits of max 10 response_format=request_data.response_format,
n=request_data.batch_size, # <-- use resolved batch_size
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n) negative_prompt=request_data.negative_prompt,
steps=request_data.steps,
seed=request_data.seed,
cfg_scale=request_data.cfg_scale,
batch_count=request_data.batch_count,
)
return JSONResponse(response) return JSONResponse(response)

View file

@ -264,6 +264,46 @@ class LoadLorasRequest(BaseModel):
lora_names: List[str] lora_names: List[str]
class ImageGenerationRequest(BaseModel):
"""OpenAI-compatible image generation request with extended parameters."""
# Required
prompt: str
# Generation parameters
negative_prompt: str = ""
size: str = Field(default="1024x1024", description="'WIDTHxHEIGHT'")
steps: int = Field(default=9, ge=1)
cfg_scale: float = Field(default=0.0, ge=0.0)
seed: int = Field(default=-1, description="-1 for random")
batch_size: int | None = Field(default=None, ge=1, description="Parallel batch size (VRAM heavy)")
n: int = Field(default=1, ge=1, description="Alias for batch_size (OpenAI compatibility)")
batch_count: int = Field(default=1, ge=1, description="Sequential batch count")
# OpenAI compatibility (unused)
model: str | None = None
response_format: str = "b64_json"
user: str | None = None
@model_validator(mode='after')
def resolve_batch_size(self):
"""Use batch_size if provided, otherwise fall back to n."""
if self.batch_size is None:
self.batch_size = self.n
return self
def get_width_height(self) -> tuple[int, int]:
try:
parts = self.size.lower().split('x')
return int(parts[0]), int(parts[1])
except (ValueError, IndexError):
return 1024, 1024
class ImageGenerationResponse(BaseModel):
created: int = int(time.time())
data: List[dict]
def to_json(obj): def to_json(obj):
return json.dumps(obj.__dict__, indent=4) return json.dumps(obj.__dict__, indent=4)