mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-06 05:33:50 +01:00
Add adaptive-p sampler and n-gram speculative decoding support
This commit is contained in:
parent
f010aa1612
commit
65de4c30c8
|
|
@ -22,6 +22,8 @@ class GenerationOptions(BaseModel):
|
|||
tfs: float = 1
|
||||
top_a: float = 0
|
||||
top_n_sigma: float = 0
|
||||
adaptive_target: float = 0
|
||||
adaptive_decay: float = 0.9
|
||||
dry_multiplier: float = 0
|
||||
dry_allowed_length: int = 2
|
||||
dry_base: float = 1.75
|
||||
|
|
@ -48,7 +50,7 @@ class GenerationOptions(BaseModel):
|
|||
static_cache: bool = False
|
||||
truncation_length: int = 0
|
||||
seed: int = -1
|
||||
sampler_priority: List[str] | str | None = Field(default=None, description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].")
|
||||
sampler_priority: List[str] | str | None = Field(default=['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'top_n_sigma', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'adaptive_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'], description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].")
|
||||
custom_token_bans: str = ""
|
||||
negative_prompt: str = ''
|
||||
dry_sequence_breakers: str = '"\\n", ":", "\\"", "*"'
|
||||
|
|
|
|||
|
|
@ -75,6 +75,8 @@ class LlamaServer:
|
|||
"top_p": state["top_p"],
|
||||
"min_p": state["min_p"],
|
||||
"top_n_sigma": state["top_n_sigma"] if state["top_n_sigma"] > 0 else -1,
|
||||
"adaptive_target": state["adaptive_target"] if state["adaptive_target"] > 0 else -1,
|
||||
"adaptive_decay": state["adaptive_decay"],
|
||||
"typical_p": state["typical_p"],
|
||||
"repeat_penalty": state["repetition_penalty"],
|
||||
"repeat_last_n": state["repetition_penalty_range"],
|
||||
|
|
@ -123,6 +125,12 @@ class LlamaServer:
|
|||
filtered_samplers.remove("temperature")
|
||||
filtered_samplers.append("temperature")
|
||||
|
||||
# adaptive-p replaces the default dist sampler; llama.cpp always
|
||||
# places it at the end of the chain regardless of position, so we
|
||||
# activate it based on the parameter value rather than sampler order.
|
||||
if state.get("adaptive_target", 0) > 0:
|
||||
filtered_samplers.append("adaptive-p")
|
||||
|
||||
payload["samplers"] = filtered_samplers
|
||||
|
||||
if state['custom_token_bans']:
|
||||
|
|
@ -391,6 +399,16 @@ class LlamaServer:
|
|||
cmd += ["--device-draft", shared.args.device_draft]
|
||||
if shared.args.ctx_size_draft > 0:
|
||||
cmd += ["--ctx-size-draft", str(shared.args.ctx_size_draft)]
|
||||
if shared.args.spec_type != 'none':
|
||||
cmd += ["--spec-type", shared.args.spec_type]
|
||||
if shared.args.draft_max > 0:
|
||||
cmd += ["--draft-max", str(shared.args.draft_max)]
|
||||
if shared.args.spec_ngram_size_n != 12:
|
||||
cmd += ["--spec-ngram-size-n", str(shared.args.spec_ngram_size_n)]
|
||||
if shared.args.spec_ngram_size_m != 48:
|
||||
cmd += ["--spec-ngram-size-m", str(shared.args.spec_ngram_size_m)]
|
||||
if shared.args.spec_ngram_min_hits != 1:
|
||||
cmd += ["--spec-ngram-min-hits", str(shared.args.spec_ngram_min_hits)]
|
||||
if shared.args.streaming_llm:
|
||||
cmd += ["--cache-reuse", "1"]
|
||||
cmd += ["--swa-full"]
|
||||
|
|
|
|||
|
|
@ -28,6 +28,10 @@ loaders_and_params = OrderedDict({
|
|||
'gpu_layers_draft',
|
||||
'device_draft',
|
||||
'ctx_size_draft',
|
||||
'spec_type',
|
||||
'spec_ngram_size_n',
|
||||
'spec_ngram_size_m',
|
||||
'spec_ngram_min_hits',
|
||||
'speculative_decoding_accordion',
|
||||
'mmproj',
|
||||
'mmproj_accordion',
|
||||
|
|
@ -128,6 +132,8 @@ def transformers_samplers():
|
|||
'tfs',
|
||||
'top_a',
|
||||
'top_n_sigma',
|
||||
'adaptive_target',
|
||||
'adaptive_decay',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
|
|
@ -183,6 +189,8 @@ loaders_samplers = {
|
|||
'tfs',
|
||||
'top_a',
|
||||
'top_n_sigma',
|
||||
'adaptive_target',
|
||||
'adaptive_decay',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
|
|
@ -231,6 +239,8 @@ loaders_samplers = {
|
|||
'tfs',
|
||||
'top_a',
|
||||
'top_n_sigma',
|
||||
'adaptive_target',
|
||||
'adaptive_decay',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
|
|
@ -327,6 +337,8 @@ loaders_samplers = {
|
|||
'xtc_threshold',
|
||||
'xtc_probability',
|
||||
'top_n_sigma',
|
||||
'adaptive_target',
|
||||
'adaptive_decay',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ def default_preset():
|
|||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'top_n_sigma': 0,
|
||||
'adaptive_target': 0,
|
||||
'adaptive_decay': 0.9,
|
||||
'dry_multiplier': 0,
|
||||
'dry_allowed_length': 2,
|
||||
'dry_base': 1.75,
|
||||
|
|
@ -45,7 +47,7 @@ def default_preset():
|
|||
'do_sample': True,
|
||||
'dynamic_temperature': False,
|
||||
'temperature_last': False,
|
||||
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntop_n_sigma\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram',
|
||||
'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntop_n_sigma\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nadaptive_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram',
|
||||
'dry_sequence_breakers': '"\\n", ":", "\\"", "*"',
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -235,6 +235,73 @@ class TopNSigmaLogitsWarper(LogitsProcessor):
|
|||
return scores
|
||||
|
||||
|
||||
class AdaptivePLogitsWarper(LogitsProcessor):
|
||||
'''
|
||||
Adaptive-p sampling. A stateful sampler that favors tokens near a target
|
||||
probability, using an EMA-based control loop to adapt over time.
|
||||
|
||||
Matches the llama.cpp implementation from PR #17927.
|
||||
'''
|
||||
|
||||
DISTRIBUTION_WIDTH = 0.3
|
||||
PEAK_LOGIT_VALUE = 5.0
|
||||
SHARPNESS = 10.0
|
||||
INV_WIDTH = 1.0 / DISTRIBUTION_WIDTH
|
||||
|
||||
def __init__(self, adaptive_target, adaptive_decay, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||
self.target = adaptive_target
|
||||
self.decay = min(adaptive_decay, 0.99)
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
# Initialize EMA at equilibrium (as if target was already achieved)
|
||||
if self.decay < 1.0:
|
||||
self.weighted_sum = self.target / (1.0 - self.decay)
|
||||
self.total_weight = 1.0 / (1.0 - self.decay)
|
||||
else:
|
||||
self.weighted_sum = 0.0
|
||||
self.total_weight = 0.0
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
logits = scores[0]
|
||||
|
||||
# Compute original probabilities (before transform)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
# Compute adapted target using proportional control on the EMA
|
||||
if self.total_weight > 0:
|
||||
ema_avg = self.weighted_sum / self.total_weight
|
||||
else:
|
||||
ema_avg = self.target
|
||||
|
||||
adapted_target = max(0.0, min(1.0, 2.0 * self.target - ema_avg))
|
||||
|
||||
# Adaptive probability transform:
|
||||
# quadratic near target for fine differentiation, transitioning
|
||||
# to linear decay in the tails for proper suppression after softmax
|
||||
dist = torch.abs((probs - adapted_target) * self.INV_WIDTH)
|
||||
new_logits = self.PEAK_LOGIT_VALUE - self.SHARPNESS * dist * dist / (1.0 + dist)
|
||||
|
||||
# Preserve already-masked tokens (-inf logits from prior samplers)
|
||||
new_logits = torch.where(torch.isfinite(logits), new_logits, logits)
|
||||
|
||||
# Softmax and sample from the transformed distribution
|
||||
new_probs = torch.softmax(new_logits, dim=-1)
|
||||
selected = torch.multinomial(new_probs, num_samples=1, replacement=True)
|
||||
|
||||
# Update EMA with the original probability of the selected token
|
||||
original_prob = probs[selected[0]].item()
|
||||
self.weighted_sum = original_prob + self.decay * self.weighted_sum
|
||||
self.total_weight = 1.0 + self.decay * self.total_weight
|
||||
|
||||
# Mask all tokens except the selected one
|
||||
indices_to_remove = torch.ones_like(scores[0], dtype=torch.bool)
|
||||
indices_to_remove[selected[0]] = False
|
||||
indices_to_remove = indices_to_remove.unsqueeze(0)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
|
||||
# Exclude Top Choices (XTC)
|
||||
class XTCLogitsWarper(LogitsProcessor):
|
||||
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
|
||||
|
|
@ -575,6 +642,15 @@ def get_logits_processor_patch(self, **kwargs):
|
|||
)
|
||||
)
|
||||
|
||||
if generation_config.adaptive_target is not None and generation_config.adaptive_target > 0.0:
|
||||
warpers_to_add.append(
|
||||
AdaptivePLogitsWarper(
|
||||
adaptive_target=generation_config.adaptive_target,
|
||||
adaptive_decay=generation_config.adaptive_decay,
|
||||
min_tokens_to_keep=min_tokens_to_keep
|
||||
)
|
||||
)
|
||||
|
||||
if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0:
|
||||
warpers_to_add.append(
|
||||
XTCLogitsWarper(
|
||||
|
|
@ -640,6 +716,7 @@ def get_logits_processor_patch(self, **kwargs):
|
|||
'TemperatureLogitsWarperCustom': 'temperature',
|
||||
'TopALogitsWarper': 'top_a',
|
||||
'TopNSigmaLogitsWarper': 'top_n_sigma',
|
||||
'AdaptivePLogitsWarper': 'adaptive_p',
|
||||
'TopKLogitsWarper': 'top_k',
|
||||
'TopPLogitsWarper': 'top_p',
|
||||
'TypicalLogitsWarper': 'typical_p',
|
||||
|
|
@ -688,6 +765,8 @@ def generation_config_init_patch(self, **kwargs):
|
|||
self.tfs = kwargs.pop("tfs", 1.0)
|
||||
self.top_a = kwargs.pop("top_a", 0.0)
|
||||
self.top_n_sigma = kwargs.pop("top_n_sigma", 0.0)
|
||||
self.adaptive_target = kwargs.pop("adaptive_target", 0.0)
|
||||
self.adaptive_decay = kwargs.pop("adaptive_decay", 0.9)
|
||||
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
|
||||
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
|
||||
|
|
@ -701,7 +780,7 @@ def generation_config_init_patch(self, **kwargs):
|
|||
self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
|
||||
self.xtc_probability = kwargs.pop("xtc_probability", 0)
|
||||
self.temperature_last = kwargs.pop("temperature_last", False)
|
||||
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_n_sigma', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
|
||||
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_n_sigma', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'adaptive_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
|
||||
|
||||
|
||||
def hijack_samplers():
|
||||
|
|
|
|||
|
|
@ -81,6 +81,10 @@ group.add_argument('--draft-max', type=int, default=4, help='Number of tokens to
|
|||
group.add_argument('--gpu-layers-draft', type=int, default=256, help='Number of layers to offload to the GPU for the draft model.')
|
||||
group.add_argument('--device-draft', type=str, default=None, help='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1')
|
||||
group.add_argument('--ctx-size-draft', type=int, default=0, help='Size of the prompt context for the draft model. If 0, uses the same as the main model.')
|
||||
group.add_argument('--spec-type', type=str, default='none', choices=['none', 'ngram-cache', 'ngram-simple', 'ngram-map-k', 'ngram-map-k4v', 'ngram-mod'], help='Speculative decoding type for draftless speculation.')
|
||||
group.add_argument('--spec-ngram-size-n', type=int, default=12, help='N-gram lookup size for ngram speculative decoding.')
|
||||
group.add_argument('--spec-ngram-size-m', type=int, default=48, help='Draft n-gram size for ngram speculative decoding.')
|
||||
group.add_argument('--spec-ngram-min-hits', type=int, default=1, help='Minimum n-gram hits for ngram-map speculative decoding.')
|
||||
|
||||
# llama.cpp
|
||||
group = parser.add_argument_group('llama.cpp')
|
||||
|
|
@ -269,6 +273,8 @@ settings = {
|
|||
'tfs': neutral_samplers['tfs'],
|
||||
'top_a': neutral_samplers['top_a'],
|
||||
'top_n_sigma': neutral_samplers['top_n_sigma'],
|
||||
'adaptive_target': neutral_samplers['adaptive_target'],
|
||||
'adaptive_decay': neutral_samplers['adaptive_decay'],
|
||||
|
||||
# Generation parameters - Repetition suppression
|
||||
'dry_multiplier': neutral_samplers['dry_multiplier'],
|
||||
|
|
|
|||
|
|
@ -317,6 +317,8 @@ def generate_reply_HF(question, original_question, state, stopping_strings=None,
|
|||
'tfs',
|
||||
'top_a',
|
||||
'top_n_sigma',
|
||||
'adaptive_target',
|
||||
'adaptive_decay',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
|
|
|
|||
|
|
@ -168,6 +168,10 @@ def list_model_elements():
|
|||
'gpu_layers_draft',
|
||||
'device_draft',
|
||||
'ctx_size_draft',
|
||||
'spec_type',
|
||||
'spec_ngram_size_n',
|
||||
'spec_ngram_size_m',
|
||||
'spec_ngram_min_hits',
|
||||
'mmproj',
|
||||
]
|
||||
|
||||
|
|
@ -193,6 +197,8 @@ def list_interface_input_elements():
|
|||
'tfs',
|
||||
'top_a',
|
||||
'top_n_sigma',
|
||||
'adaptive_target',
|
||||
'adaptive_decay',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
|
|
@ -488,6 +494,8 @@ def setup_auto_save():
|
|||
'tfs',
|
||||
'top_a',
|
||||
'top_n_sigma',
|
||||
'adaptive_target',
|
||||
'adaptive_decay',
|
||||
'dry_multiplier',
|
||||
'dry_allowed_length',
|
||||
'dry_base',
|
||||
|
|
|
|||
|
|
@ -76,6 +76,10 @@ def create_ui():
|
|||
shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Number of tokens to draft for speculative decoding. Recommended value: 4.')
|
||||
shared.gradio['device_draft'] = gr.Textbox(label="device-draft", value=shared.args.device_draft, info='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1')
|
||||
shared.gradio['ctx_size_draft'] = gr.Number(label="ctx-size-draft", precision=0, step=256, value=shared.args.ctx_size_draft, info='Size of the prompt context for the draft model. If 0, uses the same as the main model.')
|
||||
shared.gradio['spec_type'] = gr.Dropdown(label="spec-type", choices=['none', 'ngram-cache', 'ngram-simple', 'ngram-map-k', 'ngram-map-k4v', 'ngram-mod'], value=shared.args.spec_type, info='Draftless speculative decoding type. Uses n-gram matching from context.')
|
||||
shared.gradio['spec_ngram_size_n'] = gr.Number(label="spec-ngram-size-n", precision=0, step=1, value=shared.args.spec_ngram_size_n, info='N-gram lookup size for speculative decoding.', visible=shared.args.spec_type != 'none')
|
||||
shared.gradio['spec_ngram_size_m'] = gr.Number(label="spec-ngram-size-m", precision=0, step=1, value=shared.args.spec_ngram_size_m, info='Draft n-gram size for speculative decoding.', visible=shared.args.spec_type != 'none')
|
||||
shared.gradio['spec_ngram_min_hits'] = gr.Number(label="spec-ngram-min-hits", precision=0, step=1, value=shared.args.spec_ngram_min_hits, info='Minimum n-gram hits for ngram-map speculative decoding.', visible=shared.args.spec_type != 'none')
|
||||
|
||||
gr.Markdown("## Other options")
|
||||
with gr.Accordion("See more options", open=False, elem_classes='tgw-accordion'):
|
||||
|
|
@ -179,6 +183,13 @@ def create_event_handlers():
|
|||
if not shared.args.portable:
|
||||
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False)
|
||||
|
||||
shared.gradio['spec_type'].change(
|
||||
lambda x: [gr.update(visible=x != 'none')] * 3,
|
||||
gradio('spec_type'),
|
||||
gradio('spec_ngram_size_n', 'spec_ngram_size_m', 'spec_ngram_min_hits'),
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
|
||||
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
|
||||
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
|
||||
|
|
|
|||
|
|
@ -67,6 +67,8 @@ def create_ui():
|
|||
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=shared.settings['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
|
||||
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=shared.settings['mirostat_tau'], label='mirostat_tau')
|
||||
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=shared.settings['mirostat_eta'], label='mirostat_eta')
|
||||
shared.gradio['adaptive_target'] = gr.Slider(0.0, 1.0, value=shared.settings['adaptive_target'], step=0.01, label='adaptive_target', info='Target probability for adaptive-p sampling. Tokens near this probability are favored. 0 disables.')
|
||||
shared.gradio['adaptive_decay'] = gr.Slider(0.0, 0.99, value=shared.settings['adaptive_decay'], step=0.01, label='adaptive_decay', info='EMA decay rate for adaptive-p. Controls history window (~1/(1-decay) tokens). Default: 0.9.')
|
||||
|
||||
gr.Markdown('## Other options')
|
||||
shared.gradio['do_sample'] = gr.Checkbox(value=shared.settings['do_sample'], label='do_sample')
|
||||
|
|
|
|||
Loading…
Reference in a new issue