Quadratic sampling (#5403)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
kalomaze 2024-02-03 21:20:02 -06:00 committed by GitHub
parent e98d1086f5
commit b6077b02e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 45 additions and 15 deletions

View file

@ -15,8 +15,12 @@ from modules import shared
global_scores = None
class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float):
class ModifiedTemperatureLogitsWarper(LogitsWarper):
'''
Based on the original Transformers temperature logits warper, this
adds support for dynamic temperature and quadratic sampling.
'''
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float, smoothing_factor: float):
if not isinstance(temperature, float) or not (temperature > 0):
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
@ -32,16 +36,27 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
self.dynatemp_low = dynatemp_low
self.dynatemp_high = dynatemp_high
self.dynatemp_exponent = dynatemp_exponent
self.smoothing_factor = smoothing_factor
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Regular temperature
if not self.dynamic_temperature:
scores = scores / self.temperature
return scores
# Quadratic sampling
if self.smoothing_factor > 0:
# Compute the maximum logit value
max_logit = scores.max()
# Apply the quadratic transformation
transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit
# No need to print the top 5 logits since this is not required
# print("Original top 5 logits: ", torch.topk(scores, 5))
# print("New top 5 logits: ", torch.topk(transformed_logits, 5))
return transformed_logits
# Dynamic temperature
else:
elif self.dynamic_temperature:
min_temp = self.dynatemp_low
max_temp = self.dynatemp_high
exponent_val = self.dynatemp_exponent
@ -88,6 +103,11 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
return scores
# Regular temperature
else:
scores = scores / self.temperature
return scores
class MinPLogitsWarper(LogitsWarper):
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
@ -286,7 +306,7 @@ def get_logits_warper_patch(self, generation_config):
generation_config.temperature = float(generation_config.temperature)
temperature = generation_config.temperature
if generation_config.dynamic_temperature:
if generation_config.dynamic_temperature or generation_config.smoothing_factor > 0:
# Make sure TemperatureLogitsWarper will be created by temporarily
# setting temperature to a value != 1.
generation_config.temperature = 1.1
@ -294,12 +314,13 @@ def get_logits_warper_patch(self, generation_config):
warpers = self._get_logits_warper_old(generation_config)
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
warpers[i] = TemperatureLogitsWarperWithDynatemp(
warpers[i] = ModifiedTemperatureLogitsWarper(
temperature,
generation_config.dynamic_temperature,
generation_config.dynatemp_low,
generation_config.dynatemp_high,
generation_config.dynatemp_exponent
generation_config.dynatemp_exponent,
generation_config.smoothing_factor
)
warpers_to_add = LogitsProcessorList()
@ -328,7 +349,7 @@ def get_logits_warper_patch(self, generation_config):
if generation_config.temperature_last:
temperature_idx = None
for i in range(len(warpers)):
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'TemperatureLogitsWarperWithDynatemp']:
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'ModifiedTemperatureLogitsWarper']:
temperature_idx = i
break
@ -352,8 +373,7 @@ def get_logits_processor_patch(self, **kwargs):
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0)
if do_rep_pen_hijack:
# Make sure that a RepetitionPenaltyLogitsProcessor will be created
kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1
kwargs['generation_config'].repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
result = self._get_logits_processor_old(**kwargs)
@ -372,6 +392,7 @@ def generation_config_init_patch(self, **kwargs):
self.dynatemp_low = kwargs.pop("dynatemp_low", 1)
self.dynatemp_high = kwargs.pop("dynatemp_high", 1)
self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1)
self.smoothing_factor = kwargs.pop("smoothing_factor", 0.0)
self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0)
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)