Add custom sampler order support (#5443)

This commit is contained in:
oobabooga 2024-02-06 11:20:10 -03:00 committed by GitHub
parent 7301c7618f
commit 8c35fefb3b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 205 additions and 113 deletions

View file

@ -182,6 +182,7 @@ def transformers_samplers():
'negative_prompt',
'ban_eos_token',
'custom_token_bans',
'sampler_priority',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
@ -230,6 +231,7 @@ loaders_samplers = {
'negative_prompt',
'ban_eos_token',
'custom_token_bans',
'sampler_priority',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
@ -287,6 +289,7 @@ loaders_samplers = {
'negative_prompt',
'ban_eos_token',
'custom_token_bans',
'sampler_priority',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',

View file

@ -42,6 +42,7 @@ def default_preset():
'num_beams': 1,
'length_penalty': 1,
'early_stopping': False,
'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat'
}

View file

@ -1,4 +1,5 @@
import math
import pprint
import torch
import transformers
@ -6,21 +7,21 @@ from transformers import LogitsWarper, is_torch_xpu_available
from transformers.generation.logits_process import (
LogitNormalization,
LogitsProcessor,
LogitsProcessorList,
TemperatureLogitsWarper
LogitsProcessorList
)
from modules import shared
from modules.logging_colors import logger
global_scores = None
class ModifiedTemperatureLogitsWarper(LogitsWarper):
class TemperatureLogitsWarperCustom(LogitsWarper):
'''
Based on the original Transformers temperature logits warper, this
adds support for dynamic temperature and quadratic sampling.
A copy of the original Transformers temperature logits warper.
'''
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float, smoothing_factor: float):
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
@ -32,81 +33,90 @@ class ModifiedTemperatureLogitsWarper(LogitsWarper):
raise ValueError(except_msg)
self.temperature = temperature
self.dynamic_temperature = dynamic_temperature
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores = scores / self.temperature
return scores
class DynamicTemperatureLogitsWarper(LogitsWarper):
'''
Dynamic temperature.
'''
def __init__(self, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float):
self.dynatemp_low = dynatemp_low
self.dynatemp_high = dynatemp_high
self.dynatemp_exponent = dynatemp_exponent
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
min_temp = self.dynatemp_low
max_temp = self.dynatemp_high
exponent_val = self.dynatemp_exponent
# Convert logits to probabilities
probs = torch.softmax(scores, dim=-1)
# Calculate entropy of the softmax probabilities
entropy = -1.0 * torch.where(probs > 0, probs * torch.log(probs), torch.zeros_like(probs)).sum()
# Guard against future possible division by zero
entropy = max(entropy, torch.tensor(1e-10)) # Ensures entropy is slightly greater than 0
# Any logits which are not -Infinity will be considered for calculating max entropy.
num_valid_tokens = torch.sum(scores > -float('inf')).item()
# Now, calculate the max entropy by using only the valid tokens' count
max_entropy = math.log(num_valid_tokens)
# Guard against future possible division by zero
max_entropy = max_entropy if max_entropy > 0.0 else 1e-10
# Normalize the entropy
normalized_entropy = entropy / max_entropy
# Map the normalized entropy to the desired temperature range using the power function
dyn_temp = min_temp + (max_temp - min_temp) * (normalized_entropy.pow(exponent_val))
# Apply the dynamically calculated temperature scaling
scores = scores / dyn_temp
# print("----------------------\nTemperature from generation_config:", self.temperature)
# print("min_temp:", min_temp)
# print("max_temp:", max_temp)
# print("Entropy:", entropy.item())
# print("Max Possible Entropy considering valid tokens only:", max_entropy)
# print("Normalized Entropy:", normalized_entropy.item())
# print("Dynamic Temperature (dyn_temp):", dyn_temp.item())
# print("----------------------")
# max_prob_token_id = torch.argmax(scores, dim=-1) # Get the token ID with the highest probability
# max_prob_token = shared.tokenizer.convert_ids_to_tokens(int(max_prob_token_id)) # Convert ID to token
# print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp, "exponent=", exponent_val)
return scores
class QuadraticSamplingLogitsWarper(LogitsWarper):
'''
Quadratic sampling.
'''
def __init__(self, smoothing_factor: float):
self.smoothing_factor = smoothing_factor
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Compute the maximum logit value
max_logit = scores.max()
# Quadratic sampling
if self.smoothing_factor > 0:
# Apply the quadratic transformation
transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit
# Compute the maximum logit value
max_logit = scores.max()
# No need to print the top 5 logits since this is not required
# print("Original top 5 logits: ", torch.topk(scores, 5))
# print("New top 5 logits: ", torch.topk(transformed_logits, 5))
# Apply the quadratic transformation
transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit
# No need to print the top 5 logits since this is not required
# print("Original top 5 logits: ", torch.topk(scores, 5))
# print("New top 5 logits: ", torch.topk(transformed_logits, 5))
return transformed_logits
# Dynamic temperature
elif self.dynamic_temperature:
min_temp = self.dynatemp_low
max_temp = self.dynatemp_high
exponent_val = self.dynatemp_exponent
# Convert logits to probabilities
probs = torch.softmax(scores, dim=-1)
# Calculate entropy of the softmax probabilities
entropy = -1.0 * torch.where(probs > 0, probs * torch.log(probs), torch.zeros_like(probs)).sum()
# Guard against future possible division by zero
entropy = max(entropy, torch.tensor(1e-10)) # Ensures entropy is slightly greater than 0
# Any logits which are not -Infinity will be considered for calculating max entropy.
num_valid_tokens = torch.sum(scores > -float('inf')).item()
# Now, calculate the max entropy by using only the valid tokens' count
max_entropy = math.log(num_valid_tokens)
# Guard against future possible division by zero
max_entropy = max_entropy if max_entropy > 0.0 else 1e-10
# Normalize the entropy
normalized_entropy = entropy / max_entropy
# Map the normalized entropy to the desired temperature range using the power function
dyn_temp = min_temp + (max_temp - min_temp) * (normalized_entropy.pow(exponent_val))
# Apply the dynamically calculated temperature scaling
scores = scores / dyn_temp
# print("----------------------\nTemperature from generation_config:", self.temperature)
# print("min_temp:", min_temp)
# print("max_temp:", max_temp)
# print("Entropy:", entropy.item())
# print("Max Possible Entropy considering valid tokens only:", max_entropy)
# print("Normalized Entropy:", normalized_entropy.item())
# print("Dynamic Temperature (dyn_temp):", dyn_temp.item())
# print("----------------------")
# max_prob_token_id = torch.argmax(scores, dim=-1) # Get the token ID with the highest probability
# max_prob_token = shared.tokenizer.convert_ids_to_tokens(int(max_prob_token_id)) # Convert ID to token
# print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp, "exponent=", exponent_val)
return scores
# Regular temperature
else:
scores = scores / self.temperature
return scores
return transformed_logits
class MinPLogitsWarper(LogitsWarper):
@ -209,6 +219,7 @@ class MirostatLogitsWarper(LogitsWarper):
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if mirostat_mode not in [2]:
raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}")
self.mirostat_mode = mirostat_mode
self.mirostat_eta = mirostat_eta
self.mirostat_tau = mirostat_tau
@ -301,44 +312,74 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
def get_logits_warper_patch(self, generation_config):
# Make sure that temperature is float and not int
# Parameter sanitization
if isinstance(generation_config.temperature, int):
generation_config.temperature = float(generation_config.temperature)
temperature = generation_config.temperature
if generation_config.dynamic_temperature or generation_config.smoothing_factor > 0:
# Make sure TemperatureLogitsWarper will be created by temporarily
# setting temperature to a value != 1.
generation_config.temperature = 1.1
generation_config.temperature = float(generation_config.temperature) # Must be float
# Get the original warpers
warpers = self._get_logits_warper_old(generation_config)
# Replace temperature with our modified class.
# Currently, it behaves identically to the original.
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
warpers[i] = ModifiedTemperatureLogitsWarper(
temperature,
generation_config.dynamic_temperature,
generation_config.dynatemp_low,
generation_config.dynatemp_high,
generation_config.dynatemp_exponent,
generation_config.smoothing_factor
warpers[i] = TemperatureLogitsWarperCustom(
generation_config.temperature,
)
# Add custom warpers
warpers_to_add = LogitsProcessorList()
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
warpers_to_add.append(
TailFreeLogitsWarper(
tfs=generation_config.tfs,
min_tokens_to_keep=min_tokens_to_keep
)
)
if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0:
warpers_to_add.append(
TopALogitsWarper(
top_a=generation_config.top_a,
min_tokens_to_keep=min_tokens_to_keep
)
)
if generation_config.min_p is not None and 0.0 < generation_config.min_p <= 1.0:
warpers_to_add.append(
MinPLogitsWarper(
min_p=generation_config.min_p,
min_tokens_to_keep=min_tokens_to_keep
)
)
if generation_config.dynamic_temperature:
warpers_to_add.append(
DynamicTemperatureLogitsWarper(
dynatemp_low=generation_config.dynatemp_low,
dynatemp_high=generation_config.dynatemp_high,
dynatemp_exponent=generation_config.dynatemp_exponent,
)
)
if generation_config.smoothing_factor > 0:
warpers_to_add.append(
QuadraticSamplingLogitsWarper(
smoothing_factor=generation_config.smoothing_factor
)
)
if generation_config.mirostat_mode is not None and generation_config.mirostat_mode == 2:
warpers_to_add.append(MirostatLogitsWarper(mirostat_mode=generation_config.mirostat_mode, mirostat_eta=generation_config.mirostat_eta, mirostat_tau=generation_config.mirostat_tau, min_tokens_to_keep=min_tokens_to_keep))
# We need to disable samplers other than temperature
for warper in warpers:
if not isinstance(warper, TemperatureLogitsWarper):
warpers.remove(warper)
else:
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0:
warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.min_p is not None and 0.0 < generation_config.min_p <= 1.0:
warpers_to_add.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
warpers_to_add.append(
MirostatLogitsWarper(
mirostat_mode=generation_config.mirostat_mode,
mirostat_eta=generation_config.mirostat_eta,
mirostat_tau=generation_config.mirostat_tau,
min_tokens_to_keep=min_tokens_to_keep
)
)
if len(warpers) > 0 and isinstance(warpers[-1], LogitNormalization):
normalize = warpers.pop(-1)
@ -346,23 +387,57 @@ def get_logits_warper_patch(self, generation_config):
normalize = None
warpers += warpers_to_add
if generation_config.temperature_last:
temperature_idx = None
for i in range(len(warpers)):
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'ModifiedTemperatureLogitsWarper']:
temperature_idx = i
break
if temperature_idx is not None:
warpers.append(warpers.pop(temperature_idx))
# Sort the samplers.
sampler_priority = generation_config.sampler_priority
# Handle temperature_last
if generation_config.temperature_last:
for param_name in ['temperature', 'dynamic_temperature', 'quadratic_sampling']:
if param_name in sampler_priority:
if param_name in sampler_priority:
index = sampler_priority.index(param_name)
sampler_priority.append(sampler_priority.pop(index))
else:
sampler_priority.append(param_name)
class_name_to_nickname = {
'DynamicTemperatureLogitsWarper': 'dynamic_temperature',
'EpsilonLogitsWarper': 'epsilon_cutoff',
'EtaLogitsWarper': 'eta_cutoff',
'MinPLogitsWarper': 'min_p',
'MirostatLogitsWarper': 'mirostat',
'QuadraticSamplingLogitsWarper': 'quadratic_sampling',
'TailFreeLogitsWarper': 'tfs',
'TemperatureLogitsWarperCustom': 'temperature',
'TopALogitsWarper': 'top_a',
'TopKLogitsWarper': 'top_k',
'TopPLogitsWarper': 'top_p',
'TypicalLogitsWarper': 'typical_p'
}
def custom_sort_key(obj):
class_name = obj.__class__.__name__
# Return a large value if class name is not mapped or if the mapped nickname is not in priority
if class_name not in class_name_to_nickname or class_name_to_nickname[class_name] not in sampler_priority:
return float('inf')
# Return the index of the nickname in the priority list for sorting
return sampler_priority.index(class_name_to_nickname[class_name])
# Sort the list using the custom key function
warpers = sorted(warpers, key=custom_sort_key)
if normalize is not None:
warpers.append(normalize)
warpers.append(SpyLogitsWarper())
warpers = LogitsProcessorList(warpers)
# for i in range(len(warpers)):
# print(warpers[i].__class__.__name__)
if shared.args.verbose:
logger.info("WARPERS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint([x.__class__.__name__ for x in warpers])
return warpers
@ -402,6 +477,7 @@ def generation_config_init_patch(self, **kwargs):
self.presence_penalty = kwargs.pop("presence_penalty", 0)
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
self.temperature_last = kwargs.pop("temperature_last", False)
self.sampler_priority = kwargs.pop("sampler_priority", ['temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat'])
def hijack_samplers():

View file

@ -50,6 +50,7 @@ settings = {
'prompt_lookup_num_tokens': 0,
'custom_stopping_strings': '',
'custom_token_bans': '',
'sampler_priority': 'temperature,top_k,top_p,typical_p,epsilon_cutoff,eta_cutoff,tfs,top_a,min_p,dynamic_temperature,quadratic_sampling,mirostat',
'auto_max_new_tokens': False,
'ban_eos_token': False,
'add_bos_token': True,

View file

@ -291,6 +291,11 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
if k in state:
generate_params[k] = state[k]
if isinstance(state['sampler_priority'], list):
generate_params['sampler_priority'] = state['sampler_priority']
elif isinstance(state['sampler_priority'], str):
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
if state['negative_prompt'] != '':
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])

