From ee0592473c16db1917c98a24ec24dce36fd00bed Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 27 Apr 2025 21:04:02 -0700 Subject: [PATCH] Fix ExLlamaV3_HF leaking memory (attempt) --- modules/exllamav3_hf.py | 95 +++++++++++++++++++++++++++-------------- 1 file changed, 64 insertions(+), 31 deletions(-) diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py index f15fc0b2..12b22f64 100644 --- a/modules/exllamav3_hf.py +++ b/modules/exllamav3_hf.py @@ -118,6 +118,9 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): seq_tensor = torch.tensor(seq) reset = True + # Maximum number of tokens to process in a single forward pass + max_chunk_size = 2048 + # Make the forward call if labels is None: if past_seq is not None: @@ -131,54 +134,84 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): if longest_prefix > 0: reset = False current_len = longest_prefix - if len(seq_tensor) - longest_prefix > 1: - self.ex_model.forward( - input_ids=seq_tensor[longest_prefix:-1].view(1, -1), - params={ - "attn_mode": "flash_attn", - "cache": ex_cache, - "past_len": longest_prefix, - "batch_shape": (1, self.max_tokens) - } - ) + remaining_tokens = len(seq_tensor) - longest_prefix - 1 - current_len = longest_prefix + len(seq_tensor) - longest_prefix - 1 + if remaining_tokens > 0: + # Process tokens from longest_prefix to second-to-last token + tokens_to_process = seq_tensor[longest_prefix:-1] + + # Process in chunks if the number of tokens is large + for i in range(0, tokens_to_process.shape[0], max_chunk_size): + chunk = tokens_to_process[i:i + max_chunk_size] + self.ex_model.forward( + input_ids=chunk.view(1, -1), + params={ + "attn_mode": "flash_attn", + "cache": ex_cache, + "past_len": longest_prefix + i, + "batch_shape": (1, self.max_tokens), + "reconstruct": False # Force memory-efficient path + } + ) + + current_len = longest_prefix + remaining_tokens if reset: if len(seq_tensor) > 1: - self.ex_model.forward( - input_ids=seq_tensor[:-1].view(1, -1), - params={ - "attn_mode": "flash_attn", - "cache": ex_cache, - "past_len": 0, - "batch_shape": (1, self.max_tokens) - } - ) + # Process all tokens except the last one + tokens_to_process = seq_tensor[:-1] - current_len = len(seq_tensor) - 1 + # Process in chunks if the number of tokens is large + current_len = 0 + for i in range(0, tokens_to_process.shape[0], max_chunk_size): + chunk = tokens_to_process[i:i + max_chunk_size] + self.ex_model.forward( + input_ids=chunk.view(1, -1), + params={ + "attn_mode": "flash_attn", + "cache": ex_cache, + "past_len": current_len, + "batch_shape": (1, self.max_tokens), + "reconstruct": False # Force memory-efficient path + } + ) + current_len += chunk.shape[0] else: current_len = 0 + # Process the last token and get logits logits = self.ex_model.forward( input_ids=seq_tensor[-1:].view(1, -1), params={ "attn_mode": "flash_attn", "cache": ex_cache, "past_len": current_len, - "batch_shape": (1, self.max_tokens) + "batch_shape": (1, self.max_tokens), + "reconstruct": False # Force memory-efficient path } ).to(input_ids.device).float() else: - logits = self.ex_model.forward( - input_ids=seq_tensor.view(1, -1), - params={ - "attn_mode": "flash_attn", - "cache": ex_cache, - "past_len": 0, - "batch_shape": (1, self.max_tokens) - } - ).float() + # When processing with labels, handle as a complete sequence + # Process in chunks if the number of tokens is large + tokens_to_process = seq_tensor + all_logits = None + + for i in range(0, tokens_to_process.shape[0], max_chunk_size): + chunk = tokens_to_process[i:i + max_chunk_size] + chunk_logits = self.ex_model.forward( + input_ids=chunk.view(1, -1), + params={ + "attn_mode": "flash_attn_nc", # No caching for training + "reconstruct": False # Force memory-efficient path + } + ).float() + + if all_logits is None: + all_logits = chunk_logits + else: + all_logits = torch.cat([all_logits, chunk_logits], dim=1) + + logits = all_logits if is_negative: self.past_seq_negative = seq_tensor