From 2549f7c33b8989c9604edc6476fa2354f9bb6662 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 12 Mar 2026 10:28:49 -0300 Subject: [PATCH] API: Add tool_choice support and fix tool_calls spec compliance --- extensions/openai/completions.py | 24 ++++++++++++++++++------ extensions/openai/typing.py | 1 + 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 03c4b03e..8d3cce57 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -218,6 +218,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0: tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails + tool_choice = body.get('tool_choice', None) + if tool_choice == "none": + tools = None # Disable tool detection entirely + messages = body['messages'] for m in messages: if 'role' not in m: @@ -367,6 +371,12 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p end_last_tool_call = 0 supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None + # Filter supported_tools when tool_choice specifies a particular function + if supported_tools and isinstance(tool_choice, dict): + specified_func = tool_choice.get("function", {}).get("name") + if specified_func and specified_func in supported_tools: + supported_tools = [specified_func] + for a in generator: answer = a['internal'][-1][1] @@ -375,11 +385,17 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p if len(tool_call) > 0: for tc in tool_call: tc["id"] = getToolCallId() - tc["index"] = len(tool_calls) + if stream: + tc["index"] = len(tool_calls) tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"]) tool_calls.append(tc) end_last_tool_call = len(answer) + # Stop generation before streaming content if tool_calls were detected, + # so that raw tool markup is not sent as content deltas. + if len(tool_calls) > 0: + break + if stream: len_seen = len(seen_content) new_content = answer[len_seen:] @@ -394,10 +410,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p seen_content = answer yield chunk - # stop generation if tool_calls were generated previously - if len(tool_calls) > 0: - break - token_count = shared.model.last_prompt_token_count if hasattr(shared.model, 'last_prompt_token_count') else 0 completion_token_count = len(encode(answer)[0]) stop_reason = "stop" @@ -441,7 +453,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p resp_list: [{ "index": 0, "finish_reason": stop_reason, - "message": {"role": "assistant", "refusal": None, "content": answer, **({"tool_calls": tool_calls} if tool_calls else {})}, + "message": {"role": "assistant", "refusal": None, "content": None if tool_calls else answer, **({"tool_calls": tool_calls} if tool_calls else {})}, "logprobs": None, }], "usage": { diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 078bd201..4d0f4a4a 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -150,6 +150,7 @@ class ChatCompletionRequestParams(BaseModel): function_call: str | dict | None = Field(default=None, description="Unused parameter.") functions: List[dict] | None = Field(default=None, description="Unused parameter.") tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.") + tool_choice: str | dict | None = Field(default=None, description="Controls tool use: 'auto', 'none', 'required', or {\"type\": \"function\", \"function\": {\"name\": \"...\"}}.") logit_bias: dict | None = None max_tokens: int | None = None n: int | None = Field(default=1, description="Unused parameter.")