mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-20 22:13:43 +00:00
Add adaptive-p sampler and n-gram speculative decoding support
This commit is contained in:
parent
f010aa1612
commit
65de4c30c8
10 changed files with 145 additions and 3 deletions
|
|
@ -235,6 +235,73 @@ class TopNSigmaLogitsWarper(LogitsProcessor):
|
|||
return scores
|
||||
|
||||
|
||||
class AdaptivePLogitsWarper(LogitsProcessor):
|
||||
'''
|
||||
Adaptive-p sampling. A stateful sampler that favors tokens near a target
|
||||
probability, using an EMA-based control loop to adapt over time.
|
||||
|
||||
Matches the llama.cpp implementation from PR #17927.
|
||||
'''
|
||||
|
||||
DISTRIBUTION_WIDTH = 0.3
|
||||
PEAK_LOGIT_VALUE = 5.0
|
||||
SHARPNESS = 10.0
|
||||
INV_WIDTH = 1.0 / DISTRIBUTION_WIDTH
|
||||
|
||||
def __init__(self, adaptive_target, adaptive_decay, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||
self.target = adaptive_target
|
||||
self.decay = min(adaptive_decay, 0.99)
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
# Initialize EMA at equilibrium (as if target was already achieved)
|
||||
if self.decay < 1.0:
|
||||
self.weighted_sum = self.target / (1.0 - self.decay)
|
||||
self.total_weight = 1.0 / (1.0 - self.decay)
|
||||
else:
|
||||
self.weighted_sum = 0.0
|
||||
self.total_weight = 0.0
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
logits = scores[0]
|
||||
|
||||
# Compute original probabilities (before transform)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
# Compute adapted target using proportional control on the EMA
|
||||
if self.total_weight > 0:
|
||||
ema_avg = self.weighted_sum / self.total_weight
|
||||
else:
|
||||
ema_avg = self.target
|
||||
|
||||
adapted_target = max(0.0, min(1.0, 2.0 * self.target - ema_avg))
|
||||
|
||||
# Adaptive probability transform:
|
||||
# quadratic near target for fine differentiation, transitioning
|
||||
# to linear decay in the tails for proper suppression after softmax
|
||||
dist = torch.abs((probs - adapted_target) * self.INV_WIDTH)
|
||||
new_logits = self.PEAK_LOGIT_VALUE - self.SHARPNESS * dist * dist / (1.0 + dist)
|
||||
|
||||
# Preserve already-masked tokens (-inf logits from prior samplers)
|
||||
new_logits = torch.where(torch.isfinite(logits), new_logits, logits)
|
||||
|
||||
# Softmax and sample from the transformed distribution
|
||||
new_probs = torch.softmax(new_logits, dim=-1)
|
||||
selected = torch.multinomial(new_probs, num_samples=1, replacement=True)
|
||||
|
||||
# Update EMA with the original probability of the selected token
|
||||
original_prob = probs[selected[0]].item()
|
||||
self.weighted_sum = original_prob + self.decay * self.weighted_sum
|
||||
self.total_weight = 1.0 + self.decay * self.total_weight
|
||||
|
||||
# Mask all tokens except the selected one
|
||||
indices_to_remove = torch.ones_like(scores[0], dtype=torch.bool)
|
||||
indices_to_remove[selected[0]] = False
|
||||
indices_to_remove = indices_to_remove.unsqueeze(0)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
|
||||
# Exclude Top Choices (XTC)
|
||||
class XTCLogitsWarper(LogitsProcessor):
|
||||
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
|
||||
|
|
@ -575,6 +642,15 @@ def get_logits_processor_patch(self, **kwargs):
|
|||
)
|
||||
)
|
||||
|
||||
if generation_config.adaptive_target is not None and generation_config.adaptive_target > 0.0:
|
||||
warpers_to_add.append(
|
||||
AdaptivePLogitsWarper(
|
||||
adaptive_target=generation_config.adaptive_target,
|
||||
adaptive_decay=generation_config.adaptive_decay,
|
||||
min_tokens_to_keep=min_tokens_to_keep
|
||||
)
|
||||
)
|
||||
|
||||
if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0:
|
||||
warpers_to_add.append(
|
||||
XTCLogitsWarper(
|
||||
|
|
@ -640,6 +716,7 @@ def get_logits_processor_patch(self, **kwargs):
|
|||
'TemperatureLogitsWarperCustom': 'temperature',
|
||||
'TopALogitsWarper': 'top_a',
|
||||
'TopNSigmaLogitsWarper': 'top_n_sigma',
|
||||
'AdaptivePLogitsWarper': 'adaptive_p',
|
||||
'TopKLogitsWarper': 'top_k',
|
||||
'TopPLogitsWarper': 'top_p',
|
||||
'TypicalLogitsWarper': 'typical_p',
|
||||
|
|
@ -688,6 +765,8 @@ def generation_config_init_patch(self, **kwargs):
|
|||
self.tfs = kwargs.pop("tfs", 1.0)
|
||||
self.top_a = kwargs.pop("top_a", 0.0)
|
||||
self.top_n_sigma = kwargs.pop("top_n_sigma", 0.0)
|
||||
self.adaptive_target = kwargs.pop("adaptive_target", 0.0)
|
||||
self.adaptive_decay = kwargs.pop("adaptive_decay", 0.9)
|
||||
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
|
||||
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
|
||||
|
|
@ -701,7 +780,7 @@ def generation_config_init_patch(self, **kwargs):
|
|||
self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
|
||||
self.xtc_probability = kwargs.pop("xtc_probability", 0)
|
||||
self.temperature_last = kwargs.pop("temperature_last", False)
|
||||
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_n_sigma', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
|
||||
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_n_sigma', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'adaptive_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
|
||||
|
||||
|
||||
def hijack_samplers():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue