mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-06 05:33:50 +01:00
Fix requires_grad warning in logits API
This commit is contained in:
parent
64eb77e782
commit
5d93f4e800
|
|
@ -106,7 +106,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
|||
output = shared.model(input_ids=tokens)
|
||||
scores = output['logits'][-1][-1]
|
||||
|
||||
probs = torch.softmax(scores, dim=-1, dtype=torch.float)
|
||||
probs = torch.softmax(scores.detach(), dim=-1, dtype=torch.float)
|
||||
topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True)
|
||||
if hasattr(shared.tokenizer, 'convert_ids_to_tokens'):
|
||||
tokens = [shared.tokenizer.convert_ids_to_tokens(int(i)) for i in topk_indices]
|
||||
|
|
|
|||
Loading…
Reference in a new issue