diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index a8b899d5..10cdbf42 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -37,35 +37,114 @@ def load_chat_template_file(filepath): return text -def get_logprobs_from_backend(): - """Read logprobs captured from llama.cpp or ExLlamav3 native backend.""" +def _get_raw_logprob_entries(offset=0): + """Get raw logprob entries from llama.cpp/ExLlamav3 backend, starting from offset. + + Returns (new_entries, new_offset). + """ if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities: + return [], offset + + all_entries = shared.model.last_completion_probabilities + new_entries = all_entries[offset:] + return new_entries, len(all_entries) + + +def _dict_to_logprob_entries(token_dict): + """Convert a flat {token: logprob} dict (from LogprobProcessor) to raw entry format.""" + if not token_dict: + return [] + + return [{"top_logprobs": [{"token": t, "logprob": lp} for t, lp in token_dict.items()]}] + + +def _parse_entry_top(entry): + """Extract the top logprobs list from a raw entry, handling both key names.""" + return entry.get('top_logprobs', entry.get('top_probs', [])) + + +def format_chat_logprobs(entries): + """Format logprob entries into OpenAI chat completions logprobs format. + + Output: {"content": [{"token", "logprob", "bytes", "top_logprobs": [...]}]} + """ + if not entries: return None - # Both backends store data in shared.model.last_completion_probabilities - # Format: [{"top_logprobs": [{"token": "text", "logprob": -0.5}, ...]}, ...] - result = {} - for entry in shared.model.last_completion_probabilities: - top = entry.get('top_logprobs', entry.get('top_probs', [])) + content = [] + for entry in entries: + top = _parse_entry_top(entry) + if not top: + continue + + chosen = top[0] + token_str = chosen.get('token', '') + token_logprob = chosen.get('logprob', chosen.get('prob', 0)) + + top_list = [] for item in top: - token = item.get('token', '') - logprob = item.get('logprob', item.get('prob', 0)) - result[token] = logprob + t = item.get('token', '') + lp = item.get('logprob', item.get('prob', 0)) + top_list.append({ + "token": t, + "logprob": lp, + "bytes": list(t.encode('utf-8')) if t else None + }) - return result + content.append({ + "token": token_str, + "logprob": token_logprob, + "bytes": list(token_str.encode('utf-8')) if token_str else None, + "top_logprobs": top_list + }) + + return {"content": content} if content else None -def convert_logprobs_to_tiktoken(model, logprobs): - # more problems than it's worth. - # try: - # encoder = tiktoken.encoding_for_model(model) - # # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall. - # return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()]) - # except KeyError: - # # assume native tokens if we can't find the tokenizer - # return logprobs +def format_completion_logprobs(entries): + """Format logprob entries into OpenAI completions logprobs format. - return logprobs + Output: {"tokens", "token_logprobs", "top_logprobs": [{token: prob}], "text_offset"} + """ + if not entries: + return None + + tokens = [] + token_logprobs = [] + top_logprobs = [] + text_offset = [] + offset = 0 + + for entry in entries: + top = _parse_entry_top(entry) + if not top: + continue + + chosen = top[0] + token_str = chosen.get('token', '') + token_logprob = chosen.get('logprob', chosen.get('prob', 0)) + + tokens.append(token_str) + token_logprobs.append(token_logprob) + text_offset.append(offset) + offset += len(token_str) + + top_dict = {} + for item in top: + t = item.get('token', '') + lp = item.get('logprob', item.get('prob', 0)) + top_dict[t] = lp + top_logprobs.append(top_dict) + + if not tokens: + return None + + return { + "tokens": tokens, + "token_logprobs": token_logprobs, + "top_logprobs": top_logprobs, + "text_offset": text_offset + } def process_parameters(body, is_legacy=False): @@ -90,6 +169,14 @@ def process_parameters(body, is_legacy=False): elif isinstance(body['stop'], list): generate_params['custom_stopping_strings'] = body['stop'] + # Resolve logprobs: for chat completions, logprobs is a bool and the count + # comes from top_logprobs. Normalize to an int for all backends. + logprobs = body.get('logprobs', None) + top_logprobs = body.get('top_logprobs', None) + if logprobs is True: + logprobs = top_logprobs if top_logprobs and top_logprobs > 0 else 5 + generate_params['logprobs'] = logprobs + # For llama.cpp and ExLlamav3 native, logit_bias and logprobs are forwarded natively if shared.args.loader not in ('llama.cpp', 'ExLlamav3'): from transformers import LogitsProcessorList @@ -104,11 +191,6 @@ def process_parameters(body, is_legacy=False): if logit_bias: # {str: float, ...} logits_processor = [LogitsBiasProcessor(logit_bias)] - logprobs = body.get('logprobs', None) - top_logprobs = body.get('top_logprobs', None) - # For chat completions, logprobs is a bool; use top_logprobs for the count - if logprobs is True: - logprobs = top_logprobs if top_logprobs and top_logprobs > 0 else 5 if logprobs is not None and logprobs > 0: generate_params['logprob_proc'] = LogprobProcessor(logprobs) logits_processor.extend([generate_params['logprob_proc']]) @@ -317,6 +399,9 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p requested_model = generate_params.pop('model') logprob_proc = generate_params.pop('logprob_proc', None) + if logprob_proc: + logprob_proc.token_alternatives_history.clear() + chat_logprobs_offset = [0] # mutable for closure access in streaming def chat_streaming_chunk(content=None, chunk_tool_calls=None, include_role=False): # begin streaming @@ -344,12 +429,16 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p } if logprob_proc: - top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) - chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} + entries = _dict_to_logprob_entries(logprob_proc.token_alternatives) + formatted = format_chat_logprobs(entries) + if formatted: + chunk[resp_list][0]["logprobs"] = formatted elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): - backend_logprobs = get_logprobs_from_backend() - if backend_logprobs: - chunk[resp_list][0]["logprobs"] = {'top_logprobs': [backend_logprobs]} + entries, chat_logprobs_offset[0] = _get_raw_logprob_entries(chat_logprobs_offset[0]) + if entries: + formatted = format_chat_logprobs(entries) + if formatted: + chunk[resp_list][0]["logprobs"] = formatted return chunk @@ -471,12 +560,18 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p } } if logprob_proc: - top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) - resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} + all_entries = [] + for alt in logprob_proc.token_alternatives_history: + all_entries.extend(_dict_to_logprob_entries(alt)) + formatted = format_chat_logprobs(all_entries) + if formatted: + resp[resp_list][0]["logprobs"] = formatted elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): - backend_logprobs = get_logprobs_from_backend() - if backend_logprobs: - resp[resp_list][0]["logprobs"] = {'top_logprobs': [backend_logprobs]} + raw = getattr(shared.model, 'last_completion_probabilities', None) + if raw: + formatted = format_chat_logprobs(raw) + if formatted: + resp[resp_list][0]["logprobs"] = formatted yield resp @@ -518,6 +613,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e generate_params['stop_event'] = stop_event requested_model = generate_params.pop('model') logprob_proc = generate_params.pop('logprob_proc', None) + if logprob_proc: + logprob_proc.token_alternatives_history.clear() suffix = body['suffix'] if body['suffix'] else '' echo = body['echo'] @@ -583,10 +680,13 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e stop_reason = "length" if logprob_proc: - completion_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]} + all_entries = [] + for alt in logprob_proc.token_alternatives_history: + all_entries.extend(_dict_to_logprob_entries(alt)) + completion_logprobs = format_completion_logprobs(all_entries) elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): - backend_logprobs = get_logprobs_from_backend() - completion_logprobs = {'top_logprobs': [backend_logprobs]} if backend_logprobs else None + raw = getattr(shared.model, 'last_completion_probabilities', None) + completion_logprobs = format_completion_logprobs(raw) else: completion_logprobs = None @@ -633,14 +733,15 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e # Check if usage should be included in streaming chunks per OpenAI spec stream_options = body.get('stream_options') include_usage = bool(stream_options) and bool(stream_options.get('include_usage') if isinstance(stream_options, dict) else getattr(stream_options, 'include_usage', False)) + cmpl_logprobs_offset = [0] # mutable for closure access in streaming def text_streaming_chunk(content): # begin streaming if logprob_proc: - chunk_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]} + chunk_logprobs = format_completion_logprobs(_dict_to_logprob_entries(logprob_proc.token_alternatives)) elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): - backend_logprobs = get_logprobs_from_backend() - chunk_logprobs = {'top_logprobs': [backend_logprobs]} if backend_logprobs else None + entries, cmpl_logprobs_offset[0] = _get_raw_logprob_entries(cmpl_logprobs_offset[0]) + chunk_logprobs = format_completion_logprobs(entries) if entries else None else: chunk_logprobs = None diff --git a/modules/text_generation.py b/modules/text_generation.py index c78afe3e..787c1814 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -78,10 +78,13 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap reply = '' is_stream = state['stream'] if len(all_stop_strings) > 0 and not state['stream']: + original_logits_processor = state.get('logits_processor') stop_event_ref = state.pop('stop_event', None) state = copy.deepcopy(state) if stop_event_ref is not None: state['stop_event'] = stop_event_ref + if original_logits_processor is not None: + state['logits_processor'] = original_logits_processor state['stream'] = True # Generate diff --git a/modules/transformers_loader.py b/modules/transformers_loader.py index d57020c6..b9918764 100644 --- a/modules/transformers_loader.py +++ b/modules/transformers_loader.py @@ -65,14 +65,16 @@ class LogprobProcessor(LogitsProcessor): def __init__(self, logprobs=None): self.logprobs = logprobs self.token_alternatives = {} + self.token_alternatives_history = [] def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: if self.logprobs is not None: # 0-5 log_e_probabilities = F.log_softmax(logits, dim=1) - top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1) + top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs) top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]] top_probs = [float(x) for x in top_values[0]] self.token_alternatives = dict(zip(top_tokens, top_probs)) + self.token_alternatives_history.append(self.token_alternatives) return logits