text-generation-webui/modules/presets.py

130 lines
3.9 KiB
Python
Raw Permalink Normal View History

import functools
2024-02-06 17:51:34 +01:00
import pprint
from pathlib import Path
import yaml
from modules import shared
from modules.loaders import loaders_samplers
from modules.logging_colors import logger
def default_preset():
result = {
'temperature': 1,
'dynatemp_low': 1,
'dynatemp_high': 1,
'dynatemp_exponent': 1,
'smoothing_factor': 0,
'smoothing_curve': 1,
2023-11-03 16:25:22 +01:00
'min_p': 0,
2025-01-10 22:04:32 +01:00
'top_p': 1,
'top_k': 0,
'typical_p': 1,
2025-01-10 22:04:32 +01:00
'xtc_threshold': 0.1,
'xtc_probability': 0,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
2025-01-10 22:04:32 +01:00
'tfs': 1,
'top_a': 0,
2025-03-14 20:45:11 +01:00
'top_n_sigma': 0,
2025-01-10 22:04:32 +01:00
'dry_multiplier': 0,
'dry_allowed_length': 2,
'dry_base': 1.75,
'repetition_penalty': 1,
'frequency_penalty': 0,
'presence_penalty': 0,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'repetition_penalty_range': 1024,
'penalty_alpha': 0,
2025-01-10 22:04:32 +01:00
'guidance_scale': 1,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'do_sample': True,
2025-01-10 22:04:32 +01:00
'dynamic_temperature': False,
'temperature_last': False,
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntop_n_sigma\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram',
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
}
if shared.args.portable:
samplers = result['sampler_priority'].split('\n')
samplers = [sampler for sampler in samplers if sampler in ["dry", "top_k", "top_p", "top_n_sigma", "min_p", "temperature", "xtc", "typical_p", "repetition_penalty"]]
result['sampler_priority'] = '\n'.join(samplers)
return result
def presets_params():
return [k for k in default_preset()]
2024-03-28 20:45:03 +01:00
def load_preset(name, verbose=False):
generate_params = default_preset()
if name not in ['None', None, '']:
2025-04-26 13:56:54 +02:00
path = Path(f'user_data/presets/{name}.yaml')
if path.exists():
with open(path, 'r') as infile:
preset = yaml.safe_load(infile)
for k in preset:
generate_params[k] = preset[k]
else:
logger.error(f"The preset \"{name}\" does not exist under \"{path}\". Using the default parameters.")
2024-03-28 20:45:03 +01:00
if verbose:
logger.info(f"\"{name}\" preset:")
pprint.PrettyPrinter(indent=4, width=1, sort_dicts=False).pprint(remove_defaults(generate_params))
return generate_params
@functools.cache
def load_preset_memoized(name):
return load_preset(name)
def load_preset_for_ui(name, state):
2024-03-28 20:45:03 +01:00
generate_params = load_preset(name, verbose=True)
state.update(generate_params)
return state, *[generate_params[k] for k in presets_params()]
def reset_preset_for_ui(name, state):
"""Reset current preset to its saved values from file"""
generate_params = load_preset(name, verbose=True)
state.update(generate_params)
return state, *[generate_params[k] for k in presets_params()]
def neutralize_samplers_for_ui(state):
"""Set all samplers to their default/neutral values"""
generate_params = default_preset()
state.update(generate_params)
return state, *[generate_params[k] for k in presets_params()]
2024-02-06 17:51:34 +01:00
def loader_contains(sampler):
if sampler == 'dynamic_temperature' and 'dynatemp_low' in loaders_samplers[shared.args.loader]:
return True
else:
return sampler in loaders_samplers[shared.args.loader]
def remove_defaults(state):
defaults = default_preset()
data = {k: state[k] for k in presets_params()}
for k in list(data.keys()):
2024-02-06 17:51:34 +01:00
if data[k] == defaults[k]:
del data[k]
2024-02-06 17:51:34 +01:00
return data
def generate_preset_yaml(state):
data = remove_defaults(state)
return yaml.dump(data, sort_keys=False)