mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-20 22:13:43 +00:00
Allow more granular KV cache settings (#6561)
This commit is contained in:
parent
c43ee5db11
commit
addad3c63e
8 changed files with 140 additions and 37 deletions
|
|
@ -10,6 +10,35 @@ from modules.llama_cpp_python_hijack import llama_cpp_lib
|
|||
from modules.logging_colors import logger
|
||||
from modules.text_generation import get_max_prompt_length
|
||||
|
||||
llamacpp_quant_mapping = {
|
||||
'f32': 0,
|
||||
'fp16': 1,
|
||||
'q4_0': 2,
|
||||
'q4_1': 3,
|
||||
'q5_0': 6,
|
||||
'q5_1': 7,
|
||||
'q8_0': 8,
|
||||
'q8_1': 9,
|
||||
'q2_k': 10,
|
||||
'q3_k': 11,
|
||||
'q4_k': 12,
|
||||
'q5_k': 13,
|
||||
'q6_k': 14,
|
||||
'q8_k': 15,
|
||||
'iq4_nl': 20,
|
||||
'bf16': 30,
|
||||
}
|
||||
|
||||
llamacpp_valid_cache_types = {'fp16', 'q8_0', 'q4_0'}
|
||||
|
||||
|
||||
def get_llamacpp_cache_type_for_string(quant_type: str):
|
||||
quant_type = quant_type.lower()
|
||||
if quant_type in llamacpp_valid_cache_types:
|
||||
return llamacpp_quant_mapping[quant_type]
|
||||
else:
|
||||
raise ValueError(f"Invalid cache type for llama.cpp: {quant_type}. Valid options are: fp16, q8_0, q4_0.")
|
||||
|
||||
|
||||
def ban_eos_logits_processor(eos_token, input_ids, logits):
|
||||
logits[eos_token] = -float('inf')
|
||||
|
|
@ -75,12 +104,9 @@ class LlamaCppModel:
|
|||
'flash_attn': shared.args.flash_attn
|
||||
}
|
||||
|
||||
if shared.args.cache_4bit:
|
||||
params["type_k"] = 2
|
||||
params["type_v"] = 2
|
||||
elif shared.args.cache_8bit:
|
||||
params["type_k"] = 8
|
||||
params["type_v"] = 8
|
||||
if shared.args.cache_type:
|
||||
params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
|
||||
params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
|
||||
|
||||
result.model = Llama(**params)
|
||||
if cache_capacity > 0:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue