From 328215b0c74d249d153e7e684c0241a4c89eee6e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 7 Mar 2026 06:06:13 -0800 Subject: [PATCH] API: Stop generation on client disconnect for non-streaming requests --- extensions/openai/script.py | 43 ++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index bfb6fd54..7a13638d 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -59,6 +59,15 @@ params = { } +async def _wait_for_disconnect(request: Request, stop_event: threading.Event): + """Block until the client disconnects, then signal the stop_event.""" + while True: + message = await request.receive() + if message["type"] == "http.disconnect": + stop_event.set() + return + + def verify_api_key(authorization: str = Header(None)) -> None: expected_api_key = shared.args.api_key if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"): @@ -129,12 +138,17 @@ async def openai_completions(request: Request, request_data: CompletionRequest): else: stop_event = threading.Event() - response = await asyncio.to_thread( - OAIcompletions.completions, - to_dict(request_data), - is_legacy=is_legacy, - stop_event=stop_event - ) + monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event)) + try: + response = await asyncio.to_thread( + OAIcompletions.completions, + to_dict(request_data), + is_legacy=is_legacy, + stop_event=stop_event + ) + finally: + stop_event.set() + monitor.cancel() return JSONResponse(response) @@ -164,12 +178,17 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion else: stop_event = threading.Event() - response = await asyncio.to_thread( - OAIcompletions.chat_completions, - to_dict(request_data), - is_legacy=is_legacy, - stop_event=stop_event - ) + monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event)) + try: + response = await asyncio.to_thread( + OAIcompletions.chat_completions, + to_dict(request_data), + is_legacy=is_legacy, + stop_event=stop_event + ) + finally: + stop_event.set() + monitor.cancel() return JSONResponse(response)