mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-18 03:14:39 +01:00
API: Add tool_choice support and fix tool_calls spec compliance
This commit is contained in:
parent
b5cac2e3b2
commit
2549f7c33b
|
|
@ -218,6 +218,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0:
|
||||
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
|
||||
|
||||
tool_choice = body.get('tool_choice', None)
|
||||
if tool_choice == "none":
|
||||
tools = None # Disable tool detection entirely
|
||||
|
||||
messages = body['messages']
|
||||
for m in messages:
|
||||
if 'role' not in m:
|
||||
|
|
@ -367,6 +371,12 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
end_last_tool_call = 0
|
||||
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
|
||||
|
||||
# Filter supported_tools when tool_choice specifies a particular function
|
||||
if supported_tools and isinstance(tool_choice, dict):
|
||||
specified_func = tool_choice.get("function", {}).get("name")
|
||||
if specified_func and specified_func in supported_tools:
|
||||
supported_tools = [specified_func]
|
||||
|
||||
for a in generator:
|
||||
answer = a['internal'][-1][1]
|
||||
|
||||
|
|
@ -375,11 +385,17 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
if len(tool_call) > 0:
|
||||
for tc in tool_call:
|
||||
tc["id"] = getToolCallId()
|
||||
tc["index"] = len(tool_calls)
|
||||
if stream:
|
||||
tc["index"] = len(tool_calls)
|
||||
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
|
||||
tool_calls.append(tc)
|
||||
end_last_tool_call = len(answer)
|
||||
|
||||
# Stop generation before streaming content if tool_calls were detected,
|
||||
# so that raw tool markup is not sent as content deltas.
|
||||
if len(tool_calls) > 0:
|
||||
break
|
||||
|
||||
if stream:
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
|
@ -394,10 +410,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
seen_content = answer
|
||||
yield chunk
|
||||
|
||||
# stop generation if tool_calls were generated previously
|
||||
if len(tool_calls) > 0:
|
||||
break
|
||||
|
||||
token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
|
|
@ -441,7 +453,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": stop_reason,
|
||||
"message": {"role": "assistant", "refusal": None, "content": answer, **({"tool_calls": tool_calls} if tool_calls else {})},
|
||||
"message": {"role": "assistant", "refusal": None, "content": None if tool_calls else answer, **({"tool_calls": tool_calls} if tool_calls else {})},
|
||||
"logprobs": None,
|
||||
}],
|
||||
"usage": {
|
||||
|
|
|
|||
|
|
@ -150,6 +150,7 @@ class ChatCompletionRequestParams(BaseModel):
|
|||
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
|
||||
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
|
||||
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
|
||||
tool_choice: str | dict | None = Field(default=None, description="Controls tool use: 'auto', 'none', 'required', or {\"type\": \"function\", \"function\": {\"name\": \"...\"}}.")
|
||||
logit_bias: dict | None = None
|
||||
max_tokens: int | None = None
|
||||
n: int | None = Field(default=1, description="Unused parameter.")
|
||||
|
|
|
|||
Loading…
Reference in a new issue