Logits display improvements (#7486)

This commit is contained in:
wiger3 2026-04-15 03:29:19 +02:00 committed by GitHub
parent fbd95bd5e6
commit 75bf2feb59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 31 additions and 18 deletions

View file

@ -296,19 +296,26 @@ class LlamaServer:
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
print()
for retry in range(5):
response = self.session.post(url, json=payload)
result = response.json()
def _try_fetch_logits():
for retry in range(5):
response = self.session.post(url, json=payload)
result = response.json()
if "completion_probabilities" in result:
if use_samplers:
return result["completion_probabilities"][0]["top_probs"]
else:
return result["completion_probabilities"][0]["top_logprobs"]
if "completion_probabilities" in result:
if use_samplers:
return result["completion_probabilities"][0]["top_probs"]
else:
return result["completion_probabilities"][0]["top_logprobs"]
time.sleep(0.05)
else:
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
time.sleep(0.05)
else:
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
result = _try_fetch_logits()
for entry in result:
if not entry.get('token'):
entry['token'] = self.decode([entry['id']])
return result
def get_prompt_logprob_entries(self, token_ids, n_probs=5, prompt=""):
"""Get logprob entries for prompt tokens via a single n_predict=0 request.