diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py index c606912b..05b473b7 100644 --- a/modules/exllamav3_hf.py +++ b/modules/exllamav3_hf.py @@ -103,12 +103,6 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): labels = kwargs.get('labels', None) past_key_values = kwargs.get('past_key_values', None) - # Reset the internal sequence state for standalone calls (logit viewer) - # or the very first step of a new generation. - if past_key_values is None: - self.past_seq = None - self.past_seq_negative = None - if len(args) > 0: if not shared.args.cfg_cache: logger.error("Please enable the cfg-cache option to use CFG with ExLlamav3_HF.") @@ -125,8 +119,8 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): ex_cache = self.ex_cache seq = input_ids[0].tolist() - if is_negative and past_key_values is not None and isinstance(past_key_values, list): - seq = past_key_values + seq + if is_negative and past_key_values is not None: + seq = past_key_values + seq seq_tensor = torch.tensor(seq) reset = True @@ -134,50 +128,97 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): # Maximum number of tokens to process in a single forward pass max_chunk_size = 256 - if past_seq is not None: - min_length = min(past_seq.shape[0], seq_tensor.shape[0]) - indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) - if len(indices) == 0 and seq_tensor.shape[0] > past_seq.shape[0]: - reset = False - - # Create a single `params` dictionary that will be used and modified - # in-place across all `forward` calls within this function. - params = { - "attn_mode": "flash_attn", - "cache": ex_cache, - "batch_shape": (1, self.max_tokens), - "reconstruct": False, - "past_len": 0 - } - # Make the forward call if labels is None: - # If it's an efficient continuation, process only the new tokens - if not reset: - params["past_len"] = past_seq.shape[0] - tokens_to_process = seq_tensor[past_seq.shape[0]:] - # Otherwise, process the whole sequence from scratch - else: - tokens_to_process = seq_tensor + if past_seq is not None: + min_length = min(past_seq.shape[0], seq_tensor.shape[0]) + indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) + if len(indices) > 0: + longest_prefix = indices[0].item() + else: + longest_prefix = min_length - # Process all but the last token of the sequence/sub-sequence - if tokens_to_process.shape[0] > 1: - prefix_to_process = tokens_to_process[:-1] + if longest_prefix > 0: + reset = False + current_len = longest_prefix + remaining_tokens = len(seq_tensor) - longest_prefix - 1 - # Process in chunks if the number of tokens is large - for i in range(0, prefix_to_process.shape[0], max_chunk_size): - chunk = prefix_to_process[i:i + max_chunk_size] - self.ex_model.forward(input_ids=chunk.view(1, -1), params=params) - params["past_len"] += chunk.shape[0] + if remaining_tokens > 0: + # Process tokens from longest_prefix to second-to-last token + tokens_to_process = seq_tensor[longest_prefix:-1] - # Process the last token to get logits - last_token = tokens_to_process[-1:].view(1, -1) - logits = self.ex_model.forward(input_ids=last_token, params=params).to(input_ids.device).float() + # Process in chunks if the number of tokens is large + for i in range(0, tokens_to_process.shape[0], max_chunk_size): + chunk = tokens_to_process[i:i + max_chunk_size] + self.ex_model.forward( + input_ids=chunk.view(1, -1), + params={ + "attn_mode": "flash_attn", + "cache": ex_cache, + "past_len": longest_prefix + i, + "batch_shape": (1, self.max_tokens), + "reconstruct": False # Force memory-efficient path + } + ) + + current_len = longest_prefix + remaining_tokens + + if reset: + if len(seq_tensor) > 1: + # Process all tokens except the last one + tokens_to_process = seq_tensor[:-1] + + # Process in chunks if the number of tokens is large + current_len = 0 + for i in range(0, tokens_to_process.shape[0], max_chunk_size): + chunk = tokens_to_process[i:i + max_chunk_size] + self.ex_model.forward( + input_ids=chunk.view(1, -1), + params={ + "attn_mode": "flash_attn", + "cache": ex_cache, + "past_len": current_len, + "batch_shape": (1, self.max_tokens), + "reconstruct": False # Force memory-efficient path + } + ) + current_len += chunk.shape[0] + else: + current_len = 0 + + # Process the last token and get logits + logits = self.ex_model.forward( + input_ids=seq_tensor[-1:].view(1, -1), + params={ + "attn_mode": "flash_attn", + "cache": ex_cache, + "past_len": current_len, + "batch_shape": (1, self.max_tokens), + "reconstruct": False # Force memory-efficient path + } + ).to(input_ids.device).float() else: # When processing with labels, handle as a complete sequence - params["attn_mode"] = "flash_attn_nc" - logits = self.ex_model.forward(input_ids=seq_tensor.view(1,-1), params=params).float() + # Process in chunks if the number of tokens is large + tokens_to_process = seq_tensor + all_logits = None + for i in range(0, tokens_to_process.shape[0], max_chunk_size): + chunk = tokens_to_process[i:i + max_chunk_size] + chunk_logits = self.ex_model.forward( + input_ids=chunk.view(1, -1), + params={ + "attn_mode": "flash_attn_nc", # No caching for training + "reconstruct": False # Force memory-efficient path + } + ).float() + + if all_logits is None: + all_logits = chunk_logits + else: + all_logits = torch.cat([all_logits, chunk_logits], dim=1) + + logits = all_logits if is_negative: self.past_seq_negative = seq_tensor