API: Add token ids to logprobs output

This commit is contained in:
oobabooga 2026-04-02 07:17:27 -07:00
parent a32ce254f2
commit c10c6e87ae

View file

@ -143,17 +143,17 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
if token_id in top_ids:
actual_lp = top_log_probs[top_ids.index(token_id)].item()
alternatives = [
{"token": decoded_strs[top_ids[j]], "logprob": top_log_probs[j].item()}
{"token": decoded_strs[top_ids[j]], "token_id": top_ids[j], "logprob": top_log_probs[j].item()}
for j in range(k) if top_ids[j] != token_id
]
else:
actual_lp = (prompt_logits[idx, token_id] - all_lse[idx]).item()
alternatives = [
{"token": decoded_strs[top_ids[j]], "logprob": top_log_probs[j].item()}
{"token": decoded_strs[top_ids[j]], "token_id": top_ids[j], "logprob": top_log_probs[j].item()}
for j in range(k - 1) # drop lowest to make room
]
entry = {"top_logprobs": [{"token": actual_token_str, "logprob": actual_lp}] + alternatives}
entry = {"top_logprobs": [{"token": actual_token_str, "token_id": token_id, "logprob": actual_lp}] + alternatives}
entries.append(entry)
return entries
@ -239,7 +239,7 @@ def format_chat_logprobs(entries):
def format_completion_logprobs(entries):
"""Format logprob entries into OpenAI completions logprobs format.
Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "text_offset"}
Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "top_logprobs_ids": [{token_id: prob}], "text_offset"}
"""
if not entries:
return None
@ -247,6 +247,7 @@ def format_completion_logprobs(entries):
tokens = []
token_logprobs = []
top_logprobs = []
top_logprobs_ids = []
text_offset = []
offset = 0
@ -257,6 +258,7 @@ def format_completion_logprobs(entries):
tokens.append(token_str)
token_logprobs.append(None)
top_logprobs.append(None)
top_logprobs_ids.append(None)
text_offset.append(offset)
offset += len(token_str)
continue
@ -273,21 +275,28 @@ def format_completion_logprobs(entries):
offset += len(token_str)
top_dict = {}
top_dict_ids = {}
for item in top:
t = item.get('token', '')
lp = item.get('logprob', item.get('prob', 0))
top_dict[t] = lp
if 'token_id' in item:
top_dict_ids[item['token_id']] = lp
top_logprobs.append(top_dict)
top_logprobs_ids.append(top_dict_ids if top_dict_ids else None)
if not tokens:
return None
return {
result = {
"tokens": tokens,
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
"text_offset": text_offset
}
if any(x is not None for x in top_logprobs_ids):
result["top_logprobs_ids"] = top_logprobs_ids
return result
def process_parameters(body, is_legacy=False):