diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index eb801940..4aa46375 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -65,7 +65,7 @@ class Exllamav2HF(PreTrainedModel, GenerationMixin): elif kv_cache_type == 'q4': cache_type = ExLlamaV2Cache_Q4 else: - raise ValueError(f"Invalid cache type for ExLlamaV2: {cache_type}. Valid options are: fp16, fp8, q8, q6, q4.") + raise ValueError(f"Invalid cache type for ExLlamaV2: {kv_cache_type}. Valid options are: fp16, fp8, q8, q6, q4.") # Use TP if specified if shared.args.enable_tp: @@ -78,12 +78,10 @@ class Exllamav2HF(PreTrainedModel, GenerationMixin): self.past_seq = None if shared.args.cfg_cache: - if shared.args.cache_8bit: - self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model) - elif shared.args.cache_4bit: - self.ex_cache_negative = ExLlamaV2Cache_Q4(self.ex_model) + if shared.args.enable_tp: + self.ex_cache_negative = ExLlamaV2Cache_TP(self.ex_model, base=cache_type) else: - self.ex_cache_negative = ExLlamaV2Cache(self.ex_model) + self.ex_cache_negative = cache_type(self.ex_model, lazy=shared.args.autosplit) self.past_seq_negative = None