mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-04 14:17:28 +00:00
Quadratic sampling (#5403)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
e98d1086f5
commit
b6077b02e4
8 changed files with 45 additions and 15 deletions
|
|
@ -15,8 +15,12 @@ from modules import shared
|
|||
global_scores = None
|
||||
|
||||
|
||||
class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
||||
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float):
|
||||
class ModifiedTemperatureLogitsWarper(LogitsWarper):
|
||||
'''
|
||||
Based on the original Transformers temperature logits warper, this
|
||||
adds support for dynamic temperature and quadratic sampling.
|
||||
'''
|
||||
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float, smoothing_factor: float):
|
||||
if not isinstance(temperature, float) or not (temperature > 0):
|
||||
except_msg = (
|
||||
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
|
||||
|
|
@ -32,16 +36,27 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
|||
self.dynatemp_low = dynatemp_low
|
||||
self.dynatemp_high = dynatemp_high
|
||||
self.dynatemp_exponent = dynatemp_exponent
|
||||
self.smoothing_factor = smoothing_factor
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
# Regular temperature
|
||||
if not self.dynamic_temperature:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
# Quadratic sampling
|
||||
if self.smoothing_factor > 0:
|
||||
|
||||
# Compute the maximum logit value
|
||||
max_logit = scores.max()
|
||||
|
||||
# Apply the quadratic transformation
|
||||
transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit
|
||||
|
||||
# No need to print the top 5 logits since this is not required
|
||||
# print("Original top 5 logits: ", torch.topk(scores, 5))
|
||||
# print("New top 5 logits: ", torch.topk(transformed_logits, 5))
|
||||
|
||||
return transformed_logits
|
||||
|
||||
# Dynamic temperature
|
||||
else:
|
||||
elif self.dynamic_temperature:
|
||||
min_temp = self.dynatemp_low
|
||||
max_temp = self.dynatemp_high
|
||||
exponent_val = self.dynatemp_exponent
|
||||
|
|
@ -88,6 +103,11 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
|||
|
||||
return scores
|
||||
|
||||
# Regular temperature
|
||||
else:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
|
||||
|
||||
class MinPLogitsWarper(LogitsWarper):
|
||||
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
|
|
@ -286,7 +306,7 @@ def get_logits_warper_patch(self, generation_config):
|
|||
generation_config.temperature = float(generation_config.temperature)
|
||||
|
||||
temperature = generation_config.temperature
|
||||
if generation_config.dynamic_temperature:
|
||||
if generation_config.dynamic_temperature or generation_config.smoothing_factor > 0:
|
||||
# Make sure TemperatureLogitsWarper will be created by temporarily
|
||||
# setting temperature to a value != 1.
|
||||
generation_config.temperature = 1.1
|
||||
|
|
@ -294,12 +314,13 @@ def get_logits_warper_patch(self, generation_config):
|
|||
warpers = self._get_logits_warper_old(generation_config)
|
||||
for i in range(len(warpers)):
|
||||
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
||||
warpers[i] = TemperatureLogitsWarperWithDynatemp(
|
||||
warpers[i] = ModifiedTemperatureLogitsWarper(
|
||||
temperature,
|
||||
generation_config.dynamic_temperature,
|
||||
generation_config.dynatemp_low,
|
||||
generation_config.dynatemp_high,
|
||||
generation_config.dynatemp_exponent
|
||||
generation_config.dynatemp_exponent,
|
||||
generation_config.smoothing_factor
|
||||
)
|
||||
|
||||
warpers_to_add = LogitsProcessorList()
|
||||
|
|
@ -328,7 +349,7 @@ def get_logits_warper_patch(self, generation_config):
|
|||
if generation_config.temperature_last:
|
||||
temperature_idx = None
|
||||
for i in range(len(warpers)):
|
||||
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'TemperatureLogitsWarperWithDynatemp']:
|
||||
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'ModifiedTemperatureLogitsWarper']:
|
||||
temperature_idx = i
|
||||
break
|
||||
|
||||
|
|
@ -352,8 +373,7 @@ def get_logits_processor_patch(self, **kwargs):
|
|||
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
|
||||
do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0)
|
||||
if do_rep_pen_hijack:
|
||||
# Make sure that a RepetitionPenaltyLogitsProcessor will be created
|
||||
kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1
|
||||
kwargs['generation_config'].repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
|
||||
|
||||
result = self._get_logits_processor_old(**kwargs)
|
||||
|
||||
|
|
@ -372,6 +392,7 @@ def generation_config_init_patch(self, **kwargs):
|
|||
self.dynatemp_low = kwargs.pop("dynatemp_low", 1)
|
||||
self.dynatemp_high = kwargs.pop("dynatemp_high", 1)
|
||||
self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1)
|
||||
self.smoothing_factor = kwargs.pop("smoothing_factor", 0.0)
|
||||
self.tfs = kwargs.pop("tfs", 1.0)
|
||||
self.top_a = kwargs.pop("top_a", 0.0)
|
||||
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue