API: Add command-line flags to override default generation parameters

This commit is contained in:
oobabooga 2026-03-06 01:36:30 -03:00
parent 8a9afcbec6
commit 27bcc45c18
3 changed files with 138 additions and 89 deletions

View file

@ -4,56 +4,58 @@ from typing import Dict, List, Optional
from pydantic import BaseModel, Field, model_validator, validator
from modules import shared
class GenerationOptions(BaseModel):
preset: str | None = Field(default=None, description="The name of a file under text-generation-webui/user_data/presets (without the .yaml extension). The sampling parameters that get overwritten by this option are the keys in the default_preset() function in modules/presets.py.")
dynatemp_low: float = 1
dynatemp_high: float = 1
dynatemp_exponent: float = 1
smoothing_factor: float = 0
smoothing_curve: float = 1
min_p: float = 0
top_k: int = 0
typical_p: float = 1
xtc_threshold: float = 0.1
xtc_probability: float = 0
epsilon_cutoff: float = 0
eta_cutoff: float = 0
tfs: float = 1
top_a: float = 0
top_n_sigma: float = 0
adaptive_target: float = 0
adaptive_decay: float = 0.9
dry_multiplier: float = 0
dry_allowed_length: int = 2
dry_base: float = 1.75
repetition_penalty: float = 1
encoder_repetition_penalty: float = 1
no_repeat_ngram_size: int = 0
repetition_penalty_range: int = 1024
penalty_alpha: float = 0
guidance_scale: float = 1
mirostat_mode: int = 0
mirostat_tau: float = 5
mirostat_eta: float = 0.1
dynatemp_low: float = shared.args.dynatemp_low
dynatemp_high: float = shared.args.dynatemp_high
dynatemp_exponent: float = shared.args.dynatemp_exponent
smoothing_factor: float = shared.args.smoothing_factor
smoothing_curve: float = shared.args.smoothing_curve
min_p: float = shared.args.min_p
top_k: int = shared.args.top_k
typical_p: float = shared.args.typical_p
xtc_threshold: float = shared.args.xtc_threshold
xtc_probability: float = shared.args.xtc_probability
epsilon_cutoff: float = shared.args.epsilon_cutoff
eta_cutoff: float = shared.args.eta_cutoff
tfs: float = shared.args.tfs
top_a: float = shared.args.top_a
top_n_sigma: float = shared.args.top_n_sigma
adaptive_target: float = shared.args.adaptive_target
adaptive_decay: float = shared.args.adaptive_decay
dry_multiplier: float = shared.args.dry_multiplier
dry_allowed_length: int = shared.args.dry_allowed_length
dry_base: float = shared.args.dry_base
repetition_penalty: float = shared.args.repetition_penalty
encoder_repetition_penalty: float = shared.args.encoder_repetition_penalty
no_repeat_ngram_size: int = shared.args.no_repeat_ngram_size
repetition_penalty_range: int = shared.args.repetition_penalty_range
penalty_alpha: float = shared.args.penalty_alpha
guidance_scale: float = shared.args.guidance_scale
mirostat_mode: int = shared.args.mirostat_mode
mirostat_tau: float = shared.args.mirostat_tau
mirostat_eta: float = shared.args.mirostat_eta
prompt_lookup_num_tokens: int = 0
max_tokens_second: int = 0
do_sample: bool = True
dynamic_temperature: bool = False
temperature_last: bool = False
do_sample: bool = shared.args.do_sample
dynamic_temperature: bool = shared.args.dynamic_temperature
temperature_last: bool = shared.args.temperature_last
auto_max_new_tokens: bool = False
ban_eos_token: bool = False
add_bos_token: bool = True
enable_thinking: bool = True
reasoning_effort: str = "medium"
enable_thinking: bool = shared.args.enable_thinking
reasoning_effort: str = shared.args.reasoning_effort
skip_special_tokens: bool = True
static_cache: bool = False
truncation_length: int = 0
seed: int = -1
sampler_priority: List[str] | str | None = Field(default=['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'top_n_sigma', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'adaptive_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'], description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].")
sampler_priority: List[str] | str | None = Field(default=shared.args.sampler_priority, description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].")
custom_token_bans: str = ""
negative_prompt: str = ''
dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"'
dry_sequence_breakers: str = shared.args.dry_sequence_breakers
grammar_string: str = ""
@ -105,17 +107,17 @@ class CompletionRequestParams(BaseModel):
messages: List[dict] | None = Field(default=None, description="OpenAI messages format for multimodal support. Alternative to 'prompt'.")
best_of: int | None = Field(default=1, description="Unused parameter.")
echo: bool | None = False
frequency_penalty: float | None = 0
frequency_penalty: float | None = shared.args.frequency_penalty
logit_bias: dict | None = None
logprobs: int | None = None
max_tokens: int | None = 512
n: int | None = Field(default=1, description="Unused parameter.")
presence_penalty: float | None = 0
presence_penalty: float | None = shared.args.presence_penalty
stop: str | List[str] | None = None
stream: bool | None = False
suffix: str | None = None
temperature: float | None = 1
top_p: float | None = 1
temperature: float | None = shared.args.temperature
top_p: float | None = shared.args.top_p
user: str | None = Field(default=None, description="Unused parameter.")
@model_validator(mode='after')
@ -141,18 +143,18 @@ class CompletionResponse(BaseModel):
class ChatCompletionRequestParams(BaseModel):
messages: List[dict]
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
frequency_penalty: float | None = 0
frequency_penalty: float | None = shared.args.frequency_penalty
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
logit_bias: dict | None = None
max_tokens: int | None = None
n: int | None = Field(default=1, description="Unused parameter.")
presence_penalty: float | None = 0
presence_penalty: float | None = shared.args.presence_penalty
stop: str | List[str] | None = None
stream: bool | None = False
temperature: float | None = 1
top_p: float | None = 1
temperature: float | None = shared.args.temperature
top_p: float | None = shared.args.top_p
user: str | None = Field(default=None, description="Unused parameter.")
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
@ -228,11 +230,11 @@ class LogitsRequestParams(BaseModel):
prompt: str
use_samplers: bool = False
top_logits: int | None = 50
frequency_penalty: float | None = 0
frequency_penalty: float | None = shared.args.frequency_penalty
max_tokens: int | None = 512
presence_penalty: float | None = 0
temperature: float | None = 1
top_p: float | None = 1
presence_penalty: float | None = shared.args.presence_penalty
temperature: float | None = shared.args.temperature
top_p: float | None = shared.args.top_p
class LogitsRequest(GenerationOptions, LogitsRequestParams):

View file

@ -9,47 +9,50 @@ from modules.loaders import loaders_samplers
from modules.logging_colors import logger
default_preset_values = {
'temperature': 1,
'dynatemp_low': 1,
'dynatemp_high': 1,
'dynatemp_exponent': 1,
'smoothing_factor': 0,
'smoothing_curve': 1,
'min_p': 0,
'top_p': 1,
'top_k': 0,
'typical_p': 1,
'xtc_threshold': 0.1,
'xtc_probability': 0,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'top_n_sigma': 0,
'adaptive_target': 0,
'adaptive_decay': 0.9,
'dry_multiplier': 0,
'dry_allowed_length': 2,
'dry_base': 1.75,
'repetition_penalty': 1,
'frequency_penalty': 0,
'presence_penalty': 0,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'repetition_penalty_range': 1024,
'penalty_alpha': 0,
'guidance_scale': 1,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'do_sample': True,
'dynamic_temperature': False,
'temperature_last': False,
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntop_n_sigma\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nadaptive_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram',
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
}
def default_preset():
result = {
'temperature': 1,
'dynatemp_low': 1,
'dynatemp_high': 1,
'dynatemp_exponent': 1,
'smoothing_factor': 0,
'smoothing_curve': 1,
'min_p': 0,
'top_p': 1,
'top_k': 0,
'typical_p': 1,
'xtc_threshold': 0.1,
'xtc_probability': 0,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'top_n_sigma': 0,
'adaptive_target': 0,
'adaptive_decay': 0.9,
'dry_multiplier': 0,
'dry_allowed_length': 2,
'dry_base': 1.75,
'repetition_penalty': 1,
'frequency_penalty': 0,
'presence_penalty': 0,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'repetition_penalty_range': 1024,
'penalty_alpha': 0,
'guidance_scale': 1,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'do_sample': True,
'dynamic_temperature': False,
'temperature_last': False,
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntop_n_sigma\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nadaptive_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram',
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
}
result = dict(default_preset_values)
if shared.args.portable:
samplers = result['sampler_priority'].split('\n')

View file

@ -10,7 +10,7 @@ import yaml
from modules.logging_colors import logger
from modules.paths import resolve_user_data_dir
from modules.presets import default_preset
from modules.presets import default_preset, default_preset_values
# Resolve user_data directory early (before argparse defaults are set)
user_data_dir = resolve_user_data_dir()
@ -171,6 +171,50 @@ group.add_argument('--api-enable-ipv6', action='store_true', help='Enable IPv6 f
group.add_argument('--api-disable-ipv4', action='store_true', help='Disable IPv4 for the API')
group.add_argument('--nowebui', action='store_true', help='Do not launch the Gradio UI. Useful for launching the API in standalone mode.')
# API generation defaults
_d = default_preset_values
group = parser.add_argument_group('API generation defaults')
group.add_argument('--temperature', type=float, default=_d['temperature'])
group.add_argument('--dynatemp-low', type=float, default=_d['dynatemp_low'])
group.add_argument('--dynatemp-high', type=float, default=_d['dynatemp_high'])
group.add_argument('--dynatemp-exponent', type=float, default=_d['dynatemp_exponent'])
group.add_argument('--smoothing-factor', type=float, default=_d['smoothing_factor'])
group.add_argument('--smoothing-curve', type=float, default=_d['smoothing_curve'])
group.add_argument('--min-p', type=float, default=_d['min_p'])
group.add_argument('--top-p', type=float, default=_d['top_p'])
group.add_argument('--top-k', type=int, default=_d['top_k'])
group.add_argument('--typical-p', type=float, default=_d['typical_p'])
group.add_argument('--xtc-threshold', type=float, default=_d['xtc_threshold'])
group.add_argument('--xtc-probability', type=float, default=_d['xtc_probability'])
group.add_argument('--epsilon-cutoff', type=float, default=_d['epsilon_cutoff'])
group.add_argument('--eta-cutoff', type=float, default=_d['eta_cutoff'])
group.add_argument('--tfs', type=float, default=_d['tfs'])
group.add_argument('--top-a', type=float, default=_d['top_a'])
group.add_argument('--top-n-sigma', type=float, default=_d['top_n_sigma'])
group.add_argument('--adaptive-target', type=float, default=_d['adaptive_target'])
group.add_argument('--adaptive-decay', type=float, default=_d['adaptive_decay'])
group.add_argument('--dry-multiplier', type=float, default=_d['dry_multiplier'])
group.add_argument('--dry-allowed-length', type=int, default=_d['dry_allowed_length'])
group.add_argument('--dry-base', type=float, default=_d['dry_base'])
group.add_argument('--repetition-penalty', type=float, default=_d['repetition_penalty'])
group.add_argument('--frequency-penalty', type=float, default=_d['frequency_penalty'])
group.add_argument('--presence-penalty', type=float, default=_d['presence_penalty'])
group.add_argument('--encoder-repetition-penalty', type=float, default=_d['encoder_repetition_penalty'])
group.add_argument('--no-repeat-ngram-size', type=int, default=_d['no_repeat_ngram_size'])
group.add_argument('--repetition-penalty-range', type=int, default=_d['repetition_penalty_range'])
group.add_argument('--penalty-alpha', type=float, default=_d['penalty_alpha'])
group.add_argument('--guidance-scale', type=float, default=_d['guidance_scale'])
group.add_argument('--mirostat-mode', type=int, default=_d['mirostat_mode'])
group.add_argument('--mirostat-tau', type=float, default=_d['mirostat_tau'])
group.add_argument('--mirostat-eta', type=float, default=_d['mirostat_eta'])
group.add_argument('--do-sample', action=argparse.BooleanOptionalAction, default=_d['do_sample'])
group.add_argument('--dynamic-temperature', action=argparse.BooleanOptionalAction, default=_d['dynamic_temperature'])
group.add_argument('--temperature-last', action=argparse.BooleanOptionalAction, default=_d['temperature_last'])
group.add_argument('--sampler-priority', type=str, default=_d['sampler_priority'])
group.add_argument('--dry-sequence-breakers', type=str, default=_d['dry_sequence_breakers'])
group.add_argument('--enable-thinking', action=argparse.BooleanOptionalAction, default=True)
group.add_argument('--reasoning-effort', type=str, default='medium')
# Handle CMD_FLAGS.txt
cmd_flags_path = user_data_dir / "CMD_FLAGS.txt"
if cmd_flags_path.exists():