mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-17 19:04:39 +01:00
API: Improve OpenAI spec compliance in streaming and non-streaming responses
This commit is contained in:
parent
3304b57bdf
commit
f1cfeae372
|
|
@ -310,28 +310,41 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
|
||||
def chat_streaming_chunk(content, chunk_tool_calls=None):
|
||||
def chat_streaming_chunk(content=None, chunk_tool_calls=None, include_role=False):
|
||||
# begin streaming
|
||||
delta = {}
|
||||
if include_role:
|
||||
delta['role'] = 'assistant'
|
||||
delta['refusal'] = None
|
||||
if content is not None:
|
||||
delta['content'] = content
|
||||
if chunk_tool_calls:
|
||||
delta['tool_calls'] = chunk_tool_calls
|
||||
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"delta": {'role': 'assistant', 'content': content, 'tool_calls': chunk_tool_calls},
|
||||
"delta": delta,
|
||||
"logprobs": None,
|
||||
}],
|
||||
}
|
||||
|
||||
if logprob_proc: # not official for chat yet
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
# else:
|
||||
# chunk[resp_list][0]["logprobs"] = None
|
||||
|
||||
return chunk
|
||||
|
||||
# Check if usage should be included in streaming chunks per OpenAI spec
|
||||
stream_options = body.get('stream_options')
|
||||
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
|
||||
|
||||
# generate reply #######################################
|
||||
if prompt_only:
|
||||
prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
|
||||
|
|
@ -339,7 +352,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
return
|
||||
|
||||
if stream:
|
||||
yield chat_streaming_chunk('')
|
||||
chunk = chat_streaming_chunk('', include_role=True)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
|
||||
generator = generate_chat_reply(
|
||||
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
||||
|
|
@ -372,6 +388,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
continue
|
||||
|
||||
chunk = chat_streaming_chunk(new_content)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
|
||||
seen_content = answer
|
||||
yield chunk
|
||||
|
|
@ -389,25 +407,42 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
stop_reason = "length"
|
||||
|
||||
if stream:
|
||||
chunk = chat_streaming_chunk('', tool_calls)
|
||||
chunk = chat_streaming_chunk(chunk_tool_calls=tool_calls)
|
||||
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||
chunk['usage'] = {
|
||||
usage = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
yield chunk
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
# Separate usage-only chunk with choices: [] per OpenAI spec
|
||||
yield {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [],
|
||||
"usage": usage
|
||||
}
|
||||
else:
|
||||
chunk['usage'] = usage
|
||||
yield chunk
|
||||
else:
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": stop_reason,
|
||||
"message": {"role": "assistant", "content": answer, **({"tool_calls": tool_calls} if tool_calls else {})},
|
||||
"message": {"role": "assistant", "refusal": None, "content": answer, **({"tool_calls": tool_calls} if tool_calls else {})},
|
||||
"logprobs": None,
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
|
|
@ -418,8 +453,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
if logprob_proc: # not official for chat yet
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
# else:
|
||||
# resp[resp_list][0]["logprobs"] = None
|
||||
|
||||
yield resp
|
||||
|
||||
|
|
@ -427,7 +460,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_event=None):
|
||||
object_type = 'text_completion'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
||||
cmpl_id = "cmpl-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
prompt_str = 'context' if is_legacy else 'prompt'
|
||||
|
|
@ -548,6 +581,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: resp_list_data,
|
||||
"usage": {
|
||||
"prompt_tokens": total_prompt_token_count,
|
||||
|
|
@ -572,6 +606,10 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
prefix = prompt if echo else ''
|
||||
token_count = len(encode(prompt)[0])
|
||||
|
||||
# Check if usage should be included in streaming chunks per OpenAI spec
|
||||
stream_options = body.get('stream_options')
|
||||
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
|
||||
|
||||
def text_streaming_chunk(content):
|
||||
# begin streaming
|
||||
if logprob_proc:
|
||||
|
|
@ -587,6 +625,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
|
|
@ -597,7 +636,10 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
|
||||
return chunk
|
||||
|
||||
yield text_streaming_chunk(prefix)
|
||||
chunk = text_streaming_chunk(prefix)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
|
|
@ -617,6 +659,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
|
||||
seen_content = answer
|
||||
chunk = text_streaming_chunk(new_content)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
|
|
@ -626,13 +670,28 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
|
||||
chunk = text_streaming_chunk(suffix)
|
||||
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||
chunk["usage"] = {
|
||||
usage = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
yield chunk
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
# Separate usage-only chunk with choices: [] per OpenAI spec
|
||||
yield {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
"system_fingerprint": None,
|
||||
resp_list: [],
|
||||
"usage": usage
|
||||
}
|
||||
else:
|
||||
chunk["usage"] = usage
|
||||
yield chunk
|
||||
|
||||
|
||||
def chat_completions(body: dict, is_legacy: bool = False, stop_event=None) -> dict:
|
||||
|
|
|
|||
|
|
@ -99,6 +99,10 @@ class ToolCall(BaseModel):
|
|||
function: FunctionCall
|
||||
|
||||
|
||||
class StreamOptions(BaseModel):
|
||||
include_usage: bool | None = False
|
||||
|
||||
|
||||
class CompletionRequestParams(BaseModel):
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
||||
prompt: str | List[str] | None = Field(default=None, description="Text prompt for completion. Can also use 'messages' format for multimodal.")
|
||||
|
|
@ -113,6 +117,7 @@ class CompletionRequestParams(BaseModel):
|
|||
presence_penalty: float | None = shared.args.presence_penalty
|
||||
stop: str | List[str] | None = None
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
suffix: str | None = None
|
||||
temperature: float | None = shared.args.temperature
|
||||
top_p: float | None = shared.args.top_p
|
||||
|
|
@ -151,6 +156,7 @@ class ChatCompletionRequestParams(BaseModel):
|
|||
presence_penalty: float | None = shared.args.presence_penalty
|
||||
stop: str | List[str] | None = None
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
temperature: float | None = shared.args.temperature
|
||||
top_p: float | None = shared.args.top_p
|
||||
user: str | None = Field(default=None, description="Unused parameter.")
|
||||
|
|
|
|||
Loading…
Reference in a new issue