API: Rewrite logprobs for OpenAI spec compliance across all backends

- Rewrite logprobs output format to match the OpenAI specification for
  both chat completions and completions endpoints
- Fix top_logprobs count being ignored for llama.cpp and ExLlamav3
  backends in chat completions (always returned 1 instead of requested N)
- Fix non-streaming responses only returning logprobs for the last token
  instead of all generated tokens (affects all HF-based loaders)
- Fix logprobs returning null for non-streaming chat requests on HF loaders
- Fix off-by-one returning one extra top alternative on HF loaders
This commit is contained in:
oobabooga 2026-03-12 14:16:34 -03:00
parent 5a017aa338
commit fb1b3b6ddf
3 changed files with 149 additions and 43 deletions

View file

@ -37,35 +37,114 @@ def load_chat_template_file(filepath):
return text
def get_logprobs_from_backend():
"""Read logprobs captured from llama.cpp or ExLlamav3 native backend."""
def _get_raw_logprob_entries(offset=0):
"""Get raw logprob entries from llama.cpp/ExLlamav3 backend, starting from offset.
Returns (new_entries, new_offset).
"""
if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities:
return [], offset
all_entries = shared.model.last_completion_probabilities
new_entries = all_entries[offset:]
return new_entries, len(all_entries)
def _dict_to_logprob_entries(token_dict):
"""Convert a flat {token: logprob} dict (from LogprobProcessor) to raw entry format."""
if not token_dict:
return []
return [{"top_logprobs": [{"token": t, "logprob": lp} for t, lp in token_dict.items()]}]
def _parse_entry_top(entry):
"""Extract the top logprobs list from a raw entry, handling both key names."""
return entry.get('top_logprobs', entry.get('top_probs', []))
def format_chat_logprobs(entries):
"""Format logprob entries into OpenAI chat completions logprobs format.
Output: {"content": [{"token", "logprob", "bytes", "top_logprobs": [...]}]}
"""
if not entries:
return None
# Both backends store data in shared.model.last_completion_probabilities
# Format: [{"top_logprobs": [{"token": "text", "logprob": -0.5}, ...]}, ...]
result = {}
for entry in shared.model.last_completion_probabilities:
top = entry.get('top_logprobs', entry.get('top_probs', []))
content = []
for entry in entries:
top = _parse_entry_top(entry)
if not top:
continue
chosen = top[0]
token_str = chosen.get('token', '')
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
top_list = []
for item in top:
token = item.get('token', '')
logprob = item.get('logprob', item.get('prob', 0))
result[token] = logprob
t = item.get('token', '')
lp = item.get('logprob', item.get('prob', 0))
top_list.append({
"token": t,
"logprob": lp,
"bytes": list(t.encode('utf-8')) if t else None
})
return result
content.append({
"token": token_str,
"logprob": token_logprob,
"bytes": list(token_str.encode('utf-8')) if token_str else None,
"top_logprobs": top_list
})
return {"content": content} if content else None
def convert_logprobs_to_tiktoken(model, logprobs):
# more problems than it's worth.
# try:
# encoder = tiktoken.encoding_for_model(model)
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
# except KeyError:
# # assume native tokens if we can't find the tokenizer
# return logprobs
def format_completion_logprobs(entries):
"""Format logprob entries into OpenAI completions logprobs format.
return logprobs
Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "text_offset"}
"""
if not entries:
return None
tokens = []
token_logprobs = []
top_logprobs = []
text_offset = []
offset = 0
for entry in entries:
top = _parse_entry_top(entry)
if not top:
continue
chosen = top[0]
token_str = chosen.get('token', '')
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
tokens.append(token_str)
token_logprobs.append(token_logprob)
text_offset.append(offset)
offset += len(token_str)
top_dict = {}
for item in top:
t = item.get('token', '')
lp = item.get('logprob', item.get('prob', 0))
top_dict[t] = lp
top_logprobs.append(top_dict)
if not tokens:
return None
return {
"tokens": tokens,
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
"text_offset": text_offset
}
def process_parameters(body, is_legacy=False):
@ -90,6 +169,14 @@ def process_parameters(body, is_legacy=False):
elif isinstance(body['stop'], list):
generate_params['custom_stopping_strings'] = body['stop']
# Resolve logprobs: for chat completions, logprobs is a bool and the count
# comes from top_logprobs. Normalize to an int for all backends.
logprobs = body.get('logprobs', None)
top_logprobs = body.get('top_logprobs', None)
if logprobs is True:
logprobs = top_logprobs if top_logprobs and top_logprobs > 0 else 5
generate_params['logprobs'] = logprobs
# For llama.cpp and ExLlamav3 native, logit_bias and logprobs are forwarded natively
if shared.args.loader not in ('llama.cpp', 'ExLlamav3'):
from transformers import LogitsProcessorList
@ -104,11 +191,6 @@ def process_parameters(body, is_legacy=False):
if logit_bias: # {str: float, ...}
logits_processor = [LogitsBiasProcessor(logit_bias)]
logprobs = body.get('logprobs', None)
top_logprobs = body.get('top_logprobs', None)
# For chat completions, logprobs is a bool; use top_logprobs for the count
if logprobs is True:
logprobs = top_logprobs if top_logprobs and top_logprobs > 0 else 5
if logprobs is not None and logprobs > 0:
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
logits_processor.extend([generate_params['logprob_proc']])
@ -317,6 +399,9 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None)
if logprob_proc:
logprob_proc.token_alternatives_history.clear()
chat_logprobs_offset = [0] # mutable for closure access in streaming
def chat_streaming_chunk(content=None, chunk_tool_calls=None, include_role=False):
# begin streaming
@ -344,12 +429,16 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
}
if logprob_proc:
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
entries = _dict_to_logprob_entries(logprob_proc.token_alternatives)
formatted = format_chat_logprobs(entries)
if formatted:
chunk[resp_list][0]["logprobs"] = formatted
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
backend_logprobs = get_logprobs_from_backend()
if backend_logprobs:
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [backend_logprobs]}
entries, chat_logprobs_offset[0] = _get_raw_logprob_entries(chat_logprobs_offset[0])
if entries:
formatted = format_chat_logprobs(entries)
if formatted:
chunk[resp_list][0]["logprobs"] = formatted
return chunk
@ -471,12 +560,18 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
}
}
if logprob_proc:
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
all_entries = []
for alt in logprob_proc.token_alternatives_history:
all_entries.extend(_dict_to_logprob_entries(alt))
formatted = format_chat_logprobs(all_entries)
if formatted:
resp[resp_list][0]["logprobs"] = formatted
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
backend_logprobs = get_logprobs_from_backend()
if backend_logprobs:
resp[resp_list][0]["logprobs"] = {'top_logprobs': [backend_logprobs]}
raw = getattr(shared.model, 'last_completion_probabilities', None)
if raw:
formatted = format_chat_logprobs(raw)
if formatted:
resp[resp_list][0]["logprobs"] = formatted
yield resp
@ -518,6 +613,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
generate_params['stop_event'] = stop_event
requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None)
if logprob_proc:
logprob_proc.token_alternatives_history.clear()
suffix = body['suffix'] if body['suffix'] else ''
echo = body['echo']
@ -583,10 +680,13 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
stop_reason = "length"
if logprob_proc:
completion_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]}
all_entries = []
for alt in logprob_proc.token_alternatives_history:
all_entries.extend(_dict_to_logprob_entries(alt))
completion_logprobs = format_completion_logprobs(all_entries)
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
backend_logprobs = get_logprobs_from_backend()
completion_logprobs = {'top_logprobs': [backend_logprobs]} if backend_logprobs else None
raw = getattr(shared.model, 'last_completion_probabilities', None)
completion_logprobs = format_completion_logprobs(raw)
else:
completion_logprobs = None
@ -633,14 +733,15 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
# Check if usage should be included in streaming chunks per OpenAI spec
stream_options = body.get('stream_options')
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
cmpl_logprobs_offset = [0] # mutable for closure access in streaming
def text_streaming_chunk(content):
# begin streaming
if logprob_proc:
chunk_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]}
chunk_logprobs = format_completion_logprobs(_dict_to_logprob_entries(logprob_proc.token_alternatives))
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
backend_logprobs = get_logprobs_from_backend()
chunk_logprobs = {'top_logprobs': [backend_logprobs]} if backend_logprobs else None
entries, cmpl_logprobs_offset[0] = _get_raw_logprob_entries(cmpl_logprobs_offset[0])
chunk_logprobs = format_completion_logprobs(entries) if entries else None
else:
chunk_logprobs = None

View file

@ -78,10 +78,13 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
reply = ''
is_stream = state['stream']
if len(all_stop_strings) > 0 and not state['stream']:
original_logits_processor = state.get('logits_processor')
stop_event_ref = state.pop('stop_event', None)
state = copy.deepcopy(state)
if stop_event_ref is not None:
state['stop_event'] = stop_event_ref
if original_logits_processor is not None:
state['logits_processor'] = original_logits_processor
state['stream'] = True
# Generate

View file

@ -65,14 +65,16 @@ class LogprobProcessor(LogitsProcessor):
def __init__(self, logprobs=None):
self.logprobs = logprobs
self.token_alternatives = {}
self.token_alternatives_history = []
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logprobs is not None: # 0-5
log_e_probabilities = F.log_softmax(logits, dim=1)
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs)
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
top_probs = [float(x) for x in top_values[0]]
self.token_alternatives = dict(zip(top_tokens, top_probs))
self.token_alternatives_history.append(self.token_alternatives)
return logits