Allow more granular KV cache settings (#6561)

This commit is contained in:
Diner Burger 2024-12-17 15:43:48 -05:00 committed by GitHub
parent c43ee5db11
commit addad3c63e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 140 additions and 37 deletions

View file

@ -2,17 +2,19 @@ import traceback
from pathlib import Path from pathlib import Path
import torch import torch
from exllamav2 import ( from exllamav2 import (
ExLlamaV2, ExLlamaV2,
ExLlamaV2Cache, ExLlamaV2Cache,
ExLlamaV2Cache_8bit, ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q4,
ExLlamaV2Cache_Q6,
ExLlamaV2Cache_Q8,
ExLlamaV2Cache_TP, ExLlamaV2Cache_TP,
ExLlamaV2Config, ExLlamaV2Config,
ExLlamaV2Tokenizer ExLlamaV2Tokenizer
) )
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length from modules.text_generation import get_max_prompt_length
@ -57,12 +59,22 @@ class Exllamav2Model:
model.load(split) model.load(split)
# Determine the correct cache type # Determine the correct cache type
if shared.args.cache_8bit: kv_cache_type = 'fp16'
if shared.args.cache_type:
kv_cache_type = shared.args.cache_type.lower()
if kv_cache_type == 'fp16':
cache_type = ExLlamaV2Cache
elif kv_cache_type == 'fp8':
cache_type = ExLlamaV2Cache_8bit cache_type = ExLlamaV2Cache_8bit
elif shared.args.cache_4bit: elif kv_cache_type == 'q8':
cache_type = ExLlamaV2Cache_Q8
elif kv_cache_type == 'q6':
cache_type = ExLlamaV2Cache_Q6
elif kv_cache_type == 'q4':
cache_type = ExLlamaV2Cache_Q4 cache_type = ExLlamaV2Cache_Q4
else: else:
cache_type = ExLlamaV2Cache raise ValueError(f"Invalid cache type for ExLlamaV2: {cache_type}. Valid options are: fp16, fp8, q8, q6, q4.")
# Use TP if specified # Use TP if specified
if shared.args.enable_tp: if shared.args.enable_tp:

View file

@ -4,18 +4,20 @@ from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import torch import torch
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from exllamav2 import ( from exllamav2 import (
ExLlamaV2, ExLlamaV2,
ExLlamaV2Cache, ExLlamaV2Cache,
ExLlamaV2Cache_8bit, ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q4,
ExLlamaV2Cache_Q6,
ExLlamaV2Cache_Q8,
ExLlamaV2Cache_TP, ExLlamaV2Cache_TP,
ExLlamaV2Config ExLlamaV2Config
) )
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
@ -45,12 +47,22 @@ class Exllamav2HF(PreTrainedModel):
self.ex_model.load(split) self.ex_model.load(split)
# Determine the correct cache type # Determine the correct cache type
if shared.args.cache_8bit: kv_cache_type = 'fp16'
if shared.args.cache_type:
kv_cache_type = shared.args.cache_type.lower()
if kv_cache_type == 'fp16':
cache_type = ExLlamaV2Cache
elif kv_cache_type == 'fp8':
cache_type = ExLlamaV2Cache_8bit cache_type = ExLlamaV2Cache_8bit
elif shared.args.cache_4bit: elif kv_cache_type == 'q8':
cache_type = ExLlamaV2Cache_Q8
elif kv_cache_type == 'q6':
cache_type = ExLlamaV2Cache_Q6
elif kv_cache_type == 'q4':
cache_type = ExLlamaV2Cache_Q4 cache_type = ExLlamaV2Cache_Q4
else: else:
cache_type = ExLlamaV2Cache raise ValueError(f"Invalid cache type for ExLlamaV2: {cache_type}. Valid options are: fp16, fp8, q8, q6, q4.")
# Use TP if specified # Use TP if specified
if shared.args.enable_tp: if shared.args.enable_tp:

View file

@ -9,6 +9,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from modules import shared from modules import shared
from modules.llama_cpp_python_hijack import llama_cpp_lib from modules.llama_cpp_python_hijack import llama_cpp_lib
from modules.llamacpp_model import get_llamacpp_cache_type_for_string
from modules.logging_colors import logger from modules.logging_colors import logger
@ -196,12 +197,9 @@ class LlamacppHF(PreTrainedModel):
'flash_attn': shared.args.flash_attn 'flash_attn': shared.args.flash_attn
} }
if shared.args.cache_4bit: if shared.args.cache_type:
params["type_k"] = 2 params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
params["type_v"] = 2 params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
elif shared.args.cache_8bit:
params["type_k"] = 8
params["type_v"] = 8
Llama = llama_cpp_lib().Llama Llama = llama_cpp_lib().Llama
model = Llama(**params) model = Llama(**params)

