mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-04 14:17:28 +00:00
Quadratic sampling (#5403)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
e98d1086f5
commit
b6077b02e4
8 changed files with 45 additions and 15 deletions
|
|
@ -159,6 +159,7 @@ def transformers_samplers():
|
|||
'dynatemp_low',
|
||||
'dynatemp_high',
|
||||
'dynatemp_exponent',
|
||||
'smoothing_factor',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
|
|
@ -233,6 +234,7 @@ loaders_samplers = {
|
|||
'dynatemp_low',
|
||||
'dynatemp_high',
|
||||
'dynatemp_exponent',
|
||||
'smoothing_factor',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
|
|
@ -289,6 +291,7 @@ loaders_samplers = {
|
|||
'dynatemp_low',
|
||||
'dynatemp_high',
|
||||
'dynatemp_exponent',
|
||||
'smoothing_factor',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ def default_preset():
|
|||
'dynatemp_low': 1,
|
||||
'dynatemp_high': 1,
|
||||
'dynatemp_exponent': 1,
|
||||
'smoothing_factor': 0,
|
||||
'top_p': 1,
|
||||
'min_p': 0,
|
||||
'top_k': 0,
|
||||
|
|
|
|||
|
|
@ -15,8 +15,12 @@ from modules import shared
|
|||
global_scores = None
|
||||
|
||||
|
||||
class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
||||
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float):
|
||||
class ModifiedTemperatureLogitsWarper(LogitsWarper):
|
||||
'''
|
||||
Based on the original Transformers temperature logits warper, this
|
||||
adds support for dynamic temperature and quadratic sampling.
|
||||
'''
|
||||
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float, smoothing_factor: float):
|
||||
if not isinstance(temperature, float) or not (temperature > 0):
|
||||
except_msg = (
|
||||
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
|
||||
|
|
@ -32,16 +36,27 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
|||
self.dynatemp_low = dynatemp_low
|
||||
self.dynatemp_high = dynatemp_high
|
||||
self.dynatemp_exponent = dynatemp_exponent
|
||||
self.smoothing_factor = smoothing_factor
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
# Regular temperature
|
||||
if not self.dynamic_temperature:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
# Quadratic sampling
|
||||
if self.smoothing_factor > 0:
|
||||
|
||||
# Compute the maximum logit value
|
||||
max_logit = scores.max()
|
||||
|
||||
# Apply the quadratic transformation
|
||||
transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit
|
||||
|
||||
# No need to print the top 5 logits since this is not required
|
||||
# print("Original top 5 logits: ", torch.topk(scores, 5))
|
||||
# print("New top 5 logits: ", torch.topk(transformed_logits, 5))
|
||||
|
||||
return transformed_logits
|
||||
|
||||
# Dynamic temperature
|
||||
else:
|
||||
elif self.dynamic_temperature:
|
||||
min_temp = self.dynatemp_low
|
||||
max_temp = self.dynatemp_high
|
||||
exponent_val = self.dynatemp_exponent
|
||||
|
|
@ -88,6 +103,11 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
|||
|
||||
return scores
|
||||
|
||||
# Regular temperature
|
||||
else:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
|
||||
|
||||
class MinPLogitsWarper(LogitsWarper):
|
||||
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
|
|
@ -286,7 +306,7 @@ def get_logits_warper_patch(self, generation_config):
|
|||
generation_config.temperature = float(generation_config.temperature)
|
||||
|
||||
temperature = generation_config.temperature
|
||||
if generation_config.dynamic_temperature:
|
||||
if generation_config.dynamic_temperature or generation_config.smoothing_factor > 0:
|
||||
# Make sure TemperatureLogitsWarper will be created by temporarily
|
||||
# setting temperature to a value != 1.
|
||||
generation_config.temperature = 1.1
|
||||
|
|
@ -294,12 +314,13 @@ def get_logits_warper_patch(self, generation_config):
|
|||
warpers = self._get_logits_warper_old(generation_config)
|
||||
for i in range(len(warpers)):
|
||||
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
||||
warpers[i] = TemperatureLogitsWarperWithDynatemp(
|
||||
warpers[i] = ModifiedTemperatureLogitsWarper(
|
||||
temperature,
|
||||
generation_config.dynamic_temperature,
|
||||
generation_config.dynatemp_low,
|
||||
generation_config.dynatemp_high,
|
||||
generation_config.dynatemp_exponent
|
||||
generation_config.dynatemp_exponent,
|
||||
generation_config.smoothing_factor
|
||||
)
|
||||
|
||||
warpers_to_add = LogitsProcessorList()
|
||||
|
|
@ -328,7 +349,7 @@ def get_logits_warper_patch(self, generation_config):
|
|||
if generation_config.temperature_last:
|
||||
temperature_idx = None
|
||||
for i in range(len(warpers)):
|
||||
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'TemperatureLogitsWarperWithDynatemp']:
|
||||
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'ModifiedTemperatureLogitsWarper']:
|
||||
temperature_idx = i
|
||||
break
|
||||
|
||||
|
|
@ -352,8 +373,7 @@ def get_logits_processor_patch(self, **kwargs):
|
|||
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
|
||||
do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0)
|
||||
if do_rep_pen_hijack:
|
||||
# Make sure that a RepetitionPenaltyLogitsProcessor will be created
|
||||
kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1
|
||||
kwargs['generation_config'].repetition_penalty = 1.1 # Set to value > 1 to ensure RepetitionPenaltyLogitsProcessor is created
|
||||
|
||||
result = self._get_logits_processor_old(**kwargs)
|
||||
|
||||
|
|
@ -372,6 +392,7 @@ def generation_config_init_patch(self, **kwargs):
|
|||
self.dynatemp_low = kwargs.pop("dynatemp_low", 1)
|
||||
self.dynatemp_high = kwargs.pop("dynatemp_high", 1)
|
||||
self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1)
|
||||
self.smoothing_factor = kwargs.pop("smoothing_factor", 0.0)
|
||||
self.tfs = kwargs.pop("tfs", 1.0)
|
||||
self.top_a = kwargs.pop("top_a", 0.0)
|
||||
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||
|
|
|
|||
|
|
@ -285,8 +285,9 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
|
|||
|
||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
|
||||
generate_params[k] = state[k]
|
||||
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
|
||||
if k in state:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if state['negative_prompt'] != '':
|
||||
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ def list_interface_input_elements():
|
|||
'dynatemp_low',
|
||||
'dynatemp_high',
|
||||
'dynatemp_exponent',
|
||||
'smoothing_factor',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ def create_ui(default_preset):
|
|||
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
|
||||
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
||||
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
||||
shared.gradio['smoothing_factor'] = gr.Slider(0.0, 10.0, value=generate_params['smoothing_factor'], step=0.01, label='smoothing_factor', info='Replaces temperature with Quadratic Sampling.')
|
||||
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=generate_params['dynamic_temperature'], label='dynamic_temperature')
|
||||
shared.gradio['dynatemp_low'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_low'], step=0.01, label='dynatemp_low', visible=generate_params['dynamic_temperature'])
|
||||
shared.gradio['dynatemp_high'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_high'], step=0.01, label='dynatemp_high', visible=generate_params['dynamic_temperature'])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue