mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Add an API endpoint for generating images
This commit is contained in:
parent
9448bf1caa
commit
5433ef3333
|
|
@ -1,70 +1,154 @@
|
||||||
|
"""
|
||||||
|
OpenAI-compatible image generation using local diffusion models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import requests
|
import numpy as np
|
||||||
|
|
||||||
from extensions.openai.errors import ServiceUnavailableError
|
from extensions.openai.errors import ServiceUnavailableError
|
||||||
|
from modules import shared
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
from PIL.PngImagePlugin import PngInfo
|
||||||
|
|
||||||
|
|
||||||
def generations(prompt: str, size: str, response_format: str, n: int):
|
def generations(prompt: str, size: str, response_format: str, n: int,
|
||||||
# Stable Diffusion callout wrapper for txt2img
|
negative_prompt: str = "", steps: int = 9, seed: int = -1,
|
||||||
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
cfg_scale: float = 0.0, batch_count: int = 1):
|
||||||
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
|
"""
|
||||||
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
|
Generate images using the loaded diffusion model.
|
||||||
# it's too general an API to try and shape the result with specific tags like negative prompts
|
|
||||||
# or "masterpiece", etc. SD configuration is beyond the scope of this API.
|
Args:
|
||||||
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
|
prompt: Text description of the desired image
|
||||||
# require changing the form data handling to accept multipart form data, also to properly support
|
size: Image dimensions as "WIDTHxHEIGHT"
|
||||||
# url return types will require file management and a web serving files... Perhaps later!
|
response_format: 'url' or 'b64_json'
|
||||||
base_model_size = 512 if 'SD_BASE_MODEL_SIZE' not in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512))
|
n: Number of images per batch
|
||||||
sd_defaults = {
|
negative_prompt: What to avoid in the image
|
||||||
'sampler_name': 'DPM++ 2M Karras', # vast improvement
|
steps: Number of inference steps
|
||||||
'steps': 30,
|
seed: Random seed (-1 for random)
|
||||||
|
cfg_scale: Classifier-free guidance scale
|
||||||
|
batch_count: Number of sequential batches
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with 'created' timestamp and 'data' list of images
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from modules.image_models import get_pipeline_type
|
||||||
|
from modules.torch_utils import clear_torch_cache, get_device
|
||||||
|
|
||||||
|
if shared.image_model is None:
|
||||||
|
raise ServiceUnavailableError("No image model loaded. Load a model via the UI first.")
|
||||||
|
|
||||||
|
clear_torch_cache()
|
||||||
|
|
||||||
|
# Parse dimensions
|
||||||
|
try:
|
||||||
|
width, height = [int(x) for x in size.split('x')]
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
width, height = 1024, 1024
|
||||||
|
|
||||||
|
# Handle seed
|
||||||
|
if seed == -1:
|
||||||
|
seed = np.random.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
|
device = get_device() or "cpu"
|
||||||
|
generator = torch.Generator(device).manual_seed(int(seed))
|
||||||
|
|
||||||
|
# Get pipeline type for CFG parameter name
|
||||||
|
pipeline_type = getattr(shared, 'image_pipeline_type', None) or get_pipeline_type(shared.image_model)
|
||||||
|
|
||||||
|
# Build generation kwargs
|
||||||
|
gen_kwargs = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"height": height,
|
||||||
|
"width": width,
|
||||||
|
"num_inference_steps": steps,
|
||||||
|
"num_images_per_prompt": n,
|
||||||
|
"generator": generator,
|
||||||
}
|
}
|
||||||
|
|
||||||
width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size
|
# Pipeline-specific CFG parameter
|
||||||
|
if pipeline_type == 'qwenimage':
|
||||||
|
gen_kwargs["true_cfg_scale"] = cfg_scale
|
||||||
|
else:
|
||||||
|
gen_kwargs["guidance_scale"] = cfg_scale
|
||||||
|
|
||||||
# to hack on better generation, edit default payload.
|
# Generate
|
||||||
payload = {
|
all_images = []
|
||||||
'prompt': prompt, # ignore prompt limit of 1000 characters
|
t0 = time.time()
|
||||||
'width': width,
|
|
||||||
'height': height,
|
|
||||||
'batch_size': n,
|
|
||||||
}
|
|
||||||
payload.update(sd_defaults)
|
|
||||||
|
|
||||||
scale = min(width, height) / base_model_size
|
shared.stop_everything = False
|
||||||
if scale >= 1.2:
|
|
||||||
# for better performance with the default size (1024), and larger res.
|
|
||||||
scaler = {
|
|
||||||
'width': width // scale,
|
|
||||||
'height': height // scale,
|
|
||||||
'hr_scale': scale,
|
|
||||||
'enable_hr': True,
|
|
||||||
'hr_upscaler': 'Latent',
|
|
||||||
'denoising_strength': 0.68,
|
|
||||||
}
|
|
||||||
payload.update(scaler)
|
|
||||||
|
|
||||||
|
def interrupt_callback(pipe, step_index, timestep, callback_kwargs):
|
||||||
|
if shared.stop_everything:
|
||||||
|
pipe._interrupt = True
|
||||||
|
return callback_kwargs
|
||||||
|
|
||||||
|
gen_kwargs["callback_on_step_end"] = interrupt_callback
|
||||||
|
|
||||||
|
for i in range(batch_count):
|
||||||
|
if shared.stop_everything:
|
||||||
|
break
|
||||||
|
generator.manual_seed(int(seed + i))
|
||||||
|
batch_results = shared.image_model(**gen_kwargs).images
|
||||||
|
all_images.extend(batch_results)
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
total_images = len(all_images)
|
||||||
|
total_steps = steps * batch_count
|
||||||
|
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
|
||||||
|
|
||||||
|
# Save images
|
||||||
|
_save_images(all_images, prompt, negative_prompt, width, height, steps, seed, cfg_scale)
|
||||||
|
|
||||||
|
# Build response
|
||||||
resp = {
|
resp = {
|
||||||
'created': int(time.time()),
|
'created': int(time.time()),
|
||||||
'data': []
|
'data': []
|
||||||
}
|
}
|
||||||
from extensions.openai.script import params
|
|
||||||
|
|
||||||
# TODO: support SD_WEBUI_AUTH username:password pair.
|
for img in all_images:
|
||||||
sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img"
|
b64 = _image_to_base64(img)
|
||||||
|
|
||||||
response = requests.post(url=sd_url, json=payload)
|
|
||||||
r = response.json()
|
|
||||||
if response.status_code != 200 or 'images' not in r:
|
|
||||||
print(r)
|
|
||||||
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None))
|
|
||||||
# r['parameters']...
|
|
||||||
for b64_json in r['images']:
|
|
||||||
if response_format == 'b64_json':
|
if response_format == 'b64_json':
|
||||||
resp['data'].extend([{'b64_json': b64_json}])
|
resp['data'].append({'b64_json': b64})
|
||||||
else:
|
else:
|
||||||
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
|
resp['data'].append({'url': f'data:image/png;base64,{b64}'})
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def _image_to_base64(image) -> str:
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
image.save(buffered, format="PNG")
|
||||||
|
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def _save_images(images, prompt, negative_prompt, width, height, steps, seed, cfg_scale):
|
||||||
|
"""Save images with metadata."""
|
||||||
|
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
folder = os.path.join("user_data", "image_outputs", date_str)
|
||||||
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
'image_prompt': prompt,
|
||||||
|
'image_neg_prompt': negative_prompt,
|
||||||
|
'image_width': width,
|
||||||
|
'image_height': height,
|
||||||
|
'image_steps': steps,
|
||||||
|
'image_seed': seed,
|
||||||
|
'image_cfg_scale': cfg_scale,
|
||||||
|
'model': getattr(shared, 'image_model_name', 'unknown'),
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, img in enumerate(images):
|
||||||
|
ts = datetime.now().strftime("%H-%M-%S")
|
||||||
|
filepath = os.path.join(folder, f"{ts}_{seed:010d}_{idx:03d}.png")
|
||||||
|
|
||||||
|
png_info = PngInfo()
|
||||||
|
png_info.add_text("image_gen_settings", json.dumps(metadata))
|
||||||
|
img.save(filepath, pnginfo=png_info)
|
||||||
|
|
|
||||||
|
|
@ -7,26 +7,23 @@ import traceback
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
import extensions.openai.completions as OAIcompletions
|
||||||
|
import extensions.openai.logits as OAIlogits
|
||||||
|
import extensions.openai.models as OAImodels
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||||
|
from extensions.openai.utils import _start_cloudflared
|
||||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from pydub import AudioSegment
|
|
||||||
from sse_starlette import EventSourceResponse
|
|
||||||
from starlette.concurrency import iterate_in_threadpool
|
|
||||||
|
|
||||||
import extensions.openai.completions as OAIcompletions
|
|
||||||
import extensions.openai.images as OAIimages
|
|
||||||
import extensions.openai.logits as OAIlogits
|
|
||||||
import extensions.openai.models as OAImodels
|
|
||||||
from extensions.openai.errors import ServiceUnavailableError
|
|
||||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
|
||||||
from extensions.openai.utils import _start_cloudflared
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.models import unload_model
|
from modules.models import unload_model
|
||||||
from modules.text_generation import stop_everything_event
|
from modules.text_generation import stop_everything_event
|
||||||
|
from pydub import AudioSegment
|
||||||
|
from sse_starlette import EventSourceResponse
|
||||||
|
from starlette.concurrency import iterate_in_threadpool
|
||||||
|
|
||||||
from .typing import (
|
from .typing import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
|
@ -40,6 +37,8 @@ from .typing import (
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EncodeRequest,
|
EncodeRequest,
|
||||||
EncodeResponse,
|
EncodeResponse,
|
||||||
|
ImageGenerationRequest,
|
||||||
|
ImageGenerationResponse,
|
||||||
LoadLorasRequest,
|
LoadLorasRequest,
|
||||||
LoadModelRequest,
|
LoadModelRequest,
|
||||||
LogitsRequest,
|
LogitsRequest,
|
||||||
|
|
@ -228,19 +227,24 @@ async def handle_audio_transcription(request: Request):
|
||||||
return JSONResponse(content=transcription)
|
return JSONResponse(content=transcription)
|
||||||
|
|
||||||
|
|
||||||
@app.post('/v1/images/generations', dependencies=check_key)
|
@app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key)
|
||||||
async def handle_image_generation(request: Request):
|
async def handle_image_generation(request_data: ImageGenerationRequest):
|
||||||
|
import extensions.openai.images as OAIimages
|
||||||
|
|
||||||
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
|
width, height = request_data.get_width_height()
|
||||||
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
|
||||||
|
|
||||||
body = await request.json()
|
response = await asyncio.to_thread(
|
||||||
prompt = body['prompt']
|
OAIimages.generations,
|
||||||
size = body.get('size', '1024x1024')
|
prompt=request_data.prompt,
|
||||||
response_format = body.get('response_format', 'url') # or b64_json
|
size=f"{width}x{height}",
|
||||||
n = body.get('n', 1) # ignore the batch limits of max 10
|
response_format=request_data.response_format,
|
||||||
|
n=request_data.batch_size, # <-- use resolved batch_size
|
||||||
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
negative_prompt=request_data.negative_prompt,
|
||||||
|
steps=request_data.steps,
|
||||||
|
seed=request_data.seed,
|
||||||
|
cfg_scale=request_data.cfg_scale,
|
||||||
|
batch_count=request_data.batch_count,
|
||||||
|
)
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -264,6 +264,46 @@ class LoadLorasRequest(BaseModel):
|
||||||
lora_names: List[str]
|
lora_names: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGenerationRequest(BaseModel):
|
||||||
|
"""OpenAI-compatible image generation request with extended parameters."""
|
||||||
|
# Required
|
||||||
|
prompt: str
|
||||||
|
|
||||||
|
# Generation parameters
|
||||||
|
negative_prompt: str = ""
|
||||||
|
size: str = Field(default="1024x1024", description="'WIDTHxHEIGHT'")
|
||||||
|
steps: int = Field(default=9, ge=1)
|
||||||
|
cfg_scale: float = Field(default=0.0, ge=0.0)
|
||||||
|
seed: int = Field(default=-1, description="-1 for random")
|
||||||
|
batch_size: int | None = Field(default=None, ge=1, description="Parallel batch size (VRAM heavy)")
|
||||||
|
n: int = Field(default=1, ge=1, description="Alias for batch_size (OpenAI compatibility)")
|
||||||
|
batch_count: int = Field(default=1, ge=1, description="Sequential batch count")
|
||||||
|
|
||||||
|
# OpenAI compatibility (unused)
|
||||||
|
model: str | None = None
|
||||||
|
response_format: str = "b64_json"
|
||||||
|
user: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def resolve_batch_size(self):
|
||||||
|
"""Use batch_size if provided, otherwise fall back to n."""
|
||||||
|
if self.batch_size is None:
|
||||||
|
self.batch_size = self.n
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_width_height(self) -> tuple[int, int]:
|
||||||
|
try:
|
||||||
|
parts = self.size.lower().split('x')
|
||||||
|
return int(parts[0]), int(parts[1])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return 1024, 1024
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGenerationResponse(BaseModel):
|
||||||
|
created: int = int(time.time())
|
||||||
|
data: List[dict]
|
||||||
|
|
||||||
|
|
||||||
def to_json(obj):
|
def to_json(obj):
|
||||||
return json.dumps(obj.__dict__, indent=4)
|
return json.dumps(obj.__dict__, indent=4)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue