Fix ExLlamaV3_HF leaking memory (attempt)

This commit is contained in:
oobabooga 2025-04-27 21:04:02 -07:00
parent 965ca7948f
commit ee0592473c

View file

@ -118,6 +118,9 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
seq_tensor = torch.tensor(seq)
reset = True
# Maximum number of tokens to process in a single forward pass
max_chunk_size = 2048
# Make the forward call
if labels is None:
if past_seq is not None:
@ -131,54 +134,84 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
if longest_prefix > 0:
reset = False
current_len = longest_prefix
if len(seq_tensor) - longest_prefix > 1:
self.ex_model.forward(
input_ids=seq_tensor[longest_prefix:-1].view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": longest_prefix,
"batch_shape": (1, self.max_tokens)
}
)
remaining_tokens = len(seq_tensor) - longest_prefix - 1
current_len = longest_prefix + len(seq_tensor) - longest_prefix - 1
if remaining_tokens > 0:
# Process tokens from longest_prefix to second-to-last token
tokens_to_process = seq_tensor[longest_prefix:-1]
# Process in chunks if the number of tokens is large
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size]
self.ex_model.forward(
input_ids=chunk.view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": longest_prefix + i,
"batch_shape": (1, self.max_tokens),
"reconstruct": False # Force memory-efficient path
}
)
current_len = longest_prefix + remaining_tokens
if reset:
if len(seq_tensor) > 1:
self.ex_model.forward(
input_ids=seq_tensor[:-1].view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": 0,
"batch_shape": (1, self.max_tokens)
}
)
# Process all tokens except the last one
tokens_to_process = seq_tensor[:-1]
current_len = len(seq_tensor) - 1
# Process in chunks if the number of tokens is large
current_len = 0
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size]
self.ex_model.forward(
input_ids=chunk.view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": current_len,
"batch_shape": (1, self.max_tokens),
"reconstruct": False # Force memory-efficient path
}
)
current_len += chunk.shape[0]
else:
current_len = 0
# Process the last token and get logits
logits = self.ex_model.forward(
input_ids=seq_tensor[-1:].view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": current_len,
"batch_shape": (1, self.max_tokens)
"batch_shape": (1, self.max_tokens),
"reconstruct": False # Force memory-efficient path
}
).to(input_ids.device).float()
else:
logits = self.ex_model.forward(
input_ids=seq_tensor.view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": 0,
"batch_shape": (1, self.max_tokens)
}
).float()
# When processing with labels, handle as a complete sequence
# Process in chunks if the number of tokens is large
tokens_to_process = seq_tensor
all_logits = None
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size]
chunk_logits = self.ex_model.forward(
input_ids=chunk.view(1, -1),
params={
"attn_mode": "flash_attn_nc", # No caching for training
"reconstruct": False # Force memory-efficient path
}
).float()
if all_logits is None:
all_logits = chunk_logits
else:
all_logits = torch.cat([all_logits, chunk_logits], dim=1)
logits = all_logits
if is_negative:
self.past_seq_negative = seq_tensor