Add the top N-sigma sampler (#6796)

This commit is contained in:
oobabooga 2025-03-14 16:45:11 -03:00 committed by GitHub
parent 677d74a6a0
commit 5bcd2d7ad0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 60 additions and 3 deletions

View file

@ -5,7 +5,6 @@ import random
import torch
import transformers
from transformers import LogitsProcessor
from transformers.generation.logits_process import (
LogitNormalization,
LogitsProcessor,
@ -193,6 +192,46 @@ class TopALogitsWarper(LogitsProcessor):
return scores
class TopNSigmaLogitsWarper(LogitsProcessor):
def __init__(self, n_sigma: float = 2.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
"""
Initialize Top-nσ Sampling logits warper.
Args:
n_sigma: The threshold multiplier for standard deviation
filter_value: Value to assign to filtered logits
min_tokens_to_keep: Minimum number of tokens to keep
"""
if n_sigma < 0:
raise ValueError(f"`n_sigma` must be a non-negative float, but is {n_sigma}")
self.n_sigma = n_sigma
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Calculate max of logits
max_logit = torch.max(scores, dim=-1, keepdim=True)[0]
# Calculate standard deviation only on finite values
finite_mask = torch.isfinite(scores)
finite_scores = scores.masked_fill(~finite_mask, 0.0)
std_logit = torch.std(finite_scores, dim=-1, keepdim=True)
# Create mask where tokens with logits >= max_logit - n_sigma * std_logit are kept
threshold = max_logit - self.n_sigma * std_logit
indices_to_remove = scores < threshold
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep tokens
top_k_indices = torch.topk(scores, self.min_tokens_to_keep, dim=-1)[1]
indices_to_remove.scatter_(-1, top_k_indices, False)
# Apply mask by setting filtered tokens to filter_value
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
# Exclude Top Choices (XTC)
class XTCLogitsWarper(LogitsProcessor):
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
@ -525,6 +564,14 @@ def get_logits_processor_patch(self, **kwargs):
)
)
if generation_config.top_n_sigma is not None and generation_config.top_n_sigma > 0.0:
warpers_to_add.append(
TopNSigmaLogitsWarper(
n_sigma=generation_config.top_n_sigma,
min_tokens_to_keep=min_tokens_to_keep
)
)
if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0:
warpers_to_add.append(
XTCLogitsWarper(
@ -589,6 +636,7 @@ def get_logits_processor_patch(self, **kwargs):
'TailFreeLogitsWarper': 'tfs',
'TemperatureLogitsWarperCustom': 'temperature',
'TopALogitsWarper': 'top_a',
'TopNSigmaLogitsWarper': 'top_n_sigma',
'TopKLogitsWarper': 'top_k',
'TopPLogitsWarper': 'top_p',
'TypicalLogitsWarper': 'typical_p',
@ -636,6 +684,7 @@ def generation_config_init_patch(self, **kwargs):
self.smoothing_curve = kwargs.pop("smoothing_curve", 1.0)
self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0)
self.top_n_sigma = kwargs.pop("top_n_sigma", 0.0)
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
@ -649,7 +698,7 @@ def generation_config_init_patch(self, **kwargs):
self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
self.xtc_probability = kwargs.pop("xtc_probability", 0)
self.temperature_last = kwargs.pop("temperature_last", False)
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_n_sigma', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
def hijack_samplers():