diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 5f0e0128..ea688897 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -21,6 +21,7 @@ class GenerationOptions(BaseModel): eta_cutoff: float = 0 tfs: float = 1 top_a: float = 0 + top_n_sigma: float = 0 dry_multiplier: float = 0 dry_allowed_length: int = 2 dry_base: float = 1.75 diff --git a/modules/loaders.py b/modules/loaders.py index cd864e40..88ded1d1 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -137,6 +137,7 @@ def transformers_samplers(): 'eta_cutoff', 'tfs', 'top_a', + 'top_n_sigma', 'dry_multiplier', 'dry_allowed_length', 'dry_base', @@ -224,6 +225,7 @@ loaders_samplers = { 'eta_cutoff', 'tfs', 'top_a', + 'top_n_sigma', 'dry_multiplier', 'dry_allowed_length', 'dry_base', @@ -288,6 +290,7 @@ loaders_samplers = { 'eta_cutoff', 'tfs', 'top_a', + 'top_n_sigma', 'dry_multiplier', 'dry_allowed_length', 'dry_base', diff --git a/modules/presets.py b/modules/presets.py index b841af53..7cab2af0 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -28,6 +28,7 @@ def default_preset(): 'eta_cutoff': 0, 'tfs': 1, 'top_a': 0, + 'top_n_sigma': 0, 'dry_multiplier': 0, 'dry_allowed_length': 2, 'dry_base': 1.75, @@ -45,7 +46,7 @@ def default_preset(): 'do_sample': True, 'dynamic_temperature': False, 'temperature_last': False, - 'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram', + 'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_n_sigma\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram', 'dry_sequence_breakers': '"\\n", ":", "\\"", "*"', } diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index e0df49c3..e6883289 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -5,7 +5,6 @@ import random import torch import transformers -from transformers import LogitsProcessor from transformers.generation.logits_process import ( LogitNormalization, LogitsProcessor, @@ -193,6 +192,46 @@ class TopALogitsWarper(LogitsProcessor): return scores +class TopNSigmaLogitsWarper(LogitsProcessor): + def __init__(self, n_sigma: float = 2.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + """ + Initialize Top-nσ Sampling logits warper. + + Args: + n_sigma: The threshold multiplier for standard deviation + filter_value: Value to assign to filtered logits + min_tokens_to_keep: Minimum number of tokens to keep + """ + if n_sigma < 0: + raise ValueError(f"`n_sigma` must be a non-negative float, but is {n_sigma}") + self.n_sigma = n_sigma + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # Calculate max of logits + max_logit = torch.max(scores, dim=-1, keepdim=True)[0] + + # Calculate standard deviation only on finite values + finite_mask = torch.isfinite(scores) + finite_scores = scores.masked_fill(~finite_mask, 0.0) + std_logit = torch.std(finite_scores, dim=-1, keepdim=True) + + # Create mask where tokens with logits >= max_logit - n_sigma * std_logit are kept + threshold = max_logit - self.n_sigma * std_logit + indices_to_remove = scores < threshold + + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep tokens + top_k_indices = torch.topk(scores, self.min_tokens_to_keep, dim=-1)[1] + indices_to_remove.scatter_(-1, top_k_indices, False) + + # Apply mask by setting filtered tokens to filter_value + scores = scores.masked_fill(indices_to_remove, self.filter_value) + + return scores + + # Exclude Top Choices (XTC) class XTCLogitsWarper(LogitsProcessor): def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")): @@ -525,6 +564,14 @@ def get_logits_processor_patch(self, **kwargs): ) ) + if generation_config.top_n_sigma is not None and generation_config.top_n_sigma > 0.0: + warpers_to_add.append( + TopNSigmaLogitsWarper( + n_sigma=generation_config.top_n_sigma, + min_tokens_to_keep=min_tokens_to_keep + ) + ) + if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0: warpers_to_add.append( XTCLogitsWarper( @@ -589,6 +636,7 @@ def get_logits_processor_patch(self, **kwargs): 'TailFreeLogitsWarper': 'tfs', 'TemperatureLogitsWarperCustom': 'temperature', 'TopALogitsWarper': 'top_a', + 'TopNSigmaLogitsWarper': 'top_n_sigma', 'TopKLogitsWarper': 'top_k', 'TopPLogitsWarper': 'top_p', 'TypicalLogitsWarper': 'typical_p', @@ -636,6 +684,7 @@ def generation_config_init_patch(self, **kwargs): self.smoothing_curve = kwargs.pop("smoothing_curve", 1.0) self.tfs = kwargs.pop("tfs", 1.0) self.top_a = kwargs.pop("top_a", 0.0) + self.top_n_sigma = kwargs.pop("top_n_sigma", 0.0) self.mirostat_mode = kwargs.pop("mirostat_mode", 0) self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1) self.mirostat_tau = kwargs.pop("mirostat_tau", 5) @@ -649,7 +698,7 @@ def generation_config_init_patch(self, **kwargs): self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1) self.xtc_probability = kwargs.pop("xtc_probability", 0) self.temperature_last = kwargs.pop("temperature_last", False) - self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram']) + self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_n_sigma', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram']) def hijack_samplers(): diff --git a/modules/text_generation.py b/modules/text_generation.py index 152b2b8d..eff6495e 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -302,6 +302,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings 'xtc_probability', 'tfs', 'top_a', + 'top_n_sigma', 'dry_multiplier', 'dry_allowed_length', 'dry_base', diff --git a/modules/ui.py b/modules/ui.py index b776e19c..adbb67b0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -183,6 +183,7 @@ def list_interface_input_elements(): 'eta_cutoff', 'tfs', 'top_a', + 'top_n_sigma', 'dry_multiplier', 'dry_allowed_length', 'dry_base', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 265840ed..846fcfe7 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -37,6 +37,7 @@ def create_ui(default_preset): gr.Markdown('## Curve cutoff') shared.gradio['min_p'] = gr.Slider(0.0, 1.0, value=generate_params['min_p'], step=0.01, label='min_p') + shared.gradio['top_n_sigma'] = gr.Slider(0.0, 5.0, value=generate_params['top_n_sigma'], step=0.01, label='top_n_sigma') shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p') shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k') shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')