From d8af0505a805dd5ebadeb9e6aef69f8da54599db Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 4 Mar 2026 11:09:58 -0800 Subject: [PATCH] ExLlamav3_HF: Optimize prefill and fix CFG cache initialization --- modules/exllamav3_hf.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py index e05f8d7d..b4b6ad20 100644 --- a/modules/exllamav3_hf.py +++ b/modules/exllamav3_hf.py @@ -84,6 +84,12 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): self.ex_model.load(**load_params) self.past_seq = None self.max_tokens = max_tokens + self.layer_type = layer_type + self.cache_kwargs = cache_kwargs + + if shared.args.cfg_cache: + self.ex_cache_negative = Cache(self.ex_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs) + self.past_seq_negative = None def _validate_model_class(self): pass @@ -126,7 +132,7 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): reset = True # Maximum number of tokens to process in a single forward pass - max_chunk_size = 256 + max_chunk_size = 2048 # Make the forward call if labels is None: @@ -147,17 +153,16 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): # Process tokens from longest_prefix to second-to-last token tokens_to_process = seq_tensor[longest_prefix:-1] - # Process in chunks if the number of tokens is large + # Use prefill() to fill the cache without computing logits for i in range(0, tokens_to_process.shape[0], max_chunk_size): chunk = tokens_to_process[i:i + max_chunk_size] - self.ex_model.forward( + self.ex_model.prefill( input_ids=chunk.view(1, -1), params={ "attn_mode": "flash_attn", "cache": ex_cache, "past_len": longest_prefix + i, "batch_shape": (1, self.max_tokens), - "reconstruct": False # Force memory-efficient path } ) @@ -168,18 +173,17 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): # Process all tokens except the last one tokens_to_process = seq_tensor[:-1] - # Process in chunks if the number of tokens is large + # Use prefill() to fill the cache without computing logits current_len = 0 for i in range(0, tokens_to_process.shape[0], max_chunk_size): chunk = tokens_to_process[i:i + max_chunk_size] - self.ex_model.forward( + self.ex_model.prefill( input_ids=chunk.view(1, -1), params={ "attn_mode": "flash_attn", "cache": ex_cache, "past_len": current_len, "batch_shape": (1, self.max_tokens), - "reconstruct": False # Force memory-efficient path } ) current_len += chunk.shape[0] @@ -194,7 +198,6 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): "cache": ex_cache, "past_len": current_len, "batch_shape": (1, self.max_tokens), - "reconstruct": False # Force memory-efficient path } ).to(input_ids.device).float() else: @@ -208,8 +211,7 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): chunk_logits = self.ex_model.forward( input_ids=chunk.view(1, -1), params={ - "attn_mode": "flash_attn_nc", # No caching for training - "reconstruct": False # Force memory-efficient path + "attn_mode": "flash_attn_nc", } ).float()