mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-06 15:13:38 +00:00
API: Optimize prompt logprobs and refactor ExLlamav3 forward pass
This commit is contained in:
parent
c10c6e87ae
commit
ea1f8c71f2
2 changed files with 58 additions and 25 deletions
|
|
@ -90,16 +90,8 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
|
|||
|
||||
import torch
|
||||
|
||||
if loader == 'ExLlamav3' and hasattr(model, 'model') and hasattr(model, 'cache'):
|
||||
# Native ExLlamav3: call the underlying Model.forward() in chunks
|
||||
# to avoid OOM from giant logits tensors (seq_len * vocab_size * 4 bytes)
|
||||
input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long)
|
||||
input_ids_tensor = input_ids_tensor.view(-1).cpu()
|
||||
with torch.no_grad():
|
||||
logits = model.model.forward(
|
||||
input_ids=input_ids_tensor.view(1, -1),
|
||||
params={"attn_mode": "flash_attn_nc"}
|
||||
).float().cpu()
|
||||
if hasattr(model, 'get_prompt_logits'):
|
||||
logits = model.get_prompt_logits(input_ids)
|
||||
|
||||
elif hasattr(model, 'forward'):
|
||||
# HF-compatible loaders (Transformers, ExLlamav3_HF, etc.)
|
||||
|
|
@ -111,26 +103,54 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
|
|||
# not just the last token (some HF wrappers like ExLlamav3_HF
|
||||
# only compute the last-token logits when labels are absent).
|
||||
outputs = model(input_ids=input_ids_tensor, labels=input_ids_tensor)
|
||||
logits = outputs.logits.float().cpu()
|
||||
logits = outputs.logits # keep on GPU, (1, seq_len, vocab) in model dtype
|
||||
del outputs
|
||||
|
||||
else:
|
||||
return []
|
||||
|
||||
entries = [{"token": first_token_str, "null_logprob": True}]
|
||||
|
||||
# Batch logsumexp and topk as single operations across all positions
|
||||
# to avoid per-position kernel launch overhead.
|
||||
prompt_logits = logits[0, :n_tokens - 1] # positions 0..n-2 predict tokens 1..n-1
|
||||
k = min(logprobs_count, prompt_logits.shape[-1])
|
||||
all_top_values, all_top_indices = torch.topk(prompt_logits, k=k, dim=-1)
|
||||
all_lse = torch.logsumexp(prompt_logits, dim=-1)
|
||||
all_top_log_probs = all_top_values - all_lse.unsqueeze(-1)
|
||||
|
||||
# Batch-decode all unique token IDs to avoid O(N*k) individual decode calls
|
||||
logprobs_count = max(logprobs_count, 1)
|
||||
k = min(logprobs_count, logits.shape[-1])
|
||||
chunk_size = 2048
|
||||
unique_ids = set(int(tid) for tid in token_ids[1:])
|
||||
unique_ids.update(int(tid) for tid in all_top_indices.flatten().tolist())
|
||||
|
||||
decoded_strs = {tid: shared.tokenizer.decode(torch.tensor([tid])) for tid in unique_ids}
|
||||
# Process logits in chunks on GPU, only move top-K results to CPU
|
||||
all_top_log_probs_list = []
|
||||
all_top_indices_list = []
|
||||
all_actual_lps = []
|
||||
|
||||
for start in range(0, n_tokens - 1, chunk_size):
|
||||
end = min(start + chunk_size, n_tokens - 1)
|
||||
chunk_logits = logits[0, start:end].float() # (chunk, vocab) on GPU
|
||||
chunk_lse = torch.logsumexp(chunk_logits, dim=-1)
|
||||
chunk_top_values, chunk_top_indices = torch.topk(chunk_logits, k=k, dim=-1)
|
||||
chunk_top_log_probs = chunk_top_values - chunk_lse.unsqueeze(-1)
|
||||
|
||||
# Compute logprob for actual next tokens in this chunk
|
||||
chunk_top_sets = [set(chunk_top_indices[j].tolist()) for j in range(end - start)]
|
||||
for j in range(end - start):
|
||||
actual_tid = int(token_ids[start + j + 1])
|
||||
if actual_tid not in chunk_top_sets[j]:
|
||||
all_actual_lps.append((chunk_logits[j, actual_tid] - chunk_lse[j]).item())
|
||||
else:
|
||||
all_actual_lps.append(None) # will use top_log_probs
|
||||
|
||||
all_top_log_probs_list.append(chunk_top_log_probs.cpu())
|
||||
all_top_indices_list.append(chunk_top_indices.cpu())
|
||||
unique_ids.update(int(tid) for tid in chunk_top_indices.flatten().tolist())
|
||||
del chunk_logits, chunk_lse, chunk_top_values
|
||||
|
||||
del logits
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
all_top_log_probs = torch.cat(all_top_log_probs_list, dim=0)
|
||||
all_top_indices = torch.cat(all_top_indices_list, dim=0)
|
||||
|
||||
unique_ids_list = sorted(unique_ids)
|
||||
decoded_list = shared.tokenizer.batch_decode([[tid] for tid in unique_ids_list]) if hasattr(shared.tokenizer, 'batch_decode') else [shared.tokenizer.decode(torch.tensor([tid])) for tid in unique_ids_list]
|
||||
decoded_strs = dict(zip(unique_ids_list, decoded_list))
|
||||
|
||||
for i in range(1, n_tokens):
|
||||
token_id = int(token_ids[i])
|
||||
|
|
@ -139,7 +159,6 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
|
|||
top_ids = all_top_indices[idx].tolist()
|
||||
actual_token_str = decoded_strs[token_id]
|
||||
|
||||
# Build the top list with the actual prompt token guaranteed at front
|
||||
if token_id in top_ids:
|
||||
actual_lp = top_log_probs[top_ids.index(token_id)].item()
|
||||
alternatives = [
|
||||
|
|
@ -147,10 +166,10 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
|
|||
for j in range(k) if top_ids[j] != token_id
|
||||
]
|
||||
else:
|
||||
actual_lp = (prompt_logits[idx, token_id] - all_lse[idx]).item()
|
||||
actual_lp = all_actual_lps[idx]
|
||||
alternatives = [
|
||||
{"token": decoded_strs[top_ids[j]], "token_id": top_ids[j], "logprob": top_log_probs[j].item()}
|
||||
for j in range(k - 1) # drop lowest to make room
|
||||
for j in range(k - 1)
|
||||
]
|
||||
|
||||
entry = {"top_logprobs": [{"token": actual_token_str, "token_id": token_id, "logprob": actual_lp}] + alternatives}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue