Add a /v1/internal/chat-prompt endpoint (#5879)

This commit is contained in:
oobabooga 2024-04-19 00:24:46 -03:00 committed by GitHub
parent b30bce3b2f
commit f27e1ba302
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 35 additions and 11 deletions

View file

@ -3,6 +3,7 @@ import json
import logging
import os
import traceback
from collections import deque
from threading import Thread
import speech_recognition as sr
@ -31,6 +32,7 @@ from modules.text_generation import stop_everything_event
from .typing import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatPromptResponse,
CompletionRequest,
CompletionResponse,
DecodeRequest,
@ -259,6 +261,15 @@ async def handle_logits(request_data: LogitsRequest):
return JSONResponse(response)
@app.post('/v1/internal/chat-prompt', response_model=ChatPromptResponse, dependencies=check_key)
async def handle_chat_prompt(request: Request, request_data: ChatCompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
generator = OAIcompletions.chat_completions_common(to_dict(request_data), is_legacy=is_legacy, prompt_only=True)
response = deque(generator, maxlen=1).pop()
return JSONResponse(response)
@app.post("/v1/internal/stop-generation", dependencies=check_key)
async def handle_stop_generation(request: Request):
stop_everything_event()