Quadratic sampling (#5403)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
kalomaze 2024-02-03 21:20:02 -06:00 committed by GitHub
parent e98d1086f5
commit b6077b02e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 45 additions and 15 deletions

View file

@ -159,6 +159,7 @@ def transformers_samplers():
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'top_p',
'min_p',
'top_k',
@ -233,6 +234,7 @@ loaders_samplers = {
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'top_p',
'min_p',
'top_k',
@ -289,6 +291,7 @@ loaders_samplers = {
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'top_p',
'min_p',
'top_k',

View file

@ -17,6 +17,7 @@ def default_preset():
'dynatemp_low': 1,
'dynatemp_high': 1,
'dynatemp_exponent': 1,
'smoothing_factor': 0,
'top_p': 1,
'min_p': 0,
'top_k': 0,

View file

@ -15,8 +15,12 @@ from modules import shared
global_scores = None
class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float):
class ModifiedTemperatureLogitsWarper(LogitsWarper):
'''
Based on the original Transformers temperature logits warper, this
adds support for dynamic temperature and quadratic sampling.
'''
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float, smoothing_factor: float):
if not isinstance(temperature, float) or not (temperature > 0):
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
@ -32,16 +36,27 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
self.dynatemp_low = dynatemp_low
self.dynatemp_high = dynatemp_high
self.dynatemp_exponent = dynatemp_exponent
self.smoothing_factor = smoothing_factor
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Regular temperature
if not self.dynamic_temperature:
scores = scores / self.temperature
return scores
# Quadratic sampling
if self.smoothing_factor > 0:
# Compute the maximum logit value
max_logit = scores.max()
# Apply the quadratic transformation
transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit
# No need to print the top 5 logits since this is not required
# print("Original top 5 logits: ", torch.topk(scores, 5))
# print("New top 5 logits: ", torch.topk(transformed_logits, 5))
return transformed_logits
# Dynamic temperature
else:
elif self.dynamic_temperature:
min_temp = self.dynatemp_low
max_temp = self.dynatemp_high
exponent_val = self.dynatemp_exponent
@ -88,6 +103,11 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
return scores
# Regular temperature
else:
scores = scores / self.temperature
return scores
class MinPLogitsWarper(LogitsWarper):
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
@ -286,7 +306,7 @@ def get_logits_warper_patch(self, generation_config):
generation_config.temperature = float(generation_config.temperature)
temperature = generation_config.temperature
if generation_config.dynamic_temperature:
if generation_config.dynamic_temperature or generation_config.smoothing_factor > 0:
# Make sure TemperatureLogitsWarper will be created by temporarily
# setting temperature to a value != 1.
generation_config.temperature = 1.1
@ -294,12 +314,13 @@ def get_logits_warper_patch(self, generation_config):
warpers = self._get_logits_warper_old(generation_config)
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
warpers[i] = TemperatureLogitsWarperWithDynatemp(
warpers[i] = ModifiedTemperatureLogitsWarper(
temperature,
generation_config.dynamic_temperature,
generation_config.dynatemp_low,
generation_config.dynatemp_high,
generation_config.dynatemp_exponent
generation_config.dynatemp_exponent,
generation_config.smoothing_factor
)
warpers_to_add = LogitsProcessorList()
@ -328,7 +349,7 @@ def get_logits_warper_patch(self, generation_config):
if generation_config.temperature_last:
temperature_idx = None
for i in range(len(warpers)):
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'TemperatureLogitsWarperWithDynatemp']:
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'ModifiedTemperatureLogitsWarper']:
temperature_idx = i
break
@ -352,8 +373,7 @@ def get_logits_processor_patch(self, **kwargs):
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0)
if do_rep_pen_hijack:
# Make sure that a RepetitionPenaltyLogitsProcessor will be created
kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1
kwargs['generation_config'].repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
result = self._get_logits_processor_old(**kwargs)
@ -372,6 +392,7 @@ def generation_config_init_patch(self, **kwargs):
self.dynatemp_low = kwargs.pop("dynatemp_low", 1)
self.dynatemp_high = kwargs.pop("dynatemp_high", 1)
self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1)
self.smoothing_factor = kwargs.pop("smoothing_factor", 0.0)
self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0)
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)

View file

@ -285,8 +285,9 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
generate_params[k] = state[k]
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
if k in state:
generate_params[k] = state[k]
if state['negative_prompt'] != '':
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])

View file

@ -120,6 +120,7 @@ def list_interface_input_elements():
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'smoothing_factor',
'top_p',
'min_p',
'top_k',

View file

@ -49,6 +49,7 @@ def create_ui(default_preset):
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
shared.gradio['smoothing_factor'] = gr.Slider(0.0, 10.0, value=generate_params['smoothing_factor'], step=0.01, label='smoothing_factor', info='Replaces temperature with Quadratic Sampling.')
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=generate_params['dynamic_temperature'], label='dynamic_temperature')
shared.gradio['dynatemp_low'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_low'], step=0.01, label='dynatemp_low', visible=generate_params['dynamic_temperature'])
shared.gradio['dynatemp_high'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_high'], step=0.01, label='dynatemp_high', visible=generate_params['dynamic_temperature'])