API: Improve a validation

This commit is contained in:
oobabooga 2025-08-11 12:39:18 -07:00
parent a78ca6ffcd
commit 765af1ba17

View file

@ -2,7 +2,7 @@ import json
import time
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, field_validator, validator
from pydantic import BaseModel, Field, model_validator, validator
class GenerationOptions(BaseModel):
@ -116,16 +116,11 @@ class CompletionRequestParams(BaseModel):
top_p: float | None = 1
user: str | None = Field(default=None, description="Unused parameter.")
@field_validator('prompt', 'messages')
@classmethod
def validate_prompt_or_messages(cls, v, info):
"""Ensure either 'prompt' or 'messages' is provided for completions."""
if info.field_name == 'prompt': # If we're validating 'prompt', check if neither prompt nor messages will be set
messages = info.data.get('messages')
if v is None and messages is None:
raise ValueError("Either 'prompt' or 'messages' must be provided")
return v
@model_validator(mode='after')
def validate_prompt_or_messages(self):
if self.prompt is None and self.messages is None:
raise ValueError("Either 'prompt' or 'messages' must be provided")
return self
class CompletionRequest(GenerationOptions, CompletionRequestParams):