mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-06 00:30:13 +01:00
Add the top N-sigma sampler (#6796)
This commit is contained in:
parent
677d74a6a0
commit
5bcd2d7ad0
|
|
@ -21,6 +21,7 @@ class GenerationOptions(BaseModel):
|
|||
eta_cutoff: float = 0
|
||||
tfs: float = 1
|
||||
top_a: float = 0
|
||||
top_n_sigma: float = 0
|
||||
dry_multiplier: float = 0
|
||||
dry_allowed_length: int = 2
|
||||
dry_base: float = 1.75
|
||||
|
|
|
|||
|
|
@ -137,6 +137,7 @@ def transformers_samplers():
|
|||
'eta_cutoff',
|
||||
'tfs',
|
||||
'top_a',
|
||||
'top_n_sigma',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
|
|
@ -224,6 +225,7 @@ loaders_samplers = {
|
|||
'eta_cutoff',
|
||||
'tfs',
|
||||
'top_a',
|
||||
'top_n_sigma',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
|
|
@ -288,6 +290,7 @@ loaders_samplers = {
|
|||
'eta_cutoff',
|
||||
'tfs',
|
||||
'top_a',
|
||||
'top_n_sigma',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ def default_preset():
|
|||
'eta_cutoff': 0,
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'top_n_sigma': 0,
|
||||
'dry_multiplier': 0,
|
||||
'dry_allowed_length': 2,
|
||||
'dry_base': 1.75,
|
||||
|
|
@ -45,7 +46,7 @@ def default_preset():
|
|||
'do_sample': True,
|
||||
'dynamic_temperature': False,
|
||||
'temperature_last': False,
|
||||
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram',
|
||||
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_n_sigma\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram',
|
||||
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import random
|
|||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import LogitsProcessor
|
||||
from transformers.generation.logits_process import (
|
||||
LogitNormalization,
|
||||
LogitsProcessor,
|
||||
|
|
@ -193,6 +192,46 @@ class TopALogitsWarper(LogitsProcessor):
|
|||
return scores
|
||||
|
||||
|
||||
class TopNSigmaLogitsWarper(LogitsProcessor):
|
||||
def __init__(self, n_sigma: float = 2.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
"""
|
||||
Initialize Top-nσ Sampling logits warper.
|
||||
|
||||
Args:
|
||||
n_sigma: The threshold multiplier for standard deviation
|
||||
filter_value: Value to assign to filtered logits
|
||||
min_tokens_to_keep: Minimum number of tokens to keep
|
||||
"""
|
||||
if n_sigma < 0:
|
||||
raise ValueError(f"`n_sigma` must be a non-negative float, but is {n_sigma}")
|
||||
self.n_sigma = n_sigma
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Calculate max of logits
|
||||
max_logit = torch.max(scores, dim=-1, keepdim=True)[0]
|
||||
|
||||
# Calculate standard deviation only on finite values
|
||||
finite_mask = torch.isfinite(scores)
|
||||
finite_scores = scores.masked_fill(~finite_mask, 0.0)
|
||||
std_logit = torch.std(finite_scores, dim=-1, keepdim=True)
|
||||
|
||||
# Create mask where tokens with logits >= max_logit - n_sigma * std_logit are kept
|
||||
threshold = max_logit - self.n_sigma * std_logit
|
||||
indices_to_remove = scores < threshold
|
||||
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep tokens
|
||||
top_k_indices = torch.topk(scores, self.min_tokens_to_keep, dim=-1)[1]
|
||||
indices_to_remove.scatter_(-1, top_k_indices, False)
|
||||
|
||||
# Apply mask by setting filtered tokens to filter_value
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
# Exclude Top Choices (XTC)
|
||||
class XTCLogitsWarper(LogitsProcessor):
|
||||
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
|
||||
|
|
@ -525,6 +564,14 @@ def get_logits_processor_patch(self, **kwargs):
|
|||
)
|
||||
)
|
||||
|
||||
if generation_config.top_n_sigma is not None and generation_config.top_n_sigma > 0.0:
|
||||
warpers_to_add.append(
|
||||
TopNSigmaLogitsWarper(
|
||||
n_sigma=generation_config.top_n_sigma,
|
||||
min_tokens_to_keep=min_tokens_to_keep
|
||||
)
|
||||
)
|
||||
|
||||
if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0:
|
||||
warpers_to_add.append(
|
||||
XTCLogitsWarper(
|
||||
|
|
@ -589,6 +636,7 @@ def get_logits_processor_patch(self, **kwargs):
|
|||
'TailFreeLogitsWarper': 'tfs',
|
||||
'TemperatureLogitsWarperCustom': 'temperature',
|
||||
'TopALogitsWarper': 'top_a',
|
||||
'TopNSigmaLogitsWarper': 'top_n_sigma',
|
||||
'TopKLogitsWarper': 'top_k',
|
||||
'TopPLogitsWarper': 'top_p',
|
||||
'TypicalLogitsWarper': 'typical_p',
|
||||
|
|
@ -636,6 +684,7 @@ def generation_config_init_patch(self, **kwargs):
|
|||
self.smoothing_curve = kwargs.pop("smoothing_curve", 1.0)
|
||||
self.tfs = kwargs.pop("tfs", 1.0)
|
||||
self.top_a = kwargs.pop("top_a", 0.0)
|
||||
self.top_n_sigma = kwargs.pop("top_n_sigma", 0.0)
|
||||
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
|
||||
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
|
||||
|
|
@ -649,7 +698,7 @@ def generation_config_init_patch(self, **kwargs):
|
|||
self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
|
||||
self.xtc_probability = kwargs.pop("xtc_probability", 0)
|
||||
self.temperature_last = kwargs.pop("temperature_last", False)
|
||||
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
|
||||
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_n_sigma', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
|
||||
|
||||
|
||||
def hijack_samplers():
|
||||
|
|
|
|||
|
|
@ -302,6 +302,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||
'xtc_probability',
|
||||
'tfs',
|
||||
'top_a',
|
||||
'top_n_sigma',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
|
|
|
|||
|
|
@ -183,6 +183,7 @@ def list_interface_input_elements():
|
|||
'eta_cutoff',
|
||||
'tfs',
|
||||
'top_a',
|
||||
'top_n_sigma',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ def create_ui(default_preset):
|
|||
|
||||
gr.Markdown('## Curve cutoff')
|
||||
shared.gradio['min_p'] = gr.Slider(0.0, 1.0, value=generate_params['min_p'], step=0.01, label='min_p')
|
||||
shared.gradio['top_n_sigma'] = gr.Slider(0.0, 5.0, value=generate_params['top_n_sigma'], step=0.01, label='top_n_sigma')
|
||||
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
|
||||
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
|
||||
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
|
||||
|
|
|
|||
Loading…
Reference in a new issue