API: Improve cache clearing in logprobs

This commit is contained in:
oobabooga 2026-04-02 17:50:42 -07:00
parent d84157403a
commit 7aab2fdf9a

View file

@ -89,6 +89,7 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
return [{"token": first_token_str, "null_logprob": True}]
import torch
from modules.torch_utils import clear_torch_cache
if hasattr(model, 'get_prompt_logits'):
logits = model.get_prompt_logits(input_ids)
@ -143,7 +144,7 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
del chunk_logits, chunk_lse, chunk_top_values
del logits
torch.cuda.empty_cache()
clear_torch_cache()
all_top_log_probs = torch.cat(all_top_log_probs_list, dim=0)
all_top_indices = torch.cat(all_top_indices_list, dim=0)