Fix requires_grad warning in logits API

This commit is contained in:
oobabooga 2026-03-04 10:43:23 -08:00
parent 64eb77e782
commit 5d93f4e800

View file

@ -106,7 +106,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
output = shared.model(input_ids=tokens)
scores = output['logits'][-1][-1]
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
probs = torch.softmax(scores.detach(), dim=-1, dtype=torch.float)
topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True)
if hasattr(shared.tokenizer, 'convert_ids_to_tokens'):
tokens = [shared.tokenizer.convert_ids_to_tokens(int(i)) for i in topk_indices]