From deb37b821bb13f241dcc3e6aff7cbf41bbae143f Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 9 Oct 2025 12:05:45 -0700 Subject: [PATCH] Same as 7f06aec3a1fc2e6924d87035483ce10ce65af058 but for exllamav3_hf --- modules/exllamav3_hf.py | 129 ++++++++++++++-------------------------- 1 file changed, 44 insertions(+), 85 deletions(-) diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py index 05b473b7..c606912b 100644 --- a/modules/exllamav3_hf.py +++ b/modules/exllamav3_hf.py @@ -103,6 +103,12 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): labels = kwargs.get('labels', None) past_key_values = kwargs.get('past_key_values', None) + # Reset the internal sequence state for standalone calls (logit viewer) + # or the very first step of a new generation. + if past_key_values is None: + self.past_seq = None + self.past_seq_negative = None + if len(args) > 0: if not shared.args.cfg_cache: logger.error("Please enable the cfg-cache option to use CFG with ExLlamav3_HF.") @@ -119,8 +125,8 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): ex_cache = self.ex_cache seq = input_ids[0].tolist() - if is_negative and past_key_values is not None: - seq = past_key_values + seq + if is_negative and past_key_values is not None and isinstance(past_key_values, list): + seq = past_key_values + seq seq_tensor = torch.tensor(seq) reset = True @@ -128,97 +134,50 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): # Maximum number of tokens to process in a single forward pass max_chunk_size = 256 + if past_seq is not None: + min_length = min(past_seq.shape[0], seq_tensor.shape[0]) + indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) + if len(indices) == 0 and seq_tensor.shape[0] > past_seq.shape[0]: + reset = False + + # Create a single `params` dictionary that will be used and modified + # in-place across all `forward` calls within this function. + params = { + "attn_mode": "flash_attn", + "cache": ex_cache, + "batch_shape": (1, self.max_tokens), + "reconstruct": False, + "past_len": 0 + } + # Make the forward call if labels is None: - if past_seq is not None: - min_length = min(past_seq.shape[0], seq_tensor.shape[0]) - indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) - if len(indices) > 0: - longest_prefix = indices[0].item() - else: - longest_prefix = min_length + # If it's an efficient continuation, process only the new tokens + if not reset: + params["past_len"] = past_seq.shape[0] + tokens_to_process = seq_tensor[past_seq.shape[0]:] + # Otherwise, process the whole sequence from scratch + else: + tokens_to_process = seq_tensor - if longest_prefix > 0: - reset = False - current_len = longest_prefix - remaining_tokens = len(seq_tensor) - longest_prefix - 1 + # Process all but the last token of the sequence/sub-sequence + if tokens_to_process.shape[0] > 1: + prefix_to_process = tokens_to_process[:-1] - if remaining_tokens > 0: - # Process tokens from longest_prefix to second-to-last token - tokens_to_process = seq_tensor[longest_prefix:-1] + # Process in chunks if the number of tokens is large + for i in range(0, prefix_to_process.shape[0], max_chunk_size): + chunk = prefix_to_process[i:i + max_chunk_size] + self.ex_model.forward(input_ids=chunk.view(1, -1), params=params) + params["past_len"] += chunk.shape[0] - # Process in chunks if the number of tokens is large - for i in range(0, tokens_to_process.shape[0], max_chunk_size): - chunk = tokens_to_process[i:i + max_chunk_size] - self.ex_model.forward( - input_ids=chunk.view(1, -1), - params={ - "attn_mode": "flash_attn", - "cache": ex_cache, - "past_len": longest_prefix + i, - "batch_shape": (1, self.max_tokens), - "reconstruct": False # Force memory-efficient path - } - ) - - current_len = longest_prefix + remaining_tokens - - if reset: - if len(seq_tensor) > 1: - # Process all tokens except the last one - tokens_to_process = seq_tensor[:-1] - - # Process in chunks if the number of tokens is large - current_len = 0 - for i in range(0, tokens_to_process.shape[0], max_chunk_size): - chunk = tokens_to_process[i:i + max_chunk_size] - self.ex_model.forward( - input_ids=chunk.view(1, -1), - params={ - "attn_mode": "flash_attn", - "cache": ex_cache, - "past_len": current_len, - "batch_shape": (1, self.max_tokens), - "reconstruct": False # Force memory-efficient path - } - ) - current_len += chunk.shape[0] - else: - current_len = 0 - - # Process the last token and get logits - logits = self.ex_model.forward( - input_ids=seq_tensor[-1:].view(1, -1), - params={ - "attn_mode": "flash_attn", - "cache": ex_cache, - "past_len": current_len, - "batch_shape": (1, self.max_tokens), - "reconstruct": False # Force memory-efficient path - } - ).to(input_ids.device).float() + # Process the last token to get logits + last_token = tokens_to_process[-1:].view(1, -1) + logits = self.ex_model.forward(input_ids=last_token, params=params).to(input_ids.device).float() else: # When processing with labels, handle as a complete sequence - # Process in chunks if the number of tokens is large - tokens_to_process = seq_tensor - all_logits = None + params["attn_mode"] = "flash_attn_nc" + logits = self.ex_model.forward(input_ids=seq_tensor.view(1,-1), params=params).float() - for i in range(0, tokens_to_process.shape[0], max_chunk_size): - chunk = tokens_to_process[i:i + max_chunk_size] - chunk_logits = self.ex_model.forward( - input_ids=chunk.view(1, -1), - params={ - "attn_mode": "flash_attn_nc", # No caching for training - "reconstruct": False # Force memory-efficient path - } - ).float() - - if all_logits is None: - all_logits = chunk_logits - else: - all_logits = torch.cat([all_logits, chunk_logits], dim=1) - - logits = all_logits if is_negative: self.past_seq_negative = seq_tensor