diff --git a/modules/logits.py b/modules/logits.py index 9a4243ff..32aef7ae 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -45,6 +45,9 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur output = {} for entry in logprobs: token = repr(entry['token']) + if len(token) > 2 and token.startswith("'") and token.endswith("'"): + token = token[1:-1] + prob = entry['prob'] if use_samplers else np.exp(entry['logprob']) output[token] = prob return output @@ -52,6 +55,9 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur output = '' for entry in logprobs: token = repr(entry['token']) + if len(token) > 2 and token.startswith("'") and token.endswith("'"): + token = token[1:-1] + prob = entry['prob'] if use_samplers else np.exp(entry['logprob']) output += f"{prob:.5f} - {token}\n" return output, previous