View file

@ -149,6 +149,7 @@ def list_interface_input_elements():
'add_bos_token',
'ban_eos_token',
'custom_token_bans',
'sampler_priority',
'truncation_length',
'custom_stopping_strings',
'skip_special_tokens',

View file

@ -49,12 +49,12 @@ def create_ui(default_preset):
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
shared.gradio['smoothing_factor'] = gr.Slider(0.0, 10.0, value=generate_params['smoothing_factor'], step=0.01, label='smoothing_factor', info='Replaces temperature with Quadratic Sampling.')
shared.gradio['smoothing_factor'] = gr.Slider(0.0, 10.0, value=generate_params['smoothing_factor'], step=0.01, label='smoothing_factor', info='Activates Quadratic Sampling.')
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=generate_params['dynamic_temperature'], label='dynamic_temperature')
shared.gradio['dynatemp_low'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_low'], step=0.01, label='dynatemp_low', visible=generate_params['dynamic_temperature'])
shared.gradio['dynatemp_high'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_high'], step=0.01, label='dynatemp_high', visible=generate_params['dynamic_temperature'])
shared.gradio['dynatemp_exponent'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_exponent'], step=0.01, label='dynatemp_exponent', visible=generate_params['dynamic_temperature'])
shared.gradio['temperature_last'] = gr.Checkbox(value=generate_params['temperature_last'], label='temperature_last', info='Makes temperature the last sampler instead of the first.')
shared.gradio['temperature_last'] = gr.Checkbox(value=generate_params['temperature_last'], label='temperature_last', info='Moves temperature/dynamic temperature/quadratic sampling to the end of the sampler stack, ignoring their positions in "Sampler priority".')
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
with gr.Accordion('Other parameters', open=False):
@ -85,6 +85,9 @@ def create_ui(default_preset):
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming')
with gr.Blocks():
shared.gradio['sampler_priority'] = gr.Textbox(value=generate_params['sampler_priority'], lines=12, label='Sampler priority', info='Parameter names separated by new lines or commas.')
with gr.Row() as shared.gradio['grammar_file_row']:
shared.gradio['grammar_file'] = gr.Dropdown(value='None', choices=utils.get_available_grammars(), label='Load grammar from file (.gbnf)', elem_classes='slim-dropdown')
ui.create_refresh_button(shared.gradio['grammar_file'], lambda: None, lambda: {'choices': utils.get_available_grammars()}, 'refresh-button', interactive=not mu)