Add semaphore to image generation API endpoint

This commit is contained in:
oobabooga 2025-12-03 12:02:47 -08:00
parent 5ad174fad2
commit 4468c49439

View file

@ -53,12 +53,12 @@ from .typing import (
params = {
'embedding_device': 'cpu',
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
'sd_webui_url': '',
'debug': 0
}
streaming_semaphore = asyncio.Semaphore(1)
image_generation_semaphore = asyncio.Semaphore(1)
def verify_api_key(authorization: str = Header(None)) -> None:
@ -231,21 +231,22 @@ async def handle_audio_transcription(request: Request):
async def handle_image_generation(request_data: ImageGenerationRequest):
import extensions.openai.images as OAIimages
width, height = request_data.get_width_height()
async with image_generation_semaphore:
width, height = request_data.get_width_height()
response = await asyncio.to_thread(
OAIimages.generations,
prompt=request_data.prompt,
size=f"{width}x{height}",
response_format=request_data.response_format,
n=request_data.batch_size, # <-- use resolved batch_size
negative_prompt=request_data.negative_prompt,
steps=request_data.steps,
seed=request_data.seed,
cfg_scale=request_data.cfg_scale,
batch_count=request_data.batch_count,
)
return JSONResponse(response)
response = await asyncio.to_thread(
OAIimages.generations,
prompt=request_data.prompt,
size=f"{width}x{height}",
response_format=request_data.response_format,
n=request_data.batch_size, # <-- use resolved batch_size
negative_prompt=request_data.negative_prompt,
steps=request_data.steps,
seed=request_data.seed,
cfg_scale=request_data.cfg_scale,
batch_count=request_data.batch_count,
)
return JSONResponse(response)
@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)