View file

@ -10,6 +10,35 @@ from modules.llama_cpp_python_hijack import llama_cpp_lib
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length from modules.text_generation import get_max_prompt_length
llamacpp_quant_mapping = {
'f32': 0,
'fp16': 1,
'q4_0': 2,
'q4_1': 3,
'q5_0': 6,
'q5_1': 7,
'q8_0': 8,
'q8_1': 9,
'q2_k': 10,
'q3_k': 11,
'q4_k': 12,
'q5_k': 13,
'q6_k': 14,
'q8_k': 15,
'iq4_nl': 20,
'bf16': 30,
}
llamacpp_valid_cache_types = {'fp16', 'q8_0', 'q4_0'}
def get_llamacpp_cache_type_for_string(quant_type: str):
quant_type = quant_type.lower()
if quant_type in llamacpp_valid_cache_types:
return llamacpp_quant_mapping[quant_type]
else:
raise ValueError(f"Invalid cache type for llama.cpp: {quant_type}. Valid options are: fp16, q8_0, q4_0.")
def ban_eos_logits_processor(eos_token, input_ids, logits): def ban_eos_logits_processor(eos_token, input_ids, logits):
logits[eos_token] = -float('inf') logits[eos_token] = -float('inf')
@ -75,12 +104,9 @@ class LlamaCppModel:
'flash_attn': shared.args.flash_attn 'flash_attn': shared.args.flash_attn
} }
if shared.args.cache_4bit: if shared.args.cache_type:
params["type_k"] = 2 params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
params["type_v"] = 2 params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
elif shared.args.cache_8bit:
params["type_k"] = 8
params["type_v"] = 8
result.model = Llama(**params) result.model = Llama(**params)
if cache_capacity > 0: if cache_capacity > 0:

View file

