Same as 7f06aec3a1 but for exllamav3_hf

This commit is contained in:
oobabooga 2025-10-09 12:05:45 -07:00
parent 7f06aec3a1
commit deb37b821b

View file

@ -103,6 +103,12 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
labels = kwargs.get('labels', None) labels = kwargs.get('labels', None)
past_key_values = kwargs.get('past_key_values', None) past_key_values = kwargs.get('past_key_values', None)
# Reset the internal sequence state for standalone calls (logit viewer)
# or the very first step of a new generation.
if past_key_values is None:
self.past_seq = None
self.past_seq_negative = None
if len(args) > 0: if len(args) > 0:
if not shared.args.cfg_cache: if not shared.args.cfg_cache:
logger.error("Please enable the cfg-cache option to use CFG with ExLlamav3_HF.") logger.error("Please enable the cfg-cache option to use CFG with ExLlamav3_HF.")
@ -119,7 +125,7 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
ex_cache = self.ex_cache ex_cache = self.ex_cache
seq = input_ids[0].tolist() seq = input_ids[0].tolist()
if is_negative and past_key_values is not None: if is_negative and past_key_values is not None and isinstance(past_key_values, list):
seq = past_key_values + seq seq = past_key_values + seq
seq_tensor = torch.tensor(seq) seq_tensor = torch.tensor(seq)
@ -128,97 +134,50 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
# Maximum number of tokens to process in a single forward pass # Maximum number of tokens to process in a single forward pass
max_chunk_size = 256 max_chunk_size = 256
# Make the forward call
if labels is None:
if past_seq is not None: if past_seq is not None:
min_length = min(past_seq.shape[0], seq_tensor.shape[0]) min_length = min(past_seq.shape[0], seq_tensor.shape[0])
indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
if len(indices) > 0: if len(indices) == 0 and seq_tensor.shape[0] > past_seq.shape[0]:
longest_prefix = indices[0].item()
else:
longest_prefix = min_length
if longest_prefix > 0:
reset = False reset = False
current_len = longest_prefix
remaining_tokens = len(seq_tensor) - longest_prefix - 1
if remaining_tokens > 0: # Create a single `params` dictionary that will be used and modified
# Process tokens from longest_prefix to second-to-last token # in-place across all `forward` calls within this function.
tokens_to_process = seq_tensor[longest_prefix:-1]
# Process in chunks if the number of tokens is large
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size]
self.ex_model.forward(
input_ids=chunk.view(1, -1),
params = { params = {
"attn_mode": "flash_attn", "attn_mode": "flash_attn",
"cache": ex_cache, "cache": ex_cache,
"past_len": longest_prefix + i,
"batch_shape": (1, self.max_tokens), "batch_shape": (1, self.max_tokens),
"reconstruct": False # Force memory-efficient path "reconstruct": False,
"past_len": 0
} }
)
current_len = longest_prefix + remaining_tokens # Make the forward call
if labels is None:
if reset: # If it's an efficient continuation, process only the new tokens
if len(seq_tensor) > 1: if not reset:
# Process all tokens except the last one params["past_len"] = past_seq.shape[0]
tokens_to_process = seq_tensor[:-1] tokens_to_process = seq_tensor[past_seq.shape[0]:]
# Otherwise, process the whole sequence from scratch
# Process in chunks if the number of tokens is large
current_len = 0
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size]
self.ex_model.forward(
input_ids=chunk.view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": current_len,
"batch_shape": (1, self.max_tokens),
"reconstruct": False # Force memory-efficient path
}
)
current_len += chunk.shape[0]
else: else:
current_len = 0 tokens_to_process = seq_tensor
# Process the last token and get logits # Process all but the last token of the sequence/sub-sequence
logits = self.ex_model.forward( if tokens_to_process.shape[0] > 1:
input_ids=seq_tensor[-1:].view(1, -1), prefix_to_process = tokens_to_process[:-1]
params={
"attn_mode": "flash_attn", # Process in chunks if the number of tokens is large
"cache": ex_cache, for i in range(0, prefix_to_process.shape[0], max_chunk_size):
"past_len": current_len, chunk = prefix_to_process[i:i + max_chunk_size]
"batch_shape": (1, self.max_tokens), self.ex_model.forward(input_ids=chunk.view(1, -1), params=params)
"reconstruct": False # Force memory-efficient path params["past_len"] += chunk.shape[0]
}
).to(input_ids.device).float() # Process the last token to get logits
last_token = tokens_to_process[-1:].view(1, -1)
logits = self.ex_model.forward(input_ids=last_token, params=params).to(input_ids.device).float()
else: else:
# When processing with labels, handle as a complete sequence # When processing with labels, handle as a complete sequence
# Process in chunks if the number of tokens is large params["attn_mode"] = "flash_attn_nc"
tokens_to_process = seq_tensor logits = self.ex_model.forward(input_ids=seq_tensor.view(1,-1), params=params).float()
all_logits = None
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size]
chunk_logits = self.ex_model.forward(
input_ids=chunk.view(1, -1),
params={
"attn_mode": "flash_attn_nc", # No caching for training
"reconstruct": False # Force memory-efficient path
}
).float()
if all_logits is None:
all_logits = chunk_logits
else:
all_logits = torch.cat([all_logits, chunk_logits], dim=1)
logits = all_logits
if is_negative: if is_negative:
self.past_seq_negative = seq_tensor self.past_seq_negative = seq_tensor