Fix ExLlamav3 OOM on prompt logprobs and qwen3_5_moe HF compat

This commit is contained in:
oobabooga 2026-04-01 19:08:37 -07:00
parent 328534b762
commit 4073164be0
3 changed files with 17 additions and 61 deletions

View file

@ -91,17 +91,14 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
import torch
if loader == 'ExLlamav3' and hasattr(model, 'model') and hasattr(model, 'cache'):
# Native ExLlamav3: call the underlying Model.forward() directly
# Native ExLlamav3: call the underlying Model.forward() in chunks
# to avoid OOM from giant logits tensors (seq_len * vocab_size * 4 bytes)
input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long)
input_ids_tensor = input_ids_tensor.view(-1).cpu()
with torch.no_grad():
logits = model.model.forward(
input_ids=input_ids_tensor,
params={
"attn_mode": "flash_attn",
"cache": model.cache,
"past_len": 0,
"batch_shape": (1, model.max_tokens),
}
input_ids=input_ids_tensor.view(1, -1),
params={"attn_mode": "flash_attn_nc"}
).float().cpu()
elif hasattr(model, 'forward'):

View file

@ -530,39 +530,14 @@ class Exllamav3Model:
def get_logits(self, token_ids, **kwargs):
"""
Process a batch of token_ids and return the logits for the last token.
This will reset and overwrite the model's cache.
Uses flash_attn_nc (no cache) for correct results with recurrent models.
"""
# Initialize a single params dictionary that will be updated in-place
params = {
"cache": self.cache,
"reconstruct": False,
"attn_mode": "flash_attn",
"batch_shape": (1, self.max_tokens),
"past_len": 0
}
params.update(kwargs)
# Process prefix tokens to fill the cache and generate recurrent state
if token_ids.shape[-1] > 1:
prefix_ids = token_ids[:, :-1]
# This forward call updates the 'params' dict with the recurrent state
self.model.forward(
input_ids=prefix_ids,
params=params
)
# Update past_len for the next call
params["past_len"] = prefix_ids.shape[-1]
# Process the last token, now using the state-filled 'params' dict
last_token_ids = token_ids[:, -1:]
logits = self.model.forward(
input_ids=last_token_ids,
params=params
input_ids=token_ids,
params={"attn_mode": "flash_attn_nc"}
)
return logits.float().cpu()
return logits[:, -1:, :].float().cpu()
def encode(self, string, **kwargs):
add_bos = kwargs.pop('add_bos', True)

View file

@ -26,6 +26,9 @@ except Exception:
class Exllamav3HF(PreTrainedModel, GenerationMixin):
def __init__(self, model_dir):
hf_config = PretrainedConfig.from_pretrained(model_dir)
# Ensure text_config is a proper object, not a dict (fixes qwen3_5_moe + transformers compat)
if isinstance(getattr(hf_config, 'text_config', None), dict):
hf_config.text_config = PretrainedConfig(**hf_config.text_config)
super().__init__(hf_config)
exl3_config = Config.from_directory(model_dir)
@ -199,30 +202,11 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
}
).to(input_ids.device).float()
else:
# Labels path: use cache for cross-chunk attention.
tokens_to_process = seq_tensor
all_logits = None
current_len = 0
for i in range(0, tokens_to_process.shape[0], max_chunk_size):
chunk = tokens_to_process[i:i + max_chunk_size]
chunk_logits = self.ex_model.forward(
input_ids=chunk.view(1, -1),
params={
"attn_mode": "flash_attn",
"cache": ex_cache,
"past_len": current_len,
"batch_shape": (1, self.max_tokens),
}
).float()
current_len += chunk.shape[0]
if all_logits is None:
all_logits = chunk_logits
else:
all_logits = torch.cat([all_logits, chunk_logits], dim=1)
logits = all_logits
# Labels path: single pass without cache for correct logits
logits = self.ex_model.forward(
input_ids=seq_tensor.view(1, -1),
params={"attn_mode": "flash_attn_nc"}
).float().cpu()
if is_negative:
self.past_seq_negative = seq_tensor