diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 90366270..56d91582 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -2,7 +2,7 @@ import json import time from typing import Dict, List, Optional -from pydantic import BaseModel, Field, field_validator, validator +from pydantic import BaseModel, Field, model_validator, validator class GenerationOptions(BaseModel): @@ -116,16 +116,11 @@ class CompletionRequestParams(BaseModel): top_p: float | None = 1 user: str | None = Field(default=None, description="Unused parameter.") - @field_validator('prompt', 'messages') - @classmethod - def validate_prompt_or_messages(cls, v, info): - """Ensure either 'prompt' or 'messages' is provided for completions.""" - if info.field_name == 'prompt': # If we're validating 'prompt', check if neither prompt nor messages will be set - messages = info.data.get('messages') - if v is None and messages is None: - raise ValueError("Either 'prompt' or 'messages' must be provided") - - return v + @model_validator(mode='after') + def validate_prompt_or_messages(self): + if self.prompt is None and self.messages is None: + raise ValueError("Either 'prompt' or 'messages' must be provided") + return self class CompletionRequest(GenerationOptions, CompletionRequestParams):