API: Improve a validation

This commit is contained in:
oobabooga 2025-08-11 12:39:18 -07:00
parent a78ca6ffcd
commit 765af1ba17

View file

@ -2,7 +2,7 @@ import json
import time import time
from typing import Dict, List, Optional from typing import Dict, List, Optional
from pydantic import BaseModel, Field, field_validator, validator from pydantic import BaseModel, Field, model_validator, validator
class GenerationOptions(BaseModel): class GenerationOptions(BaseModel):
@ -116,16 +116,11 @@ class CompletionRequestParams(BaseModel):
top_p: float | None = 1 top_p: float | None = 1
user: str | None = Field(default=None, description="Unused parameter.") user: str | None = Field(default=None, description="Unused parameter.")
@field_validator('prompt', 'messages') @model_validator(mode='after')
@classmethod def validate_prompt_or_messages(self):
def validate_prompt_or_messages(cls, v, info): if self.prompt is None and self.messages is None:
"""Ensure either 'prompt' or 'messages' is provided for completions.""" raise ValueError("Either 'prompt' or 'messages' must be provided")
if info.field_name == 'prompt': # If we're validating 'prompt', check if neither prompt nor messages will be set return self
messages = info.data.get('messages')
if v is None and messages is None:
raise ValueError("Either 'prompt' or 'messages' must be provided")
return v
class CompletionRequest(GenerationOptions, CompletionRequestParams): class CompletionRequest(GenerationOptions, CompletionRequestParams):