From ea1f8c71f2e92dc9ae230b943c605e43ff5c633c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 2 Apr 2026 14:30:59 -0300 Subject: [PATCH] API: Optimize prompt logprobs and refactor ExLlamav3 forward pass --- modules/api/completions.py | 69 ++++++++++++++++++++++++-------------- modules/exllamav3.py | 14 ++++++++ 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/modules/api/completions.py b/modules/api/completions.py index 453fa07b..4eb8fdad 100644 --- a/modules/api/completions.py +++ b/modules/api/completions.py @@ -90,16 +90,8 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None): import torch - if loader == 'ExLlamav3' and hasattr(model, 'model') and hasattr(model, 'cache'): - # Native ExLlamav3: call the underlying Model.forward() in chunks - # to avoid OOM from giant logits tensors (seq_len * vocab_size * 4 bytes) - input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long) - input_ids_tensor = input_ids_tensor.view(-1).cpu() - with torch.no_grad(): - logits = model.model.forward( - input_ids=input_ids_tensor.view(1, -1), - params={"attn_mode": "flash_attn_nc"} - ).float().cpu() + if hasattr(model, 'get_prompt_logits'): + logits = model.get_prompt_logits(input_ids) elif hasattr(model, 'forward'): # HF-compatible loaders (Transformers, ExLlamav3_HF, etc.) @@ -111,26 +103,54 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None): # not just the last token (some HF wrappers like ExLlamav3_HF # only compute the last-token logits when labels are absent). outputs = model(input_ids=input_ids_tensor, labels=input_ids_tensor) - logits = outputs.logits.float().cpu() + logits = outputs.logits # keep on GPU, (1, seq_len, vocab) in model dtype + del outputs else: return [] entries = [{"token": first_token_str, "null_logprob": True}] - # Batch logsumexp and topk as single operations across all positions - # to avoid per-position kernel launch overhead. - prompt_logits = logits[0, :n_tokens - 1] # positions 0..n-2 predict tokens 1..n-1 - k = min(logprobs_count, prompt_logits.shape[-1]) - all_top_values, all_top_indices = torch.topk(prompt_logits, k=k, dim=-1) - all_lse = torch.logsumexp(prompt_logits, dim=-1) - all_top_log_probs = all_top_values - all_lse.unsqueeze(-1) - - # Batch-decode all unique token IDs to avoid O(N*k) individual decode calls + logprobs_count = max(logprobs_count, 1) + k = min(logprobs_count, logits.shape[-1]) + chunk_size = 2048 unique_ids = set(int(tid) for tid in token_ids[1:]) - unique_ids.update(int(tid) for tid in all_top_indices.flatten().tolist()) - decoded_strs = {tid: shared.tokenizer.decode(torch.tensor([tid])) for tid in unique_ids} + # Process logits in chunks on GPU, only move top-K results to CPU + all_top_log_probs_list = [] + all_top_indices_list = [] + all_actual_lps = [] + + for start in range(0, n_tokens - 1, chunk_size): + end = min(start + chunk_size, n_tokens - 1) + chunk_logits = logits[0, start:end].float() # (chunk, vocab) on GPU + chunk_lse = torch.logsumexp(chunk_logits, dim=-1) + chunk_top_values, chunk_top_indices = torch.topk(chunk_logits, k=k, dim=-1) + chunk_top_log_probs = chunk_top_values - chunk_lse.unsqueeze(-1) + + # Compute logprob for actual next tokens in this chunk + chunk_top_sets = [set(chunk_top_indices[j].tolist()) for j in range(end - start)] + for j in range(end - start): + actual_tid = int(token_ids[start + j + 1]) + if actual_tid not in chunk_top_sets[j]: + all_actual_lps.append((chunk_logits[j, actual_tid] - chunk_lse[j]).item()) + else: + all_actual_lps.append(None) # will use top_log_probs + + all_top_log_probs_list.append(chunk_top_log_probs.cpu()) + all_top_indices_list.append(chunk_top_indices.cpu()) + unique_ids.update(int(tid) for tid in chunk_top_indices.flatten().tolist()) + del chunk_logits, chunk_lse, chunk_top_values + + del logits + torch.cuda.empty_cache() + + all_top_log_probs = torch.cat(all_top_log_probs_list, dim=0) + all_top_indices = torch.cat(all_top_indices_list, dim=0) + + unique_ids_list = sorted(unique_ids) + decoded_list = shared.tokenizer.batch_decode([[tid] for tid in unique_ids_list]) if hasattr(shared.tokenizer, 'batch_decode') else [shared.tokenizer.decode(torch.tensor([tid])) for tid in unique_ids_list] + decoded_strs = dict(zip(unique_ids_list, decoded_list)) for i in range(1, n_tokens): token_id = int(token_ids[i]) @@ -139,7 +159,6 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None): top_ids = all_top_indices[idx].tolist() actual_token_str = decoded_strs[token_id] - # Build the top list with the actual prompt token guaranteed at front if token_id in top_ids: actual_lp = top_log_probs[top_ids.index(token_id)].item() alternatives = [ @@ -147,10 +166,10 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None): for j in range(k) if top_ids[j] != token_id ] else: - actual_lp = (prompt_logits[idx, token_id] - all_lse[idx]).item() + actual_lp = all_actual_lps[idx] alternatives = [ {"token": decoded_strs[top_ids[j]], "token_id": top_ids[j], "logprob": top_log_probs[j].item()} - for j in range(k - 1) # drop lowest to make room + for j in range(k - 1) ] entry = {"top_logprobs": [{"token": actual_token_str, "token_id": token_id, "logprob": actual_lp}] + alternatives} diff --git a/modules/exllamav3.py b/modules/exllamav3.py index 7556a908..e1efbfeb 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -527,6 +527,20 @@ class Exllamav3Model: return output + def get_prompt_logits(self, input_ids): + """Return logits for all positions via a single no-cache forward pass. + + Used by prompt logprobs computation. Returns (1, seq_len, vocab) on CPU in float32. + """ + import torch + input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long) + input_ids_tensor = input_ids_tensor.view(1, -1).cpu() + with torch.no_grad(): + return self.model.forward( + input_ids=input_ids_tensor, + params={"attn_mode": "flash_attn_nc"} + ).cpu().float() + def get_logits(self, token_ids, **kwargs): """ Process a batch of token_ids and return the logits for the last token.