API: Optimize prompt logprobs and refactor ExLlamav3 forward pass

This commit is contained in:
oobabooga 2026-04-02 14:30:59 -03:00
parent c10c6e87ae
commit ea1f8c71f2
2 changed files with 58 additions and 25 deletions

View file

@ -90,16 +90,8 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
import torch
if loader == 'ExLlamav3' and hasattr(model, 'model') and hasattr(model, 'cache'):
# Native ExLlamav3: call the underlying Model.forward() in chunks
# to avoid OOM from giant logits tensors (seq_len * vocab_size * 4 bytes)
input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long)
input_ids_tensor = input_ids_tensor.view(-1).cpu()
with torch.no_grad():
logits = model.model.forward(
input_ids=input_ids_tensor.view(1, -1),
params={"attn_mode": "flash_attn_nc"}
).float().cpu()
if hasattr(model, 'get_prompt_logits'):
logits = model.get_prompt_logits(input_ids)
elif hasattr(model, 'forward'):
# HF-compatible loaders (Transformers, ExLlamav3_HF, etc.)
@ -111,26 +103,54 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
# not just the last token (some HF wrappers like ExLlamav3_HF
# only compute the last-token logits when labels are absent).
outputs = model(input_ids=input_ids_tensor, labels=input_ids_tensor)
logits = outputs.logits.float().cpu()
logits = outputs.logits # keep on GPU, (1, seq_len, vocab) in model dtype
del outputs
else:
return []
entries = [{"token": first_token_str, "null_logprob": True}]
# Batch logsumexp and topk as single operations across all positions
# to avoid per-position kernel launch overhead.
prompt_logits = logits[0, :n_tokens - 1] # positions 0..n-2 predict tokens 1..n-1
k = min(logprobs_count, prompt_logits.shape[-1])
all_top_values, all_top_indices = torch.topk(prompt_logits, k=k, dim=-1)
all_lse = torch.logsumexp(prompt_logits, dim=-1)
all_top_log_probs = all_top_values - all_lse.unsqueeze(-1)
# Batch-decode all unique token IDs to avoid O(N*k) individual decode calls
logprobs_count = max(logprobs_count, 1)
k = min(logprobs_count, logits.shape[-1])
chunk_size = 2048
unique_ids = set(int(tid) for tid in token_ids[1:])
unique_ids.update(int(tid) for tid in all_top_indices.flatten().tolist())
decoded_strs = {tid: shared.tokenizer.decode(torch.tensor([tid])) for tid in unique_ids}
# Process logits in chunks on GPU, only move top-K results to CPU
all_top_log_probs_list = []
all_top_indices_list = []
all_actual_lps = []
for start in range(0, n_tokens - 1, chunk_size):
end = min(start + chunk_size, n_tokens - 1)
chunk_logits = logits[0, start:end].float() # (chunk, vocab) on GPU
chunk_lse = torch.logsumexp(chunk_logits, dim=-1)
chunk_top_values, chunk_top_indices = torch.topk(chunk_logits, k=k, dim=-1)
chunk_top_log_probs = chunk_top_values - chunk_lse.unsqueeze(-1)
# Compute logprob for actual next tokens in this chunk
chunk_top_sets = [set(chunk_top_indices[j].tolist()) for j in range(end - start)]
for j in range(end - start):
actual_tid = int(token_ids[start + j + 1])
if actual_tid not in chunk_top_sets[j]:
all_actual_lps.append((chunk_logits[j, actual_tid] - chunk_lse[j]).item())
else:
all_actual_lps.append(None) # will use top_log_probs
all_top_log_probs_list.append(chunk_top_log_probs.cpu())
all_top_indices_list.append(chunk_top_indices.cpu())
unique_ids.update(int(tid) for tid in chunk_top_indices.flatten().tolist())
del chunk_logits, chunk_lse, chunk_top_values
del logits
torch.cuda.empty_cache()
all_top_log_probs = torch.cat(all_top_log_probs_list, dim=0)
all_top_indices = torch.cat(all_top_indices_list, dim=0)
unique_ids_list = sorted(unique_ids)
decoded_list = shared.tokenizer.batch_decode([[tid] for tid in unique_ids_list]) if hasattr(shared.tokenizer, 'batch_decode') else [shared.tokenizer.decode(torch.tensor([tid])) for tid in unique_ids_list]
decoded_strs = dict(zip(unique_ids_list, decoded_list))
for i in range(1, n_tokens):
token_id = int(token_ids[i])
@ -139,7 +159,6 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
top_ids = all_top_indices[idx].tolist()
actual_token_str = decoded_strs[token_id]
# Build the top list with the actual prompt token guaranteed at front
if token_id in top_ids:
actual_lp = top_log_probs[top_ids.index(token_id)].item()
alternatives = [
@ -147,10 +166,10 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
for j in range(k) if top_ids[j] != token_id
]
else:
actual_lp = (prompt_logits[idx, token_id] - all_lse[idx]).item()
actual_lp = all_actual_lps[idx]
alternatives = [
{"token": decoded_strs[top_ids[j]], "token_id": top_ids[j], "logprob": top_log_probs[j].item()}
for j in range(k - 1) # drop lowest to make room
for j in range(k - 1)
]
entry = {"top_logprobs": [{"token": actual_token_str, "token_id": token_id, "logprob": actual_lp}] + alternatives}

View file

@ -527,6 +527,20 @@ class Exllamav3Model:
return output
def get_prompt_logits(self, input_ids):
"""Return logits for all positions via a single no-cache forward pass.
Used by prompt logprobs computation. Returns (1, seq_len, vocab) on CPU in float32.
"""
import torch
input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long)
input_ids_tensor = input_ids_tensor.view(1, -1).cpu()
with torch.no_grad():
return self.model.forward(
input_ids=input_ids_tensor,
params={"attn_mode": "flash_attn_nc"}
).cpu().float()
def get_logits(self, token_ids, **kwargs):
"""
Process a batch of token_ids and return the logits for the last token.