mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
API: Improve a validation
This commit is contained in:
parent
a78ca6ffcd
commit
765af1ba17
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, validator
|
from pydantic import BaseModel, Field, model_validator, validator
|
||||||
|
|
||||||
|
|
||||||
class GenerationOptions(BaseModel):
|
class GenerationOptions(BaseModel):
|
||||||
|
|
@ -116,16 +116,11 @@ class CompletionRequestParams(BaseModel):
|
||||||
top_p: float | None = 1
|
top_p: float | None = 1
|
||||||
user: str | None = Field(default=None, description="Unused parameter.")
|
user: str | None = Field(default=None, description="Unused parameter.")
|
||||||
|
|
||||||
@field_validator('prompt', 'messages')
|
@model_validator(mode='after')
|
||||||
@classmethod
|
def validate_prompt_or_messages(self):
|
||||||
def validate_prompt_or_messages(cls, v, info):
|
if self.prompt is None and self.messages is None:
|
||||||
"""Ensure either 'prompt' or 'messages' is provided for completions."""
|
raise ValueError("Either 'prompt' or 'messages' must be provided")
|
||||||
if info.field_name == 'prompt': # If we're validating 'prompt', check if neither prompt nor messages will be set
|
return self
|
||||||
messages = info.data.get('messages')
|
|
||||||
if v is None and messages is None:
|
|
||||||
raise ValueError("Either 'prompt' or 'messages' must be provided")
|
|
||||||
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(GenerationOptions, CompletionRequestParams):
|
class CompletionRequest(GenerationOptions, CompletionRequestParams):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue