mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-18 03:14:39 +01:00
API: Rewrite logprobs for OpenAI spec compliance across all backends
- Rewrite logprobs output format to match the OpenAI specification for both chat completions and completions endpoints - Fix top_logprobs count being ignored for llama.cpp and ExLlamav3 backends in chat completions (always returned 1 instead of requested N) - Fix non-streaming responses only returning logprobs for the last token instead of all generated tokens (affects all HF-based loaders) - Fix logprobs returning null for non-streaming chat requests on HF loaders - Fix off-by-one returning one extra top alternative on HF loaders
This commit is contained in:
parent
5a017aa338
commit
fb1b3b6ddf
|
|
@ -37,35 +37,114 @@ def load_chat_template_file(filepath):
|
|||
return text
|
||||
|
||||
|
||||
def get_logprobs_from_backend():
|
||||
"""Read logprobs captured from llama.cpp or ExLlamav3 native backend."""
|
||||
def _get_raw_logprob_entries(offset=0):
|
||||
"""Get raw logprob entries from llama.cpp/ExLlamav3 backend, starting from offset.
|
||||
|
||||
Returns (new_entries, new_offset).
|
||||
"""
|
||||
if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities:
|
||||
return [], offset
|
||||
|
||||
all_entries = shared.model.last_completion_probabilities
|
||||
new_entries = all_entries[offset:]
|
||||
return new_entries, len(all_entries)
|
||||
|
||||
|
||||
def _dict_to_logprob_entries(token_dict):
|
||||
"""Convert a flat {token: logprob} dict (from LogprobProcessor) to raw entry format."""
|
||||
if not token_dict:
|
||||
return []
|
||||
|
||||
return [{"top_logprobs": [{"token": t, "logprob": lp} for t, lp in token_dict.items()]}]
|
||||
|
||||
|
||||
def _parse_entry_top(entry):
|
||||
"""Extract the top logprobs list from a raw entry, handling both key names."""
|
||||
return entry.get('top_logprobs', entry.get('top_probs', []))
|
||||
|
||||
|
||||
def format_chat_logprobs(entries):
|
||||
"""Format logprob entries into OpenAI chat completions logprobs format.
|
||||
|
||||
Output: {"content": [{"token", "logprob", "bytes", "top_logprobs": [...]}]}
|
||||
"""
|
||||
if not entries:
|
||||
return None
|
||||
|
||||
# Both backends store data in shared.model.last_completion_probabilities
|
||||
# Format: [{"top_logprobs": [{"token": "text", "logprob": -0.5}, ...]}, ...]
|
||||
result = {}
|
||||
for entry in shared.model.last_completion_probabilities:
|
||||
top = entry.get('top_logprobs', entry.get('top_probs', []))
|
||||
content = []
|
||||
for entry in entries:
|
||||
top = _parse_entry_top(entry)
|
||||
if not top:
|
||||
continue
|
||||
|
||||
chosen = top[0]
|
||||
token_str = chosen.get('token', '')
|
||||
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
|
||||
|
||||
top_list = []
|
||||
for item in top:
|
||||
token = item.get('token', '')
|
||||
logprob = item.get('logprob', item.get('prob', 0))
|
||||
result[token] = logprob
|
||||
t = item.get('token', '')
|
||||
lp = item.get('logprob', item.get('prob', 0))
|
||||
top_list.append({
|
||||
"token": t,
|
||||
"logprob": lp,
|
||||
"bytes": list(t.encode('utf-8')) if t else None
|
||||
})
|
||||
|
||||
return result
|
||||
content.append({
|
||||
"token": token_str,
|
||||
"logprob": token_logprob,
|
||||
"bytes": list(token_str.encode('utf-8')) if token_str else None,
|
||||
"top_logprobs": top_list
|
||||
})
|
||||
|
||||
return {"content": content} if content else None
|
||||
|
||||
|
||||
def convert_logprobs_to_tiktoken(model, logprobs):
|
||||
# more problems than it's worth.
|
||||
# try:
|
||||
# encoder = tiktoken.encoding_for_model(model)
|
||||
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
||||
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
||||
# except KeyError:
|
||||
# # assume native tokens if we can't find the tokenizer
|
||||
# return logprobs
|
||||
def format_completion_logprobs(entries):
|
||||
"""Format logprob entries into OpenAI completions logprobs format.
|
||||
|
||||
return logprobs
|
||||
Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "text_offset"}
|
||||
"""
|
||||
if not entries:
|
||||
return None
|
||||
|
||||
tokens = []
|
||||
token_logprobs = []
|
||||
top_logprobs = []
|
||||
text_offset = []
|
||||
offset = 0
|
||||
|
||||
for entry in entries:
|
||||
top = _parse_entry_top(entry)
|
||||
if not top:
|
||||
continue
|
||||
|
||||
chosen = top[0]
|
||||
token_str = chosen.get('token', '')
|
||||
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
|
||||
|
||||
tokens.append(token_str)
|
||||
token_logprobs.append(token_logprob)
|
||||
text_offset.append(offset)
|
||||
offset += len(token_str)
|
||||
|
||||
top_dict = {}
|
||||
for item in top:
|
||||
t = item.get('token', '')
|
||||
lp = item.get('logprob', item.get('prob', 0))
|
||||
top_dict[t] = lp
|
||||
top_logprobs.append(top_dict)
|
||||
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
return {
|
||||
"tokens": tokens,
|
||||
"token_logprobs": token_logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
"text_offset": text_offset
|
||||
}
|
||||
|
||||
|
||||
def process_parameters(body, is_legacy=False):
|
||||
|
|
@ -90,6 +169,14 @@ def process_parameters(body, is_legacy=False):
|
|||
elif isinstance(body['stop'], list):
|
||||
generate_params['custom_stopping_strings'] = body['stop']
|
||||
|
||||
# Resolve logprobs: for chat completions, logprobs is a bool and the count
|
||||
# comes from top_logprobs. Normalize to an int for all backends.
|
||||
logprobs = body.get('logprobs', None)
|
||||
top_logprobs = body.get('top_logprobs', None)
|
||||
if logprobs is True:
|
||||
logprobs = top_logprobs if top_logprobs and top_logprobs > 0 else 5
|
||||
generate_params['logprobs'] = logprobs
|
||||
|
||||
# For llama.cpp and ExLlamav3 native, logit_bias and logprobs are forwarded natively
|
||||
if shared.args.loader not in ('llama.cpp', 'ExLlamav3'):
|
||||
from transformers import LogitsProcessorList
|
||||
|
|
@ -104,11 +191,6 @@ def process_parameters(body, is_legacy=False):
|
|||
if logit_bias: # {str: float, ...}
|
||||
logits_processor = [LogitsBiasProcessor(logit_bias)]
|
||||
|
||||
logprobs = body.get('logprobs', None)
|
||||
top_logprobs = body.get('top_logprobs', None)
|
||||
# For chat completions, logprobs is a bool; use top_logprobs for the count
|
||||
if logprobs is True:
|
||||
logprobs = top_logprobs if top_logprobs and top_logprobs > 0 else 5
|
||||
if logprobs is not None and logprobs > 0:
|
||||
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||
logits_processor.extend([generate_params['logprob_proc']])
|
||||
|
|
@ -317,6 +399,9 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
|
||||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
if logprob_proc:
|
||||
logprob_proc.token_alternatives_history.clear()
|
||||
chat_logprobs_offset = [0] # mutable for closure access in streaming
|
||||
|
||||
def chat_streaming_chunk(content=None, chunk_tool_calls=None, include_role=False):
|
||||
# begin streaming
|
||||
|
|
@ -344,12 +429,16 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
}
|
||||
|
||||
if logprob_proc:
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
entries = _dict_to_logprob_entries(logprob_proc.token_alternatives)
|
||||
formatted = format_chat_logprobs(entries)
|
||||
if formatted:
|
||||
chunk[resp_list][0]["logprobs"] = formatted
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
backend_logprobs = get_logprobs_from_backend()
|
||||
if backend_logprobs:
|
||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [backend_logprobs]}
|
||||
entries, chat_logprobs_offset[0] = _get_raw_logprob_entries(chat_logprobs_offset[0])
|
||||
if entries:
|
||||
formatted = format_chat_logprobs(entries)
|
||||
if formatted:
|
||||
chunk[resp_list][0]["logprobs"] = formatted
|
||||
|
||||
return chunk
|
||||
|
||||
|
|
@ -471,12 +560,18 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
}
|
||||
}
|
||||
if logprob_proc:
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
all_entries = []
|
||||
for alt in logprob_proc.token_alternatives_history:
|
||||
all_entries.extend(_dict_to_logprob_entries(alt))
|
||||
formatted = format_chat_logprobs(all_entries)
|
||||
if formatted:
|
||||
resp[resp_list][0]["logprobs"] = formatted
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
backend_logprobs = get_logprobs_from_backend()
|
||||
if backend_logprobs:
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [backend_logprobs]}
|
||||
raw = getattr(shared.model, 'last_completion_probabilities', None)
|
||||
if raw:
|
||||
formatted = format_chat_logprobs(raw)
|
||||
if formatted:
|
||||
resp[resp_list][0]["logprobs"] = formatted
|
||||
|
||||
yield resp
|
||||
|
||||
|
|
@ -518,6 +613,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
generate_params['stop_event'] = stop_event
|
||||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
if logprob_proc:
|
||||
logprob_proc.token_alternatives_history.clear()
|
||||
suffix = body['suffix'] if body['suffix'] else ''
|
||||
echo = body['echo']
|
||||
|
||||
|
|
@ -583,10 +680,13 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
stop_reason = "length"
|
||||
|
||||
if logprob_proc:
|
||||
completion_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]}
|
||||
all_entries = []
|
||||
for alt in logprob_proc.token_alternatives_history:
|
||||
all_entries.extend(_dict_to_logprob_entries(alt))
|
||||
completion_logprobs = format_completion_logprobs(all_entries)
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
backend_logprobs = get_logprobs_from_backend()
|
||||
completion_logprobs = {'top_logprobs': [backend_logprobs]} if backend_logprobs else None
|
||||
raw = getattr(shared.model, 'last_completion_probabilities', None)
|
||||
completion_logprobs = format_completion_logprobs(raw)
|
||||
else:
|
||||
completion_logprobs = None
|
||||
|
||||
|
|
@ -633,14 +733,15 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
# Check if usage should be included in streaming chunks per OpenAI spec
|
||||
stream_options = body.get('stream_options')
|
||||
include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False))
|
||||
cmpl_logprobs_offset = [0] # mutable for closure access in streaming
|
||||
|
||||
def text_streaming_chunk(content):
|
||||
# begin streaming
|
||||
if logprob_proc:
|
||||
chunk_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]}
|
||||
chunk_logprobs = format_completion_logprobs(_dict_to_logprob_entries(logprob_proc.token_alternatives))
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
backend_logprobs = get_logprobs_from_backend()
|
||||
chunk_logprobs = {'top_logprobs': [backend_logprobs]} if backend_logprobs else None
|
||||
entries, cmpl_logprobs_offset[0] = _get_raw_logprob_entries(cmpl_logprobs_offset[0])
|
||||
chunk_logprobs = format_completion_logprobs(entries) if entries else None
|
||||
else:
|
||||
chunk_logprobs = None
|
||||
|
||||
|
|
|
|||
|
|
@ -78,10 +78,13 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||
reply = ''
|
||||
is_stream = state['stream']
|
||||
if len(all_stop_strings) > 0 and not state['stream']:
|
||||
original_logits_processor = state.get('logits_processor')
|
||||
stop_event_ref = state.pop('stop_event', None)
|
||||
state = copy.deepcopy(state)
|
||||
if stop_event_ref is not None:
|
||||
state['stop_event'] = stop_event_ref
|
||||
if original_logits_processor is not None:
|
||||
state['logits_processor'] = original_logits_processor
|
||||
state['stream'] = True
|
||||
|
||||
# Generate
|
||||
|
|
|
|||
|
|
@ -65,14 +65,16 @@ class LogprobProcessor(LogitsProcessor):
|
|||
def __init__(self, logprobs=None):
|
||||
self.logprobs = logprobs
|
||||
self.token_alternatives = {}
|
||||
self.token_alternatives_history = []
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.logprobs is not None: # 0-5
|
||||
log_e_probabilities = F.log_softmax(logits, dim=1)
|
||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs)
|
||||
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
|
||||
top_probs = [float(x) for x in top_values[0]]
|
||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||
self.token_alternatives_history.append(self.token_alternatives)
|
||||
|
||||
return logits
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue