diff --git a/modules/logits.py b/modules/logits.py index 3bcdd07f..7a5f98ae 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -106,7 +106,7 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur output = shared.model(input_ids=tokens) scores = output['logits'][-1][-1] - probs = torch.softmax(scores, dim=-1, dtype=torch.float) + probs = torch.softmax(scores.detach(), dim=-1, dtype=torch.float) topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True) if hasattr(shared.tokenizer, 'convert_ids_to_tokens'): tokens = [shared.tokenizer.convert_ids_to_tokens(int(i)) for i in topk_indices]