2023-06-14 01:34:35 +02:00
|
|
|
import functools
|
2024-02-06 17:51:34 +01:00
|
|
|
import pprint
|
2023-06-14 01:34:35 +02:00
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
|
|
2023-11-18 22:31:41 +01:00
|
|
|
from modules import shared
|
|
|
|
|
from modules.loaders import loaders_samplers
|
2024-01-09 03:28:35 +01:00
|
|
|
from modules.logging_colors import logger
|
2023-11-18 22:31:41 +01:00
|
|
|
|
2023-06-14 01:34:35 +02:00
|
|
|
|
2023-08-01 04:13:29 +02:00
|
|
|
def default_preset():
|
2025-05-03 01:32:22 +02:00
|
|
|
result = {
|
2023-06-14 01:34:35 +02:00
|
|
|
'temperature': 1,
|
2024-01-09 03:28:35 +01:00
|
|
|
'dynatemp_low': 1,
|
|
|
|
|
'dynatemp_high': 1,
|
|
|
|
|
'dynatemp_exponent': 1,
|
2024-02-04 04:20:02 +01:00
|
|
|
'smoothing_factor': 0,
|
2024-03-03 17:22:21 +01:00
|
|
|
'smoothing_curve': 1,
|
2023-11-03 16:25:22 +01:00
|
|
|
'min_p': 0,
|
2025-01-10 22:04:32 +01:00
|
|
|
'top_p': 1,
|
2023-08-06 22:22:48 +02:00
|
|
|
'top_k': 0,
|
2023-11-06 06:38:29 +01:00
|
|
|
'typical_p': 1,
|
2025-01-10 22:04:32 +01:00
|
|
|
'xtc_threshold': 0.1,
|
|
|
|
|
'xtc_probability': 0,
|
2023-11-06 06:38:29 +01:00
|
|
|
'epsilon_cutoff': 0,
|
|
|
|
|
'eta_cutoff': 0,
|
2025-01-10 22:04:32 +01:00
|
|
|
'tfs': 1,
|
|
|
|
|
'top_a': 0,
|
2025-03-14 20:45:11 +01:00
|
|
|
'top_n_sigma': 0,
|
2025-01-10 22:04:32 +01:00
|
|
|
'dry_multiplier': 0,
|
|
|
|
|
'dry_allowed_length': 2,
|
|
|
|
|
'dry_base': 1.75,
|
|
|
|
|
'repetition_penalty': 1,
|
|
|
|
|
'frequency_penalty': 0,
|
|
|
|
|
'presence_penalty': 0,
|
|
|
|
|
'encoder_repetition_penalty': 1,
|
|
|
|
|
'no_repeat_ngram_size': 0,
|
|
|
|
|
'repetition_penalty_range': 1024,
|
2023-11-06 06:38:29 +01:00
|
|
|
'penalty_alpha': 0,
|
2025-01-10 22:04:32 +01:00
|
|
|
'guidance_scale': 1,
|
2023-06-14 01:34:35 +02:00
|
|
|
'mirostat_mode': 0,
|
2023-11-06 06:38:29 +01:00
|
|
|
'mirostat_tau': 5,
|
2023-06-14 01:34:35 +02:00
|
|
|
'mirostat_eta': 0.1,
|
2023-11-06 06:38:29 +01:00
|
|
|
'do_sample': True,
|
2025-01-10 22:04:32 +01:00
|
|
|
'dynamic_temperature': False,
|
|
|
|
|
'temperature_last': False,
|
2025-05-06 17:27:21 +02:00
|
|
|
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntop_n_sigma\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram',
|
2024-05-20 04:53:47 +02:00
|
|
|
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
|
2023-06-14 01:34:35 +02:00
|
|
|
}
|
|
|
|
|
|
2025-05-03 01:32:22 +02:00
|
|
|
if shared.args.portable:
|
|
|
|
|
samplers = result['sampler_priority'].split('\n')
|
2025-05-06 15:38:39 +02:00
|
|
|
samplers = [sampler for sampler in samplers if sampler in ["dry", "top_k", "top_p", "top_n_sigma", "min_p", "temperature", "xtc", "typical_p", "repetition_penalty"]]
|
2025-05-03 01:32:22 +02:00
|
|
|
result['sampler_priority'] = '\n'.join(samplers)
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
2023-08-01 04:13:29 +02:00
|
|
|
|
2023-08-06 22:22:48 +02:00
|
|
|
def presets_params():
|
|
|
|
|
return [k for k in default_preset()]
|
|
|
|
|
|
|
|
|
|
|
2024-03-28 20:45:03 +01:00
|
|
|
def load_preset(name, verbose=False):
|
2023-08-01 04:13:29 +02:00
|
|
|
generate_params = default_preset()
|
2023-07-04 05:03:30 +02:00
|
|
|
if name not in ['None', None, '']:
|
2025-04-26 13:56:54 +02:00
|
|
|
path = Path(f'user_data/presets/{name}.yaml')
|
2024-01-09 03:28:35 +01:00
|
|
|
if path.exists():
|
|
|
|
|
with open(path, 'r') as infile:
|
|
|
|
|
preset = yaml.safe_load(infile)
|
|
|
|
|
|
|
|
|
|
for k in preset:
|
|
|
|
|
generate_params[k] = preset[k]
|
|
|
|
|
else:
|
|
|
|
|
logger.error(f"The preset \"{name}\" does not exist under \"{path}\". Using the default parameters.")
|
2023-06-14 01:34:35 +02:00
|
|
|
|
2024-03-28 20:45:03 +01:00
|
|
|
if verbose:
|
|
|
|
|
logger.info(f"\"{name}\" preset:")
|
|
|
|
|
pprint.PrettyPrinter(indent=4, width=1, sort_dicts=False).pprint(remove_defaults(generate_params))
|
|
|
|
|
|
2023-06-14 01:34:35 +02:00
|
|
|
return generate_params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@functools.cache
|
|
|
|
|
def load_preset_memoized(name):
|
|
|
|
|
return load_preset(name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_preset_for_ui(name, state):
|
2024-03-28 20:45:03 +01:00
|
|
|
generate_params = load_preset(name, verbose=True)
|
2023-06-14 01:34:35 +02:00
|
|
|
state.update(generate_params)
|
2023-08-06 22:22:48 +02:00
|
|
|
return state, *[generate_params[k] for k in presets_params()]
|
2023-06-14 01:34:35 +02:00
|
|
|
|
|
|
|
|
|
2025-06-08 06:58:02 +02:00
|
|
|
def reset_preset_for_ui(name, state):
|
|
|
|
|
"""Reset current preset to its saved values from file"""
|
|
|
|
|
generate_params = load_preset(name, verbose=True)
|
|
|
|
|
state.update(generate_params)
|
|
|
|
|
return state, *[generate_params[k] for k in presets_params()]
|
2023-11-18 22:31:41 +01:00
|
|
|
|
|
|
|
|
|
2025-06-08 06:58:02 +02:00
|
|
|
def neutralize_samplers_for_ui(state):
|
|
|
|
|
"""Set all samplers to their default/neutral values"""
|
|
|
|
|
generate_params = default_preset()
|
2023-11-18 22:31:41 +01:00
|
|
|
state.update(generate_params)
|
|
|
|
|
return state, *[generate_params[k] for k in presets_params()]
|
|
|
|
|
|
|
|
|
|
|
2024-02-06 17:51:34 +01:00
|
|
|
def loader_contains(sampler):
|
|
|
|
|
if sampler == 'dynamic_temperature' and 'dynatemp_low' in loaders_samplers[shared.args.loader]:
|
|
|
|
|
return True
|
|
|
|
|
else:
|
|
|
|
|
return sampler in loaders_samplers[shared.args.loader]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_defaults(state):
|
2023-08-01 04:13:29 +02:00
|
|
|
defaults = default_preset()
|
2023-08-06 22:22:48 +02:00
|
|
|
data = {k: state[k] for k in presets_params()}
|
2023-08-01 04:13:29 +02:00
|
|
|
|
|
|
|
|
for k in list(data.keys()):
|
2024-02-06 17:51:34 +01:00
|
|
|
if data[k] == defaults[k]:
|
2023-08-01 04:13:29 +02:00
|
|
|
del data[k]
|
|
|
|
|
|
2024-02-06 17:51:34 +01:00
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_preset_yaml(state):
|
|
|
|
|
data = remove_defaults(state)
|
2023-06-14 01:34:35 +02:00
|
|
|
return yaml.dump(data, sort_keys=False)
|