From 5a017aa3380ae273e0ce32fe02709f93cb710d20 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 12 Mar 2026 13:23:01 -0300 Subject: [PATCH] API: Several OpenAI spec compliance fixes - Return proper OpenAI error format ({"error": {...}}) instead of HTTP 500 for validation errors - Send data: [DONE] at the end of SSE streams - Fix finish_reason so "tool_calls" takes priority over "length" - Stop including usage in streaming chunks when include_usage is not set - Handle "developer" role in messages (treated same as "system") - Add logprobs and top_logprobs parameters for chat completions - Fix chat completions logprobs not working with llama.cpp and ExLlamav3 backends - Add max_completion_tokens as an alias for max_tokens in chat completions --- extensions/openai/completions.py | 25 ++++++++++++++++++------- extensions/openai/script.py | 19 +++++++++++++++++++ extensions/openai/typing.py | 9 +++++++++ 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 8d3cce57..a8b899d5 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -105,6 +105,10 @@ def process_parameters(body, is_legacy=False): logits_processor = [LogitsBiasProcessor(logit_bias)] logprobs = body.get('logprobs', None) + top_logprobs = body.get('top_logprobs', None) + # For chat completions, logprobs is a bool; use top_logprobs for the count + if logprobs is True: + logprobs = top_logprobs if top_logprobs and top_logprobs > 0 else 5 if logprobs is not None and logprobs > 0: generate_params['logprob_proc'] = LogprobProcessor(logprobs) logits_processor.extend([generate_params['logprob_proc']]) @@ -191,7 +195,7 @@ def convert_history(history): if "tool_call_id" in entry: meta["tool_call_id"] = entry["tool_call_id"] chat_dialogue.append(['', '', content, meta]) - elif role == "system": + elif role in ("system", "developer"): system_message += f"\n{content}" if system_message else content if not user_input_last: @@ -339,9 +343,13 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p }], } - if logprob_proc: # not official for chat yet + if logprob_proc: top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} + elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): + backend_logprobs = get_logprobs_from_backend() + if backend_logprobs: + chunk[resp_list][0]["logprobs"] = {'top_logprobs': [backend_logprobs]} return chunk @@ -412,11 +420,12 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0 completion_token_count = len(encode(answer)[0]) - stop_reason = "stop" if len(tool_calls) > 0: stop_reason = "tool_calls" - if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']: + elif token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']: stop_reason = "length" + else: + stop_reason = "stop" if stream: chunk = chat_streaming_chunk(chunk_tool_calls=tool_calls) @@ -441,7 +450,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p "usage": usage } else: - chunk['usage'] = usage yield chunk else: resp = { @@ -462,9 +470,13 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p "total_tokens": token_count + completion_token_count } } - if logprob_proc: # not official for chat yet + if logprob_proc: top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} + elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): + backend_logprobs = get_logprobs_from_backend() + if backend_logprobs: + resp[resp_list][0]["logprobs"] = {'top_logprobs': [backend_logprobs]} yield resp @@ -702,7 +714,6 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e "usage": usage } else: - chunk["usage"] = usage yield chunk diff --git a/extensions/openai/script.py b/extensions/openai/script.py index e3726bc8..f161e1e4 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -21,6 +21,7 @@ import extensions.openai.completions as OAIcompletions import extensions.openai.logits as OAIlogits import extensions.openai.models as OAImodels from extensions.openai.tokens import token_count, token_decode, token_encode +from extensions.openai.errors import OpenAIError from extensions.openai.utils import _start_cloudflared from modules import shared from modules.logging_colors import logger @@ -94,6 +95,20 @@ app.add_middleware( ) +@app.exception_handler(OpenAIError) +async def openai_error_handler(request: Request, exc: OpenAIError): + error_type = "server_error" if exc.code >= 500 else "invalid_request_error" + return JSONResponse( + status_code=exc.code, + content={"error": { + "message": exc.message, + "type": error_type, + "param": getattr(exc, 'param', None), + "code": None + }} + ) + + @app.middleware("http") async def validate_host_header(request: Request, call_next): # Be strict about only approving access to localhost by default @@ -136,6 +151,8 @@ async def openai_completions(request: Request, request_data: CompletionRequest): break yield {"data": json.dumps(resp)} + + yield {"data": "[DONE]"} finally: stop_event.set() response.close() @@ -176,6 +193,8 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion break yield {"data": json.dumps(resp)} + + yield {"data": "[DONE]"} finally: stop_event.set() response.close() diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 4d0f4a4a..80831c44 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -152,7 +152,10 @@ class ChatCompletionRequestParams(BaseModel): tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.") tool_choice: str | dict | None = Field(default=None, description="Controls tool use: 'auto', 'none', 'required', or {\"type\": \"function\", \"function\": {\"name\": \"...\"}}.") logit_bias: dict | None = None + logprobs: bool | None = None + top_logprobs: int | None = None max_tokens: int | None = None + max_completion_tokens: int | None = None n: int | None = Field(default=1, description="Unused parameter.") presence_penalty: float | None = shared.args.presence_penalty stop: str | List[str] | None = None @@ -162,6 +165,12 @@ class ChatCompletionRequestParams(BaseModel): top_p: float | None = shared.args.top_p user: str | None = Field(default=None, description="Unused parameter.") + @model_validator(mode='after') + def resolve_max_tokens(self): + if self.max_tokens is None and self.max_completion_tokens is not None: + self.max_tokens = self.max_completion_tokens + return self + mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.") instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. If not set, the correct template will be automatically obtained from the model metadata.")