mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Add an API endpoint for generating images
This commit is contained in:
parent
9448bf1caa
commit
5433ef3333
|
|
@ -1,70 +1,154 @@
|
|||
"""
|
||||
OpenAI-compatible image generation using local diffusion models.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
|
||||
import numpy as np
|
||||
from extensions.openai.errors import ServiceUnavailableError
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
|
||||
def generations(prompt: str, size: str, response_format: str, n: int):
|
||||
# Stable Diffusion callout wrapper for txt2img
|
||||
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
||||
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
|
||||
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
|
||||
# it's too general an API to try and shape the result with specific tags like negative prompts
|
||||
# or "masterpiece", etc. SD configuration is beyond the scope of this API.
|
||||
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
|
||||
# require changing the form data handling to accept multipart form data, also to properly support
|
||||
# url return types will require file management and a web serving files... Perhaps later!
|
||||
base_model_size = 512 if 'SD_BASE_MODEL_SIZE' not in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512))
|
||||
sd_defaults = {
|
||||
'sampler_name': 'DPM++ 2M Karras', # vast improvement
|
||||
'steps': 30,
|
||||
def generations(prompt: str, size: str, response_format: str, n: int,
|
||||
negative_prompt: str = "", steps: int = 9, seed: int = -1,
|
||||
cfg_scale: float = 0.0, batch_count: int = 1):
|
||||
"""
|
||||
Generate images using the loaded diffusion model.
|
||||
|
||||
Args:
|
||||
prompt: Text description of the desired image
|
||||
size: Image dimensions as "WIDTHxHEIGHT"
|
||||
response_format: 'url' or 'b64_json'
|
||||
n: Number of images per batch
|
||||
negative_prompt: What to avoid in the image
|
||||
steps: Number of inference steps
|
||||
seed: Random seed (-1 for random)
|
||||
cfg_scale: Classifier-free guidance scale
|
||||
batch_count: Number of sequential batches
|
||||
|
||||
Returns:
|
||||
dict with 'created' timestamp and 'data' list of images
|
||||
"""
|
||||
import torch
|
||||
from modules.image_models import get_pipeline_type
|
||||
from modules.torch_utils import clear_torch_cache, get_device
|
||||
|
||||
if shared.image_model is None:
|
||||
raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.")
|
||||
|
||||
clear_torch_cache()
|
||||
|
||||
# Parse dimensions
|
||||
try:
|
||||
width, height = [int(x) for x in size.split('x')]
|
||||
except (ValueError, IndexError):
|
||||
width, height = 1024, 1024
|
||||
|
||||
# Handle seed
|
||||
if seed == -1:
|
||||
seed = np.random.randint(0, 2**32 - 1)
|
||||
|
||||
device = get_device() or "cpu"
|
||||
generator = torch.Generator(device).manual_seed(int(seed))
|
||||
|
||||
# Get pipeline type for CFG parameter name
|
||||
pipeline_type = getattr(shared, 'image_pipeline_type', None) or get_pipeline_type(shared.image_model)
|
||||
|
||||
# Build generation kwargs
|
||||
gen_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"num_inference_steps": steps,
|
||||
"num_images_per_prompt": n,
|
||||
"generator": generator,
|
||||
}
|
||||
|
||||
width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size
|
||||
# Pipeline-specific CFG parameter
|
||||
if pipeline_type == 'qwenimage':
|
||||
gen_kwargs["true_cfg_scale"] = cfg_scale
|
||||
else:
|
||||
gen_kwargs["guidance_scale"] = cfg_scale
|
||||
|
||||
# to hack on better generation, edit default payload.
|
||||
payload = {
|
||||
'prompt': prompt, # ignore prompt limit of 1000 characters
|
||||
'width': width,
|
||||
'height': height,
|
||||
'batch_size': n,
|
||||
}
|
||||
payload.update(sd_defaults)
|
||||
# Generate
|
||||
all_images = []
|
||||
t0 = time.time()
|
||||
|
||||
scale = min(width, height) / base_model_size
|
||||
if scale >= 1.2:
|
||||
# for better performance with the default size (1024), and larger res.
|
||||
scaler = {
|
||||
'width': width // scale,
|
||||
'height': height // scale,
|
||||
'hr_scale': scale,
|
||||
'enable_hr': True,
|
||||
'hr_upscaler': 'Latent',
|
||||
'denoising_strength': 0.68,
|
||||
}
|
||||
payload.update(scaler)
|
||||
shared.stop_everything = False
|
||||
|
||||
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
|
||||
if shared.stop_everything:
|
||||
pipe._interrupt = True
|
||||
return callback_kwargs
|
||||
|
||||
gen_kwargs["callback_on_step_end"] = interrupt_callback
|
||||
|
||||
for i in range(batch_count):
|
||||
if shared.stop_everything:
|
||||
break
|
||||
generator.manual_seed(int(seed + i))
|
||||
batch_results = shared.image_model(**gen_kwargs).images
|
||||
all_images.extend(batch_results)
|
||||
|
||||
t1 = time.time()
|
||||
total_images = len(all_images)
|
||||
total_steps = steps * batch_count
|
||||
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
|
||||
|
||||
# Save images
|
||||
_save_images(all_images, prompt, negative_prompt, width, height, steps, seed, cfg_scale)
|
||||
|
||||
# Build response
|
||||
resp = {
|
||||
'created': int(time.time()),
|
||||
'data': []
|
||||
}
|
||||
from extensions.openai.script import params
|
||||
|
||||
# TODO: support SD_WEBUI_AUTH username:password pair.
|
||||
sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img"
|
||||
|
||||
response = requests.post(url=sd_url, json=payload)
|
||||
r = response.json()
|
||||
if response.status_code != 200 or 'images' not in r:
|
||||
print(r)
|
||||
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None))
|
||||
# r['parameters']...
|
||||
for b64_json in r['images']:
|
||||
for img in all_images:
|
||||
b64 = _image_to_base64(img)
|
||||
if response_format == 'b64_json':
|
||||
resp['data'].extend([{'b64_json': b64_json}])
|
||||
resp['data'].append({'b64_json': b64})
|
||||
else:
|
||||
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
|
||||
resp['data'].append({'url': f'data:image/png;base64,{b64}'})
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def _image_to_base64(image) -> str:
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
|
||||
def _save_images(images, prompt, negative_prompt, width, height, steps, seed, cfg_scale):
|
||||
"""Save images with metadata."""
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
folder = os.path.join("user_data", "image_outputs", date_str)
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
metadata = {
|
||||
'image_prompt': prompt,
|
||||
'image_neg_prompt': negative_prompt,
|
||||
'image_width': width,
|
||||
'image_height': height,
|
||||
'image_steps': steps,
|
||||
'image_seed': seed,
|
||||
'image_cfg_scale': cfg_scale,
|
||||
'model': getattr(shared, 'image_model_name', 'unknown'),
|
||||
}
|
||||
|
||||
for idx, img in enumerate(images):
|
||||
ts = datetime.now().strftime("%H-%M-%S")
|
||||
filepath = os.path.join(folder, f"{ts}_{seed:010d}_{idx:03d}.png")
|
||||
|
||||
png_info = PngInfo()
|
||||
png_info.add_text("image_gen_settings", json.dumps(metadata))
|
||||
img.save(filepath, pnginfo=png_info)
|
||||
|
|
|
|||
|
|
@ -7,26 +7,23 @@ import traceback
|
|||
from collections import deque
|
||||
from threading import Thread
|
||||
|
||||
import extensions.openai.completions as OAIcompletions
|
||||
import extensions.openai.logits as OAIlogits
|
||||
import extensions.openai.models as OAImodels
|
||||
import uvicorn
|
||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||
from extensions.openai.utils import _start_cloudflared
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydub import AudioSegment
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
import extensions.openai.completions as OAIcompletions
|
||||
import extensions.openai.images as OAIimages
|
||||
import extensions.openai.logits as OAIlogits
|
||||
import extensions.openai.models as OAImodels
|
||||
from extensions.openai.errors import ServiceUnavailableError
|
||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||
from extensions.openai.utils import _start_cloudflared
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import unload_model
|
||||
from modules.text_generation import stop_everything_event
|
||||
from pydub import AudioSegment
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
from .typing import (
|
||||
ChatCompletionRequest,
|
||||
|
|
@ -40,6 +37,8 @@ from .typing import (
|
|||
EmbeddingsResponse,
|
||||
EncodeRequest,
|
||||
EncodeResponse,
|
||||
ImageGenerationRequest,
|
||||
ImageGenerationResponse,
|
||||
LoadLorasRequest,
|
||||
LoadModelRequest,
|
||||
LogitsRequest,
|
||||
|
|
@ -228,19 +227,24 @@ async def handle_audio_transcription(request: Request):
|
|||
return JSONResponse(content=transcription)
|
||||
|
||||
|
||||
@app.post('/v1/images/generations', dependencies=check_key)
|
||||
async def handle_image_generation(request: Request):
|
||||
@app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key)
|
||||
async def handle_image_generation(request_data: ImageGenerationRequest):
|
||||
import extensions.openai.images as OAIimages
|
||||
|
||||
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
|
||||
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
||||
width, height = request_data.get_width_height()
|
||||
|
||||
body = await request.json()
|
||||
prompt = body['prompt']
|
||||
size = body.get('size', '1024x1024')
|
||||
response_format = body.get('response_format', 'url') # or b64_json
|
||||
n = body.get('n', 1) # ignore the batch limits of max 10
|
||||
|
||||
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
||||
response = await asyncio.to_thread(
|
||||
OAIimages.generations,
|
||||
prompt=request_data.prompt,
|
||||
size=f"{width}x{height}",
|
||||
response_format=request_data.response_format,
|
||||
n=request_data.batch_size, # <-- use resolved batch_size
|
||||
negative_prompt=request_data.negative_prompt,
|
||||
steps=request_data.steps,
|
||||
seed=request_data.seed,
|
||||
cfg_scale=request_data.cfg_scale,
|
||||
batch_count=request_data.batch_count,
|
||||
)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -264,6 +264,46 @@ class LoadLorasRequest(BaseModel):
|
|||
lora_names: List[str]
|
||||
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
"""OpenAI-compatible image generation request with extended parameters."""
|
||||
# Required
|
||||
prompt: str
|
||||
|
||||
# Generation parameters
|
||||
negative_prompt: str = ""
|
||||
size: str = Field(default="1024x1024", description="'WIDTHxHEIGHT'")
|
||||
steps: int = Field(default=9, ge=1)
|
||||
cfg_scale: float = Field(default=0.0, ge=0.0)
|
||||
seed: int = Field(default=-1, description="-1 for random")
|
||||
batch_size: int | None = Field(default=None, ge=1, description="Parallel batch size (VRAM heavy)")
|
||||
n: int = Field(default=1, ge=1, description="Alias for batch_size (OpenAI compatibility)")
|
||||
batch_count: int = Field(default=1, ge=1, description="Sequential batch count")
|
||||
|
||||
# OpenAI compatibility (unused)
|
||||
model: str | None = None
|
||||
response_format: str = "b64_json"
|
||||
user: str | None = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def resolve_batch_size(self):
|
||||
"""Use batch_size if provided, otherwise fall back to n."""
|
||||
if self.batch_size is None:
|
||||
self.batch_size = self.n
|
||||
return self
|
||||
|
||||
def get_width_height(self) -> tuple[int, int]:
|
||||
try:
|
||||
parts = self.size.lower().split('x')
|
||||
return int(parts[0]), int(parts[1])
|
||||
except (ValueError, IndexError):
|
||||
return 1024, 1024
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
created: int = int(time.time())
|
||||
data: List[dict]
|
||||
|
||||
|
||||
def to_json(obj):
|
||||
return json.dumps(obj.__dict__, indent=4)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue