mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-06 07:03:37 +00:00
API: Add Anthropic-compatible /v1/messages endpoint
This commit is contained in:
parent
f0e3997f37
commit
0216893475
3 changed files with 600 additions and 4 deletions
|
|
@ -10,6 +10,7 @@ from threading import Thread
|
|||
|
||||
import uvicorn
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
|
@ -19,6 +20,7 @@ from starlette.concurrency import iterate_in_threadpool
|
|||
import modules.api.completions as OAIcompletions
|
||||
import modules.api.logits as OAIlogits
|
||||
import modules.api.models as OAImodels
|
||||
import modules.api.anthropic as Anthropic
|
||||
from .tokens import token_count, token_decode, token_encode
|
||||
from .errors import OpenAIError
|
||||
from .utils import _start_cloudflared
|
||||
|
|
@ -28,6 +30,7 @@ from modules.models import unload_model
|
|||
from modules.text_generation import stop_everything_event # used by /v1/internal/stop-generation
|
||||
|
||||
from .typing import (
|
||||
AnthropicRequest,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatPromptResponse,
|
||||
|
|
@ -74,9 +77,23 @@ def verify_admin_key(authorization: str = Header(None)) -> None:
|
|||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
def verify_anthropic_key(x_api_key: str = Header(None, alias="x-api-key")) -> None:
|
||||
expected_api_key = shared.args.api_key
|
||||
if expected_api_key and (x_api_key is None or x_api_key != expected_api_key):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
class AnthropicError(Exception):
|
||||
def __init__(self, message: str, error_type: str = "invalid_request_error", status_code: int = 400):
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
check_key = [Depends(verify_api_key)]
|
||||
check_admin_key = [Depends(verify_admin_key)]
|
||||
check_anthropic_key = [Depends(verify_anthropic_key)]
|
||||
|
||||
# Configure CORS settings to allow all origins, methods, and headers
|
||||
app.add_middleware(
|
||||
|
|
@ -102,6 +119,28 @@ async def openai_error_handler(request: Request, exc: OpenAIError):
|
|||
)
|
||||
|
||||
|
||||
@app.exception_handler(AnthropicError)
|
||||
async def anthropic_error_handler(request: Request, exc: AnthropicError):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"type": "error", "error": {"type": exc.error_type, "message": exc.message}}
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_error_handler(request: Request, exc: RequestValidationError):
|
||||
if request.url.path.startswith("/v1/messages"):
|
||||
messages = "; ".join(
|
||||
f"{'.'.join(str(l) for l in e['loc'])}: {e['msg']}" for e in exc.errors()
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"type": "error", "error": {"type": "invalid_request_error", "message": messages}}
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=422, content={"detail": exc.errors()})
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def validate_host_header(request: Request, call_next):
|
||||
# Be strict about only approving access to localhost by default
|
||||
|
|
@ -211,6 +250,76 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
|
|||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post('/v1/messages', dependencies=check_anthropic_key)
|
||||
async def anthropic_messages(request: Request, request_data: AnthropicRequest):
|
||||
body = to_dict(request_data)
|
||||
model = body.get('model') or shared.model_name or 'unknown'
|
||||
|
||||
try:
|
||||
converted = Anthropic.convert_request(body)
|
||||
except Exception as e:
|
||||
raise AnthropicError(message=str(e))
|
||||
|
||||
try:
|
||||
return await _anthropic_generate(request, request_data, converted, model)
|
||||
except OpenAIError as e:
|
||||
error_type = "invalid_request_error" if e.code < 500 else "api_error"
|
||||
if e.code == 503:
|
||||
error_type = "overloaded_error"
|
||||
raise AnthropicError(message=e.message, error_type=error_type, status_code=e.code)
|
||||
except Exception as e:
|
||||
raise AnthropicError(message=str(e) or "Internal server error", error_type="api_error", status_code=500)
|
||||
|
||||
|
||||
async def _anthropic_generate(request, request_data, converted, model):
|
||||
if request_data.stream:
|
||||
stop_event = threading.Event()
|
||||
|
||||
async def generator():
|
||||
converter = Anthropic.StreamConverter(model)
|
||||
response = OAIcompletions.stream_chat_completions(converted, is_legacy=False, stop_event=stop_event)
|
||||
try:
|
||||
async for resp in iterate_in_threadpool(response):
|
||||
disconnected = await request.is_disconnected()
|
||||
if disconnected:
|
||||
break
|
||||
|
||||
for event in converter.process_chunk(resp):
|
||||
yield event
|
||||
|
||||
for event in converter.finish():
|
||||
yield event
|
||||
except OpenAIError as e:
|
||||
error_type = "invalid_request_error" if e.code < 500 else "api_error"
|
||||
if e.code == 503:
|
||||
error_type = "overloaded_error"
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": json.dumps({"type": "error", "error": {"type": error_type, "message": e.message}})
|
||||
}
|
||||
finally:
|
||||
stop_event.set()
|
||||
response.close()
|
||||
|
||||
return EventSourceResponse(generator(), sep="\n")
|
||||
|
||||
else:
|
||||
stop_event = threading.Event()
|
||||
monitor = asyncio.create_task(_wait_for_disconnect(request, stop_event))
|
||||
try:
|
||||
openai_resp = await asyncio.to_thread(
|
||||
OAIcompletions.chat_completions,
|
||||
converted,
|
||||
is_legacy=False,
|
||||
stop_event=stop_event
|
||||
)
|
||||
finally:
|
||||
stop_event.set()
|
||||
monitor.cancel()
|
||||
|
||||
return JSONResponse(Anthropic.build_response(openai_resp, model))
|
||||
|
||||
|
||||
@app.get("/v1/models", dependencies=check_key)
|
||||
@app.get("/v1/models/{model}", dependencies=check_key)
|
||||
async def handle_models(request: Request):
|
||||
|
|
@ -469,15 +578,15 @@ def run_server():
|
|||
port,
|
||||
shared.args.public_api_id,
|
||||
max_attempts=3,
|
||||
on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}/v1\n')
|
||||
on_start=lambda url: logger.info(f'API URL (OpenAI + Anthropic compatible):\n\n{url}/v1\n')
|
||||
)
|
||||
else:
|
||||
url_proto = 'https://' if (ssl_certfile and ssl_keyfile) else 'http://'
|
||||
urls = [f'{url_proto}{addr}:{port}/v1' for addr in server_addrs]
|
||||
if len(urls) > 1:
|
||||
logger.info('OpenAI-compatible API URLs:\n\n' + '\n'.join(urls) + '\n')
|
||||
logger.info('API URLs (OpenAI + Anthropic compatible):\n\n' + '\n'.join(urls) + '\n')
|
||||
else:
|
||||
logger.info('OpenAI-compatible API URL:\n\n' + '\n'.join(urls) + '\n')
|
||||
logger.info('API URL (OpenAI + Anthropic compatible):\n\n' + '\n'.join(urls) + '\n')
|
||||
|
||||
# Log API keys
|
||||
if shared.args.api_key:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue