From c375b6941395454bef52d9ac0e102c0de3f4d3ee Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 13 May 2025 11:23:33 -0700 Subject: [PATCH] API: Fix llama.cpp generating after disconnect, improve disconnect detection, fix deadlock on simultaneous requests --- extensions/openai/script.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index a995da9d..2b4f274f 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -14,6 +14,7 @@ from fastapi.requests import Request from fastapi.responses import JSONResponse from pydub import AudioSegment from sse_starlette import EventSourceResponse +from starlette.concurrency import iterate_in_threadpool import extensions.openai.completions as OAIcompletions import extensions.openai.images as OAIimages @@ -115,7 +116,7 @@ async def openai_completions(request: Request, request_data: CompletionRequest): async def generator(): async with streaming_semaphore: response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy) - for resp in response: + async for resp in iterate_in_threadpool(response): disconnected = await request.is_disconnected() if disconnected: break @@ -125,7 +126,12 @@ async def openai_completions(request: Request, request_data: CompletionRequest): return EventSourceResponse(generator()) # SSE streaming else: - response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy) + response = await asyncio.to_thread( + OAIcompletions.completions, + to_dict(request_data), + is_legacy=is_legacy + ) + return JSONResponse(response) @@ -138,7 +144,7 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion async def generator(): async with streaming_semaphore: response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy) - for resp in response: + async for resp in iterate_in_threadpool(response): disconnected = await request.is_disconnected() if disconnected: break @@ -148,7 +154,12 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion return EventSourceResponse(generator()) # SSE streaming else: - response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy) + response = await asyncio.to_thread( + OAIcompletions.chat_completions, + to_dict(request_data), + is_legacy=is_legacy + ) + return JSONResponse(response)