API: Stop generation on client disconnect for non-streaming requests

This commit is contained in:
oobabooga 2026-03-07 06:06:13 -08:00
parent 304510eb3d
commit 328215b0c7

View file

@ -59,6 +59,15 @@ params = {
}
async def _wait_for_disconnect(request: Request, stop_event: threading.Event):
"""Block until the client disconnects, then signal the stop_event."""
while True:
message = await request.receive()
if message["type"] == "http.disconnect":
stop_event.set()
return
def verify_api_key(authorization: str = Header(None)) -> None:
expected_api_key = shared.args.api_key
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
@ -129,12 +138,17 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
else:
stop_event = threading.Event()
response = await asyncio.to_thread(
OAIcompletions.completions,
to_dict(request_data),
is_legacy=is_legacy,
stop_event=stop_event
)
monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event))
try:
response = await asyncio.to_thread(
OAIcompletions.completions,
to_dict(request_data),
is_legacy=is_legacy,
stop_event=stop_event
)
finally:
stop_event.set()
monitor.cancel()
return JSONResponse(response)
@ -164,12 +178,17 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
else:
stop_event = threading.Event()
response = await asyncio.to_thread(
OAIcompletions.chat_completions,
to_dict(request_data),
is_legacy=is_legacy,
stop_event=stop_event
)
monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event))
try:
response = await asyncio.to_thread(
OAIcompletions.chat_completions,
to_dict(request_data),
is_legacy=is_legacy,
stop_event=stop_event
)
finally:
stop_event.set()
monitor.cancel()
return JSONResponse(response)