mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-06 15:13:38 +00:00
API: Add token ids to logprobs output
This commit is contained in:
parent
a32ce254f2
commit
c10c6e87ae
1 changed files with 14 additions and 5 deletions
|
|
@ -143,17 +143,17 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
|
|||
if token_id in top_ids:
|
||||
actual_lp = top_log_probs[top_ids.index(token_id)].item()
|
||||
alternatives = [
|
||||
{"token": decoded_strs[top_ids[j]], "logprob": top_log_probs[j].item()}
|
||||
{"token": decoded_strs[top_ids[j]], "token_id": top_ids[j], "logprob": top_log_probs[j].item()}
|
||||
for j in range(k) if top_ids[j] != token_id
|
||||
]
|
||||
else:
|
||||
actual_lp = (prompt_logits[idx, token_id] - all_lse[idx]).item()
|
||||
alternatives = [
|
||||
{"token": decoded_strs[top_ids[j]], "logprob": top_log_probs[j].item()}
|
||||
{"token": decoded_strs[top_ids[j]], "token_id": top_ids[j], "logprob": top_log_probs[j].item()}
|
||||
for j in range(k - 1) # drop lowest to make room
|
||||
]
|
||||
|
||||
entry = {"top_logprobs": [{"token": actual_token_str, "logprob": actual_lp}] + alternatives}
|
||||
entry = {"top_logprobs": [{"token": actual_token_str, "token_id": token_id, "logprob": actual_lp}] + alternatives}
|
||||
entries.append(entry)
|
||||
|
||||
return entries
|
||||
|
|
@ -239,7 +239,7 @@ def format_chat_logprobs(entries):
|
|||
def format_completion_logprobs(entries):
|
||||
"""Format logprob entries into OpenAI completions logprobs format.
|
||||
|
||||
Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "text_offset"}
|
||||
Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "top_logprobs_ids": [{token_id: prob}], "text_offset"}
|
||||
"""
|
||||
if not entries:
|
||||
return None
|
||||
|
|
@ -247,6 +247,7 @@ def format_completion_logprobs(entries):
|
|||
tokens = []
|
||||
token_logprobs = []
|
||||
top_logprobs = []
|
||||
top_logprobs_ids = []
|
||||
text_offset = []
|
||||
offset = 0
|
||||
|
||||
|
|
@ -257,6 +258,7 @@ def format_completion_logprobs(entries):
|
|||
tokens.append(token_str)
|
||||
token_logprobs.append(None)
|
||||
top_logprobs.append(None)
|
||||
top_logprobs_ids.append(None)
|
||||
text_offset.append(offset)
|
||||
offset += len(token_str)
|
||||
continue
|
||||
|
|
@ -273,21 +275,28 @@ def format_completion_logprobs(entries):
|
|||
offset += len(token_str)
|
||||
|
||||
top_dict = {}
|
||||
top_dict_ids = {}
|
||||
for item in top:
|
||||
t = item.get('token', '')
|
||||
lp = item.get('logprob', item.get('prob', 0))
|
||||
top_dict[t] = lp
|
||||
if 'token_id' in item:
|
||||
top_dict_ids[item['token_id']] = lp
|
||||
top_logprobs.append(top_dict)
|
||||
top_logprobs_ids.append(top_dict_ids if top_dict_ids else None)
|
||||
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
return {
|
||||
result = {
|
||||
"tokens": tokens,
|
||||
"token_logprobs": token_logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
"text_offset": text_offset
|
||||
}
|
||||
if any(x is not None for x in top_logprobs_ids):
|
||||
result["top_logprobs_ids"] = top_logprobs_ids
|
||||
return result
|
||||
|
||||
|
||||
def process_parameters(body, is_legacy=False):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue