Allow more granular KV cache settings (#6561)

This commit is contained in:
Diner Burger 2024-12-17 15:43:48 -05:00 committed by GitHub
parent c43ee5db11
commit addad3c63e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 140 additions and 37 deletions

View file

@ -142,8 +142,6 @@ group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Creat
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.')
group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to not be used.')
group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')
group.add_argument('--cache_4bit', action='store_true', help='Use Q4 cache to save VRAM.')
group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.')
group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.')
@ -166,6 +164,10 @@ group.add_argument('--hqq-backend', type=str, default='PYTORCH_COMPILE', help='B
group = parser.add_argument_group('TensorRT-LLM')
group.add_argument('--cpp-runner', action='store_true', help='Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn\'t support streaming yet.')
# Cache
group = parser.add_argument_group('Cache')
group.add_argument('--cache_type', type=str, default=None, help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
# DeepSpeed
group = parser.add_argument_group('DeepSpeed')
group.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
@ -213,6 +215,8 @@ group.add_argument('--pre_layer', type=int, nargs='+', help='DEPRECATED')
group.add_argument('--checkpoint', type=str, help='DEPRECATED')
group.add_argument('--monkey-patch', action='store_true', help='DEPRECATED')
group.add_argument('--no_inject_fused_attention', action='store_true', help='DEPRECATED')
group.add_argument('--cache_4bit', action='store_true', help='DEPRECATED')
group.add_argument('--cache_8bit', action='store_true', help='DEPRECATED')
group.add_argument('--chat-buttons', action='store_true', help='DEPRECATED')
args = parser.parse_args()
@ -270,6 +274,59 @@ def fix_loader_name(name):
return 'TensorRT-LLM'
def transform_legacy_kv_cache_options(opts):
# Handle both argparse.Namespace and dict here
def get(key):
return opts.get(key) if isinstance(opts, dict) else getattr(opts, key, None)
def set(key, value):
if isinstance(opts, dict):
opts[key] = value
else:
setattr(opts, key, value)
def del_key(key, fallback_set):
# only remove from user dict, can't delete from argparse.Namespace
if type(opts) is dict:
if key in opts:
del opts[key]
else:
setattr(opts, key, fallback_set)
# Retrieve values
loader = get('loader')
cache_type = get('cache_type')
cache_8bit = get('cache_8bit')
cache_4bit = get('cache_4bit')
# Determine cache type based on loader or legacy flags
if not cache_type:
if not loader:
# Legacy behavior: prefer 8-bit over 4-bit to minimize breakage
if cache_8bit:
set('cache_type', 'fp8')
elif cache_4bit:
set('cache_type', 'q4')
elif loader.lower() in ['exllamav2', 'exllamav2_hf']:
# ExLlamaV2 loader-specific cache type
if cache_8bit:
set('cache_type', 'fp8')
elif cache_4bit:
set('cache_type', 'q4')
elif loader.lower() in ['llama.cpp', 'llamacpp_hf']:
# Llama.cpp loader-specific cache type
if cache_4bit:
set('cache_type', 'q4_0')
elif cache_8bit:
set('cache_type', 'q8_0')
# Clean up legacy keys
del_key('cache_4bit', False)
del_key('cache_8bit', False)
return opts
def add_extension(name, last=False):
if args.extensions is None:
args.extensions = [name]
@ -298,10 +355,14 @@ def load_user_config():
else:
user_config = {}
for model_name in user_config:
user_config[model_name] = transform_legacy_kv_cache_options(user_config[model_name])
return user_config
args.loader = fix_loader_name(args.loader)
args = transform_legacy_kv_cache_options(args)
# Activate the multimodal extension
if args.multimodal_pipeline is not None: