ExLlamav3_HF: Optimize prefill and fix CFG cache initialization

This commit is contained in:
oobabooga 2026-03-04 11:09:58 -08:00
parent 9b916f02cd
commit d8af0505a8

View file

@ -84,6 +84,12 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
self.ex_model.load(**load_params)
self.past_seq = None
self.max_tokens = max_tokens
self.layer_type = layer_type
self.cache_kwargs = cache_kwargs
if shared.args.cfg_cache:
self.ex_cache_negative = Cache(self.ex_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
self.past_seq_negative = None
def _validate_model_class(self):
pass
@ -126,7 +132,7 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
reset = True
# Maximum number of tokens to process in a single forward pass
max_chunk_size = 256
max_chunk_size = 2048
# Make the forward call
if labels is None:
@ -147,17 +153,16 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
# Process tokens from longest_prefix to second-to-last token
tokens_to_process = seq_tensor[longest_prefix:-1]
# Process in chunks if the number of tokens is large
# Use prefill() to fill the cache without computing logits
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size]
self.ex_model.forward(
self.ex_model.prefill(
input_ids=chunk.view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": longest_prefix + i,
"batch_shape": (1, self.max_tokens),
"reconstruct": False # Force memory-efficient path
}
)
@ -168,18 +173,17 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
# Process all tokens except the last one
tokens_to_process = seq_tensor[:-1]
# Process in chunks if the number of tokens is large
# Use prefill() to fill the cache without computing logits
current_len = 0
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size]
self.ex_model.forward(
self.ex_model.prefill(
input_ids=chunk.view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": current_len,
"batch_shape": (1, self.max_tokens),
"reconstruct": False # Force memory-efficient path
}
)
current_len += chunk.shape[0]
@ -194,7 +198,6 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
"cache": ex_cache,
"past_len": current_len,
"batch_shape": (1, self.max_tokens),
"reconstruct": False # Force memory-efficient path
}
).to(input_ids.device).float()
else:
@ -208,8 +211,7 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
chunk_logits = self.ex_model.forward(
input_ids=chunk.view(1, -1),
params={
"attn_mode": "flash_attn_nc", # No caching for training
"reconstruct": False # Force memory-efficient path
"attn_mode": "flash_attn_nc",
}
).float()