diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index a7d8b4e4..ed0bcc40 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -1,11 +1,14 @@ import copy import time +import json from collections import deque import tiktoken from extensions.openai.errors import InvalidRequestError -from extensions.openai.utils import debug_msg +from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall +from extensions.openai.typing import ToolDefinition +from pydantic import ValidationError from modules import shared from modules.chat import ( generate_chat_prompt, @@ -99,19 +102,24 @@ def convert_history(history): user_input = content user_input_last = True if current_message: - chat_dialogue.append([current_message, '']) + chat_dialogue.append([current_message, '', '']) current_message = "" current_message = content elif role == "assistant": + if "tool_calls" in entry and isinstance(entry["tool_calls"], list) and len(entry["tool_calls"]) > 0 and content.strip() == "": + continue # skip tool calls current_reply = content user_input_last = False if current_message: - chat_dialogue.append([current_message, current_reply]) + chat_dialogue.append([current_message, current_reply, '']) current_message = "" current_reply = "" else: - chat_dialogue.append(['', current_reply]) + chat_dialogue.append(['', current_reply, '']) + elif role == "tool": + user_input_last = False + chat_dialogue.append(['', '', content]) elif role == "system": system_message += f"\n{content}" if system_message else content @@ -131,6 +139,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p if 'messages' not in body: raise InvalidRequestError(message="messages is required", param='messages') + tools = None + if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0: + tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails + messages = body['messages'] for m in messages: if 'role' not in m: @@ -188,6 +200,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p 'custom_system_message': custom_system_message, 'chat_template_str': chat_template_str, 'chat-instruct_command': chat_instruct_command, + 'tools': tools, 'history': history, 'stream': stream }) @@ -200,7 +213,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p requested_model = generate_params.pop('model') logprob_proc = generate_params.pop('logprob_proc', None) - def chat_streaming_chunk(content): + def chat_streaming_chunk(content, chunk_tool_calls=None): # begin streaming chunk = { "id": cmpl_id, @@ -210,7 +223,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p resp_list: [{ "index": 0, "finish_reason": None, - "delta": {'role': 'assistant', 'content': content}, + "delta": {'role': 'assistant', 'content': content, 'tool_calls': chunk_tool_calls}, }], } @@ -219,6 +232,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} # else: # chunk[resp_list][0]["logprobs"] = None + return chunk # generate reply ####################################### @@ -227,8 +241,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p yield {'prompt': prompt} return - debug_msg({'prompt': prompt, 'generate_params': generate_params}) - if stream: yield chat_streaming_chunk('') @@ -238,8 +250,23 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p answer = '' seen_content = '' + tool_calls = [] + end_last_tool_call = 0 + supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None + for a in generator: answer = a['internal'][-1][1] + + if supported_tools is not None: + tool_call = parseToolCall(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else [] + if len(tool_call) > 0: + for tc in tool_call: + tc["id"] = getToolCallId() + tc["index"] = str(len(tool_calls)) + tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"]) + tool_calls.append(tc) + end_last_tool_call = len(answer) + if stream: len_seen = len(seen_content) new_content = answer[len_seen:] @@ -247,18 +274,25 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. continue - seen_content = answer chunk = chat_streaming_chunk(new_content) + + seen_content = answer yield chunk + # stop generation if tool_calls were generated previously + if len(tool_calls) > 0: + break + token_count = len(encode(prompt)[0]) completion_token_count = len(encode(answer)[0]) stop_reason = "stop" + if len(tool_calls) > 0: + stop_reason = "tool_calls" if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']: stop_reason = "length" if stream: - chunk = chat_streaming_chunk('') + chunk = chat_streaming_chunk('', tool_calls) chunk[resp_list][0]['finish_reason'] = stop_reason chunk['usage'] = { "prompt_tokens": token_count, @@ -276,7 +310,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p resp_list: [{ "index": 0, "finish_reason": stop_reason, - "message": {"role": "assistant", "content": answer} + "message": {"role": "assistant", "content": answer}, + "tool_calls": tool_calls }], "usage": { "prompt_tokens": token_count, @@ -465,3 +500,19 @@ def completions(body: dict, is_legacy: bool = False) -> dict: def stream_completions(body: dict, is_legacy: bool = False): for resp in completions_common(body, is_legacy, stream=True): yield resp + + +def validateTools(tools: list[dict]): + # Validate each tool definition in the JSON array + valid_tools = None + for idx in range(len(tools)): + tool = tools[idx] + try: + tool_definition = ToolDefinition(**tool) + if valid_tools is None: + valid_tools = [] + valid_tools.append(tool) + except ValidationError: + raise InvalidRequestError(message=f"Invalid tool specification at index {idx}.", param='tools') + + return valid_tools diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index b1979cbc..b28ebb4e 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -1,8 +1,8 @@ import json import time -from typing import Dict, List +from typing import Dict, List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator class GenerationOptions(BaseModel): @@ -54,6 +54,48 @@ class GenerationOptions(BaseModel): grammar_string: str = "" +class ToolDefinition(BaseModel): + function: 'ToolFunction' + type: str + + +class ToolFunction(BaseModel): + description: str + name: str + parameters: 'ToolParameters' + + +class ToolParameters(BaseModel): + properties: Optional[Dict[str, 'ToolProperty']] = None + required: Optional[list[str]] = None + type: str + description: Optional[str] = None + + +class ToolProperty(BaseModel): + description: Optional[str] = None + type: Optional[str] = None # we are faced with definitions like anyOf, e.g. {'type': 'function', 'function': {'name': 'git_create_branch', 'description': 'Creates a new branch from an optional base branch', 'parameters': {'type': 'object', 'properties': {'repo_path': {'title': 'Repo Path', 'type': 'string'}, 'branch_name': {'title': 'Branch Name', 'type': 'string'}, 'base_branch': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'default': None, 'title': 'Base Branch'}}, 'required': ['repo_path', 'branch_name'], 'title': 'GitCreateBranch'}}} + + +class FunctionCall(BaseModel): + name: str + arguments: Optional[str] = None + parameters: Optional[str] = None + + @validator('arguments', allow_reuse=True) + def checkPropertyArgsOrParams(cls, v, values, **kwargs): + if not v and not values.get('parameters'): + raise ValueError("At least one of 'arguments' or 'parameters' must be provided as property in FunctionCall type") + return v + + +class ToolCall(BaseModel): + id: str + index: int + type: str + function: FunctionCall + + class CompletionRequestParams(BaseModel): model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.") prompt: str | List[str] @@ -92,6 +134,7 @@ class ChatCompletionRequestParams(BaseModel): frequency_penalty: float | None = 0 function_call: str | dict | None = Field(default=None, description="Unused parameter.") functions: List[dict] | None = Field(default=None, description="Unused parameter.") + tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.") logit_bias: dict | None = None max_tokens: int | None = None n: int | None = Field(default=1, description="Unused parameter.") diff --git a/extensions/openai/utils.py b/extensions/openai/utils.py index 2b414769..8cb856ff 100644 --- a/extensions/openai/utils.py +++ b/extensions/openai/utils.py @@ -1,6 +1,9 @@ import base64 import os import time +import json +import random +import re import traceback from typing import Callable, Optional @@ -52,3 +55,94 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star time.sleep(3) raise Exception('Could not start cloudflared.') + + +def getToolCallId() -> str: + letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789" + b = [random.choice(letter_bytes) for _ in range(8)] + return "call_" + "".join(b).lower() + + +def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str]): + # check if property 'function' exists and is a dictionary, otherwise adapt dict + if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str): + candidate_dict = {"type": "function", "function": candidate_dict} + if 'function' in candidate_dict and isinstance(candidate_dict['function'], str): + candidate_dict['name'] = candidate_dict['function'] + del candidate_dict['function'] + candidate_dict = {"type": "function", "function": candidate_dict} + if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict): + # check if 'name' exists within 'function' and is part of known tools + if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names: + candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value + # map property 'parameters' used by some older models to 'arguments' + if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]: + candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"] + del candidate_dict["function"]["parameters"] + return candidate_dict + return None + + +def parseToolCall(answer: str, tool_names: list[str]): + matches = [] + + # abort on very short answers to save computation cycles + if len(answer) < 10: + return matches + + # Define the regex pattern to find the JSON content wrapped in , , , and other tags observed from various models + patterns = [ r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)" ] + + for pattern in patterns: + for match in re.finditer(pattern, answer, re.DOTALL): + # print(match.group(2)) + if match.group(2) is None: + continue + # remove backtick wraps if present + candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip()) + candidate = re.sub(r"```$", "", candidate.strip()) + # unwrap inner tags + candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL) + # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually + if re.search(r"\}\s*\n\s*\{", candidate) is not None: + candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate) + if not candidate.strip().startswith("["): + candidate = "[" + candidate + "]" + + candidates = [] + try: + # parse the candidate JSON into a dictionary + candidates = json.loads(candidate) + if not isinstance(candidates, list): + candidates = [candidates] + except json.JSONDecodeError: + # Ignore invalid JSON silently + continue + + for candidate_dict in candidates: + checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names) + if checked_candidate is not None: + matches.append(checked_candidate) + + # last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags + if len(matches) == 0: + try: + candidate = answer + # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually + if re.search(r"\}\s*\n\s*\{", candidate) is not None: + candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate) + if not candidate.strip().startswith("["): + candidate = "[" + candidate + "]" + # parse the candidate JSON into a dictionary + candidates = json.loads(candidate) + if not isinstance(candidates, list): + candidates = [candidates] + for candidate_dict in candidates: + checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names) + if checked_candidate is not None: + matches.append(checked_candidate) + except json.JSONDecodeError: + # Ignore invalid JSON silently + pass + + return matches diff --git a/modules/chat.py b/modules/chat.py index feac6bdd..b524b1b9 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -145,7 +145,7 @@ def generate_chat_prompt(user_input, state, **kwargs): instruct_renderer = partial( instruction_template.render, builtin_tools=None, - tools=None, + tools=state['tools'] if 'tools' in state else None, tools_in_user_message=False, add_generation_prompt=False ) @@ -171,9 +171,13 @@ def generate_chat_prompt(user_input, state, **kwargs): messages.append({"role": "system", "content": context}) insert_pos = len(messages) - for user_msg, assistant_msg in reversed(history): - user_msg = user_msg.strip() - assistant_msg = assistant_msg.strip() + for entry in reversed(history): + user_msg = entry[0].strip() + assistant_msg = entry[1].strip() + tool_msg = entry[2].strip() if len(entry) > 2 else '' + + if tool_msg: + messages.insert(insert_pos, {"role": "tool", "content": tool_msg}) if assistant_msg: messages.insert(insert_pos, {"role": "assistant", "content": assistant_msg})