mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-18 03:14:39 +01:00
API: Several OpenAI spec compliance fixes
- Return proper OpenAI error format ({"error": {...}}) instead of HTTP 500 for validation errors
- Send data: [DONE] at the end of SSE streams
- Fix finish_reason so "tool_calls" takes priority over "length"
- Stop including usage in streaming chunks when include_usage is not set
- Handle "developer" role in messages (treated same as "system")
- Add logprobs and top_logprobs parameters for chat completions
- Fix chat completions logprobs not working with llama.cpp and ExLlamav3 backends
- Add max_completion_tokens as an alias for max_tokens in chat completions
This commit is contained in:
parent
4b6c9db1c9
commit
5a017aa338
|
|
@ -105,6 +105,10 @@ def process_parameters(body, is_legacy=False):
|
|||
logits_processor = [LogitsBiasProcessor(logit_bias)]
|
||||
|
||||
logprobs = body.get('logprobs', None)
|
||||
top_logprobs = body.get('top_logprobs', None)
|
||||
# For chat completions, logprobs is a bool; use top_logprobs for the count
|
||||
if logprobs is True:
|
||||
logprobs = top_logprobs if top_logprobs and top_logprobs > 0 else 5
|
||||
if logprobs is not None and logprobs > 0:
|
||||
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||
logits_processor.extend([generate_params['logprob_proc']])
|
||||
|
|
@ -191,7 +195,7 @@ def convert_history(history):
|
|||
if "tool_call_id" in entry:
|
||||
meta["tool_call_id"] = entry["tool_call_id"]
|
||||
chat_dialogue.append(['', '', content, meta])
|
||||
elif role == "system":
|
||||
elif role in ("system", "developer"):
|
||||
system_message += f"\n{content}" if system_message else content
|
||||
|
||||
if not user_input_last:
|
||||
|
|
@ -339,9 +343,13 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
}],
|
||||
}
|
||||
|
||||
if logprob_proc: # not official for chat yet
|
||||
if logprob_proc:
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
backend_logprobs = get_logprobs_from_backend()
|
||||
if backend_logprobs:
|
||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [backend_logprobs]}
|
||||
|
||||
return chunk
|
||||
|
||||
|
|
@ -412,11 +420,12 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
|
||||
token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if len(tool_calls) > 0:
|
||||
stop_reason = "tool_calls"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
|
||||
elif token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
|
||||
stop_reason = "length"
|
||||
else:
|
||||
stop_reason = "stop"
|
||||
|
||||
if stream:
|
||||
chunk = chat_streaming_chunk(chunk_tool_calls=tool_calls)
|
||||
|
|
@ -441,7 +450,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
"usage": usage
|
||||
}
|
||||
else:
|
||||
chunk['usage'] = usage
|
||||
yield chunk
|
||||
else:
|
||||
resp = {
|
||||
|
|
@ -462,9 +470,13 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
}
|
||||
if logprob_proc: # not official for chat yet
|
||||
if logprob_proc:
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
backend_logprobs = get_logprobs_from_backend()
|
||||
if backend_logprobs:
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [backend_logprobs]}
|
||||
|
||||
yield resp
|
||||
|
||||
|
|
@ -702,7 +714,6 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
"usage": usage
|
||||
}
|
||||
else:
|
||||
chunk["usage"] = usage
|
||||
yield chunk
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import extensions.openai.completions as OAIcompletions
|
|||
import extensions.openai.logits as OAIlogits
|
||||
import extensions.openai.models as OAImodels
|
||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||
from extensions.openai.errors import OpenAIError
|
||||
from extensions.openai.utils import _start_cloudflared
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
|
@ -94,6 +95,20 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
|
||||
@app.exception_handler(OpenAIError)
|
||||
async def openai_error_handler(request: Request, exc: OpenAIError):
|
||||
error_type = "server_error" if exc.code >= 500 else "invalid_request_error"
|
||||
return JSONResponse(
|
||||
status_code=exc.code,
|
||||
content={"error": {
|
||||
"message": exc.message,
|
||||
"type": error_type,
|
||||
"param": getattr(exc, 'param', None),
|
||||
"code": None
|
||||
}}
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def validate_host_header(request: Request, call_next):
|
||||
# Be strict about only approving access to localhost by default
|
||||
|
|
@ -136,6 +151,8 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
|
|||
break
|
||||
|
||||
yield {"data": json.dumps(resp)}
|
||||
|
||||
yield {"data": "[DONE]"}
|
||||
finally:
|
||||
stop_event.set()
|
||||
response.close()
|
||||
|
|
@ -176,6 +193,8 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
|
|||
break
|
||||
|
||||
yield {"data": json.dumps(resp)}
|
||||
|
||||
yield {"data": "[DONE]"}
|
||||
finally:
|
||||
stop_event.set()
|
||||
response.close()
|
||||
|
|
|
|||
|
|
@ -152,7 +152,10 @@ class ChatCompletionRequestParams(BaseModel):
|
|||
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
|
||||
tool_choice: str | dict | None = Field(default=None, description="Controls tool use: 'auto', 'none', 'required', or {\"type\": \"function\", \"function\": {\"name\": \"...\"}}.")
|
||||
logit_bias: dict | None = None
|
||||
logprobs: bool | None = None
|
||||
top_logprobs: int | None = None
|
||||
max_tokens: int | None = None
|
||||
max_completion_tokens: int | None = None
|
||||
n: int | None = Field(default=1, description="Unused parameter.")
|
||||
presence_penalty: float | None = shared.args.presence_penalty
|
||||
stop: str | List[str] | None = None
|
||||
|
|
@ -162,6 +165,12 @@ class ChatCompletionRequestParams(BaseModel):
|
|||
top_p: float | None = shared.args.top_p
|
||||
user: str | None = Field(default=None, description="Unused parameter.")
|
||||
|
||||
@model_validator(mode='after')
|
||||
def resolve_max_tokens(self):
|
||||
if self.max_tokens is None and self.max_completion_tokens is not None:
|
||||
self.max_tokens = self.max_completion_tokens
|
||||
return self
|
||||
|
||||
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
|
||||
|
||||
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/user_data/instruction-templates. If not set, the correct template will be automatically obtained from the model metadata.")
|
||||
|
|
|
|||
Loading…
Reference in a new issue