mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-31 12:54:33 +01:00
Add semaphore to image generation API endpoint
This commit is contained in:
parent
5ad174fad2
commit
4468c49439
|
|
@ -53,12 +53,12 @@ from .typing import (
|
|||
params = {
|
||||
'embedding_device': 'cpu',
|
||||
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
|
||||
'sd_webui_url': '',
|
||||
'debug': 0
|
||||
}
|
||||
|
||||
|
||||
streaming_semaphore = asyncio.Semaphore(1)
|
||||
image_generation_semaphore = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
def verify_api_key(authorization: str = Header(None)) -> None:
|
||||
|
|
@ -231,21 +231,22 @@ async def handle_audio_transcription(request: Request):
|
|||
async def handle_image_generation(request_data: ImageGenerationRequest):
|
||||
import extensions.openai.images as OAIimages
|
||||
|
||||
width, height = request_data.get_width_height()
|
||||
async with image_generation_semaphore:
|
||||
width, height = request_data.get_width_height()
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
OAIimages.generations,
|
||||
prompt=request_data.prompt,
|
||||
size=f"{width}x{height}",
|
||||
response_format=request_data.response_format,
|
||||
n=request_data.batch_size, # <-- use resolved batch_size
|
||||
negative_prompt=request_data.negative_prompt,
|
||||
steps=request_data.steps,
|
||||
seed=request_data.seed,
|
||||
cfg_scale=request_data.cfg_scale,
|
||||
batch_count=request_data.batch_count,
|
||||
)
|
||||
return JSONResponse(response)
|
||||
response = await asyncio.to_thread(
|
||||
OAIimages.generations,
|
||||
prompt=request_data.prompt,
|
||||
size=f"{width}x{height}",
|
||||
response_format=request_data.response_format,
|
||||
n=request_data.batch_size, # <-- use resolved batch_size
|
||||
negative_prompt=request_data.negative_prompt,
|
||||
steps=request_data.steps,
|
||||
seed=request_data.seed,
|
||||
cfg_scale=request_data.cfg_scale,
|
||||
batch_count=request_data.batch_count,
|
||||
)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)
|
||||
|
|
|
|||
Loading…
Reference in a new issue