@ -31,8 +31,7 @@ loaders_and_params = OrderedDict({
'llama.cpp': [ 'llama.cpp': [
'n_ctx', 'n_ctx',
'n_gpu_layers', 'n_gpu_layers',
'cache_8bit', 'cache_type',
'cache_4bit',
'tensor_split', 'tensor_split',
'n_batch', 'n_batch',
'threads', 'threads',
@ -54,8 +53,7 @@ loaders_and_params = OrderedDict({
'llamacpp_HF': [ 'llamacpp_HF': [
'n_ctx', 'n_ctx',
'n_gpu_layers', 'n_gpu_layers',
'cache_8bit', 'cache_type',
'cache_4bit',
'tensor_split', 'tensor_split',
'n_batch', 'n_batch',
'threads', 'threads',
@ -87,8 +85,7 @@ loaders_and_params = OrderedDict({
'no_xformers', 'no_xformers',
'no_sdpa', 'no_sdpa',
'num_experts_per_token', 'num_experts_per_token',
'cache_8bit', 'cache_type',
'cache_4bit',
'autosplit', 'autosplit',
'enable_tp', 'enable_tp',
'alpha_value', 'alpha_value',
@ -103,8 +100,7 @@ loaders_and_params = OrderedDict({
'no_xformers', 'no_xformers',
'no_sdpa', 'no_sdpa',
'num_experts_per_token', 'num_experts_per_token',
'cache_8bit', 'cache_type',
'cache_4bit',
'autosplit', 'autosplit',
'enable_tp', 'enable_tp',
'alpha_value', 'alpha_value',

View file

@ -142,8 +142,6 @@ group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Creat
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.') group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.') group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.')
group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to not be used.') group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to not be used.')
group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')
group.add_argument('--cache_4bit', action='store_true', help='Use Q4 cache to save VRAM.')
group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.') group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.')
group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.') group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.')
@ -166,6 +164,10 @@ group.add_argument('--hqq-backend', type=str, default='PYTORCH_COMPILE', help='B
group = parser.add_argument_group('TensorRT-LLM') group = parser.add_argument_group('TensorRT-LLM')
group.add_argument('--cpp-runner', action='store_true', help='Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn\'t support streaming yet.') group.add_argument('--cpp-runner', action='store_true', help='Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn\'t support streaming yet.')
# Cache
group = parser.add_argument_group('Cache')
group.add_argument('--cache_type', type=str, default=None, help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
# DeepSpeed # DeepSpeed
group = parser.add_argument_group('DeepSpeed') group = parser.add_argument_group('DeepSpeed')
group.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') group.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
@ -213,6 +215,8 @@ group.add_argument('--pre_layer', type=int, nargs='+', help='DEPRECATED')
group.add_argument('--checkpoint', type=str, help='DEPRECATED') group.add_argument('--checkpoint', type=str, help='DEPRECATED')
group.add_argument('--monkey-patch', action='store_true', help='DEPRECATED') group.add_argument('--monkey-patch', action='store_true', help='DEPRECATED')
group.add_argument('--no_inject_fused_attention', action='store_true', help='DEPRECATED') group.add_argument('--no_inject_fused_attention', action='store_true', help='DEPRECATED')
group.add_argument('--cache_4bit', action='store_true', help='DEPRECATED')
group.add_argument('--cache_8bit', action='store_true', help='DEPRECATED')
group.add_argument('--chat-buttons', action='store_true', help='DEPRECATED') group.add_argument('--chat-buttons', action='store_true', help='DEPRECATED')
args = parser.parse_args() args = parser.parse_args()
@ -270,6 +274,59 @@ def fix_loader_name(name):
return 'TensorRT-LLM' return 'TensorRT-LLM'
def transform_legacy_kv_cache_options(opts):
# Handle both argparse.Namespace and dict here
def get(key):
return opts.get(key) if isinstance(opts, dict) else getattr(opts, key, None)
def set(key, value):
if isinstance(opts, dict):
opts[key] = value
else:
setattr(opts, key, value)
def del_key(key, fallback_set):
# only remove from user dict, can't delete from argparse.Namespace
if type(opts) is dict:
if key in opts:
del opts[key]
else:
setattr(opts, key, fallback_set)
# Retrieve values
loader = get('loader')
cache_type = get('cache_type')
cache_8bit = get('cache_8bit')
cache_4bit = get('cache_4bit')
# Determine cache type based on loader or legacy flags
if not cache_type:
if not loader:
# Legacy behavior: prefer 8-bit over 4-bit to minimize breakage
if cache_8bit:
set('cache_type', 'fp8')
elif cache_4bit:
set('cache_type', 'q4')
elif loader.lower() in ['exllamav2', 'exllamav2_hf']:
# ExLlamaV2 loader-specific cache type
if cache_8bit:
set('cache_type', 'fp8')
elif cache_4bit:
set('cache_type', 'q4')
elif loader.lower() in ['llama.cpp', 'llamacpp_hf']:
# Llama.cpp loader-specific cache type
if cache_4bit:
set('cache_type', 'q4_0')
elif cache_8bit:
set('cache_type', 'q8_0')
# Clean up legacy keys
del_key('cache_4bit', False)
del_key('cache_8bit', False)
return opts
def add_extension(name, last=False): def add_extension(name, last=False):
if args.extensions is None: if args.extensions is None:
args.extensions = [name] args.extensions = [name]
@ -298,10 +355,14 @@ def load_user_config():
else: else:
user_config = {} user_config = {}
for model_name in user_config:
user_config[model_name] = transform_legacy_kv_cache_options(user_config[model_name])
return user_config return user_config
args.loader = fix_loader_name(args.loader) args.loader = fix_loader_name(args.loader)
args = transform_legacy_kv_cache_options(args)
# Activate the multimodal extension # Activate the multimodal extension
if args.multimodal_pipeline is not None: if args.multimodal_pipeline is not None:

View file

@ -130,8 +130,7 @@ def list_model_elements():
'no_xformers', 'no_xformers',
'no_sdpa', 'no_sdpa',
'num_experts_per_token', 'num_experts_per_token',
'cache_8bit', 'cache_type',
'cache_4bit',
'autosplit', 'autosplit',
'enable_tp', 'enable_tp',
'threads', 'threads',

View file

@ -118,8 +118,7 @@ def create_ui():
shared.gradio['flash_attn'] = gr.Checkbox(label="flash_attn", value=shared.args.flash_attn, info='Use flash-attention.') shared.gradio['flash_attn'] = gr.Checkbox(label="flash_attn", value=shared.args.flash_attn, info='Use flash-attention.')
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices) shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled with tensor cores support. This may increase performance on newer cards.') shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled with tensor cores support. This may increase performance on newer cards.')
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.') shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q6', 'q4'], value=shared.args.cache_type, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
shared.gradio['cache_4bit'] = gr.Checkbox(label="cache_4bit", value=shared.args.cache_4bit, info='Use Q4 cache to save VRAM.')
shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming_llm", value=shared.args.streaming_llm, info='(experimental) Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.') shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming_llm", value=shared.args.streaming_llm, info='(experimental) Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.') shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.')
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.') shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.')