diff --git a/modules/api/completions.py b/modules/api/completions.py index a15e1f86..453fa07b 100644 --- a/modules/api/completions.py +++ b/modules/api/completions.py @@ -143,17 +143,17 @@ def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None): if token_id in top_ids: actual_lp = top_log_probs[top_ids.index(token_id)].item() alternatives = [ - {"token": decoded_strs[top_ids[j]], "logprob": top_log_probs[j].item()} + {"token": decoded_strs[top_ids[j]], "token_id": top_ids[j], "logprob": top_log_probs[j].item()} for j in range(k) if top_ids[j] != token_id ] else: actual_lp = (prompt_logits[idx, token_id] - all_lse[idx]).item() alternatives = [ - {"token": decoded_strs[top_ids[j]], "logprob": top_log_probs[j].item()} + {"token": decoded_strs[top_ids[j]], "token_id": top_ids[j], "logprob": top_log_probs[j].item()} for j in range(k - 1) # drop lowest to make room ] - entry = {"top_logprobs": [{"token": actual_token_str, "logprob": actual_lp}] + alternatives} + entry = {"top_logprobs": [{"token": actual_token_str, "token_id": token_id, "logprob": actual_lp}] + alternatives} entries.append(entry) return entries @@ -239,7 +239,7 @@ def format_chat_logprobs(entries): def format_completion_logprobs(entries): """Format logprob entries into OpenAI completions logprobs format. - Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "text_offset"} + Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "top_logprobs_ids": [{token_id: prob}], "text_offset"} """ if not entries: return None @@ -247,6 +247,7 @@ def format_completion_logprobs(entries): tokens = [] token_logprobs = [] top_logprobs = [] + top_logprobs_ids = [] text_offset = [] offset = 0 @@ -257,6 +258,7 @@ def format_completion_logprobs(entries): tokens.append(token_str) token_logprobs.append(None) top_logprobs.append(None) + top_logprobs_ids.append(None) text_offset.append(offset) offset += len(token_str) continue @@ -273,21 +275,28 @@ def format_completion_logprobs(entries): offset += len(token_str) top_dict = {} + top_dict_ids = {} for item in top: t = item.get('token', '') lp = item.get('logprob', item.get('prob', 0)) top_dict[t] = lp + if 'token_id' in item: + top_dict_ids[item['token_id']] = lp top_logprobs.append(top_dict) + top_logprobs_ids.append(top_dict_ids if top_dict_ids else None) if not tokens: return None - return { + result = { "tokens": tokens, "token_logprobs": token_logprobs, "top_logprobs": top_logprobs, "text_offset": text_offset } + if any(x is not None for x in top_logprobs_ids): + result["top_logprobs_ids"] = top_logprobs_ids + return result def process_parameters(body, is_legacy=False):