From 75bf2feb590f2af2785d1deff92f11891c624f67 Mon Sep 17 00:00:00 2001 From: wiger3 <74871505+wiger3@users.noreply.github.com> Date: Wed, 15 Apr 2026 03:29:19 +0200 Subject: [PATCH] Logits display improvements (#7486) --- modules/llama_cpp_server.py | 29 ++++++++++++++++++----------- modules/logits.py | 20 +++++++++++++------- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index c01f5d5b..f77d2e07 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -296,19 +296,26 @@ class LlamaServer: pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) print() - for retry in range(5): - response = self.session.post(url, json=payload) - result = response.json() + def _try_fetch_logits(): + for retry in range(5): + response = self.session.post(url, json=payload) + result = response.json() - if "completion_probabilities" in result: - if use_samplers: - return result["completion_probabilities"][0]["top_probs"] - else: - return result["completion_probabilities"][0]["top_logprobs"] + if "completion_probabilities" in result: + if use_samplers: + return result["completion_probabilities"][0]["top_probs"] + else: + return result["completion_probabilities"][0]["top_logprobs"] - time.sleep(0.05) - else: - raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}") + time.sleep(0.05) + else: + raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}") + + result = _try_fetch_logits() + for entry in result: + if not entry.get('token'): + entry['token'] = self.decode([entry['id']]) + return result def get_prompt_logprob_entries(self, token_ids, n_probs=5, prompt=""): """Get logprob entries for prompt tokens via a single n_predict=0 request. diff --git a/modules/logits.py b/modules/logits.py index 473f5890..09849721 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -36,15 +36,22 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur return error_message, previous # llama.cpp case + def _escaped(token): + chars = [] + for a in token: + # C0 and DEL and C1 + if ord(a) <= 0x1F or 0x7F <= ord(a) <= 0x9F: + chars.append(repr(a)[1:-1]) + else: + chars.append(a) + return ''.join(chars) if shared.model.__class__.__name__ == 'LlamaServer': logprobs = shared.model.get_logits(prompt, state, n_probs=top_logits, use_samplers=use_samplers) if return_dict: output = {} for entry in logprobs: - token = repr(entry['token']) - if len(token) > 2 and token.startswith("'") and token.endswith("'"): - token = token[1:-1] + token = _escaped(entry['token']) prob = entry['prob'] if use_samplers else np.exp(entry['logprob']) output[token] = prob @@ -52,12 +59,11 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur else: output = '' for entry in logprobs: - token = repr(entry['token']) - if len(token) > 2 and token.startswith("'") and token.endswith("'"): - token = token[1:-1] + token = _escaped(entry['token']) + token_id = entry['id'] prob = entry['prob'] if use_samplers else np.exp(entry['logprob']) - output += f"{prob:.5f} - {token}\n" + output += f"{prob:.5f} - [{token}] ({token_id})\n" return output, previous # All other model types