From 4468c49439685dc8bc68e9d7a6109694a2eab72b Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 3 Dec 2025 12:02:47 -0800 Subject: [PATCH] Add semaphore to image generation API endpoint --- extensions/openai/script.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 1e982731..65805629 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -53,12 +53,12 @@ from .typing import ( params = { 'embedding_device': 'cpu', 'embedding_model': 'sentence-transformers/all-mpnet-base-v2', - 'sd_webui_url': '', 'debug': 0 } streaming_semaphore = asyncio.Semaphore(1) +image_generation_semaphore = asyncio.Semaphore(1) def verify_api_key(authorization: str = Header(None)) -> None: @@ -231,21 +231,22 @@ async def handle_audio_transcription(request: Request): async def handle_image_generation(request_data: ImageGenerationRequest): import extensions.openai.images as OAIimages - width, height = request_data.get_width_height() + async with image_generation_semaphore: + width, height = request_data.get_width_height() - response = await asyncio.to_thread( - OAIimages.generations, - prompt=request_data.prompt, - size=f"{width}x{height}", - response_format=request_data.response_format, - n=request_data.batch_size, # <-- use resolved batch_size - negative_prompt=request_data.negative_prompt, - steps=request_data.steps, - seed=request_data.seed, - cfg_scale=request_data.cfg_scale, - batch_count=request_data.batch_count, - ) - return JSONResponse(response) + response = await asyncio.to_thread( + OAIimages.generations, + prompt=request_data.prompt, + size=f"{width}x{height}", + response_format=request_data.response_format, + n=request_data.batch_size, # <-- use resolved batch_size + negative_prompt=request_data.negative_prompt, + steps=request_data.steps, + seed=request_data.seed, + cfg_scale=request_data.cfg_scale, + batch_count=request_data.batch_count, + ) + return JSONResponse(response) @app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)