From 55283bb8f1a1f7fcef65b8f35ee81658e35a46af Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 30 Apr 2025 18:43:45 -0700 Subject: [PATCH] Fix CFG with ExLlamaV2_HF (closes #6937) --- modules/exllamav2_hf.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index eb801940..4aa46375 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -65,7 +65,7 @@ class Exllamav2HF(PreTrainedModel, GenerationMixin): elif kv_cache_type == 'q4': cache_type = ExLlamaV2Cache_Q4 else: - raise ValueError(f"Invalid cache type for ExLlamaV2: {cache_type}. Valid options are: fp16, fp8, q8, q6, q4.") + raise ValueError(f"Invalid cache type for ExLlamaV2: {kv_cache_type}. Valid options are: fp16, fp8, q8, q6, q4.") # Use TP if specified if shared.args.enable_tp: @@ -78,12 +78,10 @@ class Exllamav2HF(PreTrainedModel, GenerationMixin): self.past_seq = None if shared.args.cfg_cache: - if shared.args.cache_8bit: - self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model) - elif shared.args.cache_4bit: - self.ex_cache_negative = ExLlamaV2Cache_Q4(self.ex_model) + if shared.args.enable_tp: + self.ex_cache_negative = ExLlamaV2Cache_TP(self.ex_model, base=cache_type) else: - self.ex_cache_negative = ExLlamaV2Cache(self.ex_model) + self.ex_cache_negative = cache_type(self.ex_model, lazy=shared.args.autosplit) self.past_seq_negative = None