API: Add tool_choice support and fix tool_calls spec compliance

This commit is contained in:
oobabooga 2026-03-12 10:28:49 -03:00
parent b5cac2e3b2
commit 2549f7c33b
2 changed files with 19 additions and 6 deletions

View file

@ -218,6 +218,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0:
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
tool_choice = body.get('tool_choice', None)
if tool_choice == "none":
tools = None # Disable tool detection entirely
messages = body['messages']
for m in messages:
if 'role' not in m:
@ -367,6 +371,12 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
end_last_tool_call = 0
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
# Filter supported_tools when tool_choice specifies a particular function
if supported_tools and isinstance(tool_choice, dict):
specified_func = tool_choice.get("function", {}).get("name")
if specified_func and specified_func in supported_tools:
supported_tools = [specified_func]
for a in generator:
answer = a['internal'][-1][1]
@ -375,11 +385,17 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
if len(tool_call) > 0:
for tc in tool_call:
tc["id"] = getToolCallId()
tc["index"] = len(tool_calls)
if stream:
tc["index"] = len(tool_calls)
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
tool_calls.append(tc)
end_last_tool_call = len(answer)
# Stop generation before streaming content if tool_calls were detected,
# so that raw tool markup is not sent as content deltas.
if len(tool_calls) > 0:
break
if stream:
len_seen = len(seen_content)
new_content = answer[len_seen:]
@ -394,10 +410,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
seen_content = answer
yield chunk
# stop generation if tool_calls were generated previously
if len(tool_calls) > 0:
break
token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
@ -441,7 +453,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"message": {"role": "assistant", "refusal": None, "content": answer, **({"tool_calls": tool_calls} if tool_calls else {})},
"message": {"role": "assistant", "refusal": None, "content": None if tool_calls else answer, **({"tool_calls": tool_calls} if tool_calls else {})},
"logprobs": None,
}],
"usage": {

View file

@ -150,6 +150,7 @@ class ChatCompletionRequestParams(BaseModel):
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
tool_choice: str | dict | None = Field(default=None, description="Controls tool use: 'auto', 'none', 'required', or {\"type\": \"function\", \"function\": {\"name\": \"...\"}}.")
logit_bias: dict | None = None
max_tokens: int | None = None
n: int | None = Field(default=1, description="Unused parameter.")