API: Implement echo + logprobs for /v1/completions endpoint

This commit is contained in:
oobabooga 2026-03-30 20:49:38 -07:00
parent 6382fbef83
commit 71c1a52afe
3 changed files with 309 additions and 55 deletions

View file

@ -39,6 +39,129 @@ def load_chat_template_file(filepath):
return text
def _first_token_display_str(token_id, prompt, tokenizer):
"""Return the display string for the first prompt token.
Returns empty string for BOS or tokens that don't appear at the start
of the prompt text, so they don't shift text_offset for subsequent tokens.
"""
token_id = int(token_id)
bos_id = getattr(tokenizer, 'bos_token_id', None)
if bos_id is not None and token_id == bos_id:
return ""
import torch
tok = tokenizer.decode(torch.tensor([token_id]))
if not prompt.startswith(tok):
return ""
return tok
def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
"""Compute logprob entries for prompt tokens via a forward pass.
Returns a list of logprob entries in the standard format.
The first token gets a null entry (no conditioning context).
Supported for HF-compatible loaders (Transformers, ExLlamav3_HF, etc.)
via a single forward pass, and for llama.cpp via the server's
prompt_logprobs parameter. Returns [] for unsupported loaders.
"""
if input_ids is None:
input_ids = encode(prompt) # (1, seq_len) tensor or array
token_ids = input_ids[0]
n_tokens = len(token_ids)
if n_tokens == 0:
return []
loader = shared.args.loader
model = shared.model
if loader == 'llama.cpp':
return model.get_prompt_logprob_entries(token_ids, max(logprobs_count, 1), prompt=prompt)
first_token_str = _first_token_display_str(token_ids[0], prompt, shared.tokenizer)
if n_tokens <= 1:
return [{"token": first_token_str, "null_logprob": True}]
import torch
if loader == 'ExLlamav3' and hasattr(model, 'model') and hasattr(model, 'cache'):
# Native ExLlamav3: call the underlying Model.forward() directly
input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long)
with torch.no_grad():
logits = model.model.forward(
input_ids=input_ids_tensor,
params={
"attn_mode": "flash_attn",
"cache": model.cache,
"past_len": 0,
"batch_shape": (1, model.max_tokens),
}
).float().cpu()
elif hasattr(model, 'forward'):
# HF-compatible loaders (Transformers, ExLlamav3_HF, etc.)
input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long)
if hasattr(model, 'device'):
input_ids_tensor = input_ids_tensor.to(model.device)
with torch.no_grad():
# Pass labels to ensure logits are returned for ALL positions,
# not just the last token (some HF wrappers like ExLlamav3_HF
# only compute the last-token logits when labels are absent).
outputs = model(input_ids=input_ids_tensor, labels=input_ids_tensor)
logits = outputs.logits.float().cpu()
else:
return []
entries = [{"token": first_token_str, "null_logprob": True}]
# Batch logsumexp and topk as single operations across all positions
# to avoid per-position kernel launch overhead.
prompt_logits = logits[0, :n_tokens - 1] # positions 0..n-2 predict tokens 1..n-1
k = min(logprobs_count, prompt_logits.shape[-1])
all_top_values, all_top_indices = torch.topk(prompt_logits, k=k, dim=-1)
all_lse = torch.logsumexp(prompt_logits, dim=-1)
all_top_log_probs = all_top_values - all_lse.unsqueeze(-1)
# Batch-decode all unique token IDs to avoid O(N*k) individual decode calls
unique_ids = set(int(tid) for tid in token_ids[1:])
unique_ids.update(int(tid) for tid in all_top_indices.flatten().tolist())
decoded_strs = {tid: shared.tokenizer.decode(torch.tensor([tid])) for tid in unique_ids}
for i in range(1, n_tokens):
token_id = int(token_ids[i])
idx = i - 1
top_log_probs = all_top_log_probs[idx]
top_ids = all_top_indices[idx].tolist()
actual_token_str = decoded_strs[token_id]
# Build the top list with the actual prompt token guaranteed at front
if token_id in top_ids:
actual_lp = top_log_probs[top_ids.index(token_id)].item()
alternatives = [
{"token": decoded_strs[top_ids[j]], "logprob": top_log_probs[j].item()}
for j in range(k) if top_ids[j] != token_id
]
else:
actual_lp = (prompt_logits[idx, token_id] - all_lse[idx]).item()
alternatives = [
{"token": decoded_strs[top_ids[j]], "logprob": top_log_probs[j].item()}
for j in range(k - 1) # drop lowest to make room
]
entry = {"top_logprobs": [{"token": actual_token_str, "logprob": actual_lp}] + alternatives}
entries.append(entry)
return entries
def _get_raw_logprob_entries(offset=0):
"""Get raw logprob entries from llama.cpp/ExLlamav3 backend, starting from offset.
@ -65,6 +188,21 @@ def _parse_entry_top(entry):
return entry.get('top_logprobs', entry.get('top_probs', []))
def _extract_sampled_token(entry, top):
"""Get the actually sampled token and its logprob from a logprob entry.
Uses the entry-level token/logprob when available (the actually sampled
token), falling back to top[0] (highest-probability alternative) which
may differ with non-greedy sampling.
"""
if 'token' in entry:
return entry['token'], entry.get('logprob', entry.get('prob', 0))
token_str = top[0].get('token', '')
token_logprob = top[0].get('logprob', top[0].get('prob', 0))
return token_str, token_logprob
def format_chat_logprobs(entries):
"""Format logprob entries into OpenAI chat completions logprobs format.
@ -79,9 +217,7 @@ def format_chat_logprobs(entries):
if not top:
continue
chosen = top[0]
token_str = chosen.get('token', '')
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
token_str, token_logprob = _extract_sampled_token(entry, top)
top_list = []
for item in top:
@ -118,13 +254,21 @@ def format_completion_logprobs(entries):
offset = 0
for entry in entries:
# Handle null logprob entries (first prompt token with echo)
if entry.get("null_logprob"):
token_str = entry.get("token", "")
tokens.append(token_str)
token_logprobs.append(None)
top_logprobs.append(None)
text_offset.append(offset)
offset += len(token_str)
continue
top = _parse_entry_top(entry)
if not top:
continue
chosen = top[0]
token_str = chosen.get('token', '')
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
token_str, token_logprob = _extract_sampled_token(entry, top)
tokens.append(token_str)
token_logprobs.append(token_logprob)
@ -407,7 +551,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
})
max_tokens = generate_params['max_new_tokens']
if max_tokens in [None, 0]:
if max_tokens is not None and max_tokens <= 0:
raise InvalidRequestError(message="max_tokens must be greater than 0.", param="max_tokens")
if max_tokens is None:
generate_params['max_new_tokens'] = 512
generate_params['auto_max_new_tokens'] = True
@ -652,6 +799,15 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
# common params
generate_params = process_parameters(body, is_legacy=is_legacy)
max_tokens = generate_params['max_new_tokens']
if max_tokens is None:
generate_params['max_new_tokens'] = 512
generate_params['auto_max_new_tokens'] = True
max_tokens = 512
elif max_tokens < 0:
raise InvalidRequestError(message="max_tokens must be greater than or equal to 0.", param="max_tokens")
elif max_tokens == 0 and body.get('logprobs') is None:
raise InvalidRequestError(message="max_tokens is 0 but no logprobs parameter was specified.", param="max_tokens")
generate_params['stream'] = stream
if stop_event is not None:
generate_params['stop_event'] = stop_event
@ -700,9 +856,17 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
prompt = decode(prompt)[0]
prefix = prompt if echo else ''
token_count = len(encode(prompt)[0])
prompt_input_ids = encode(prompt)
token_count = len(prompt_input_ids[0])
total_prompt_token_count += token_count
# Compute prompt logprobs once per prompt (shared across n_completions)
logprobs_val = body.get('logprobs', None)
if echo and logprobs_val is not None and logprobs_val >= 0:
prompt_entries = _compute_prompt_logprob_entries(prompt, logprobs_val, input_ids=prompt_input_ids)
else:
prompt_entries = None
original_seed = generate_params.get('seed', -1)
for _n in range(n_completions):
# Increment seed for each completion to ensure diversity (matches llama.cpp native behavior)
@ -713,29 +877,41 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
logprob_proc.token_alternatives_history.clear()
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
for a in generator:
answer = a
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
if logprob_proc:
all_entries = []
for alt in logprob_proc.token_alternatives_history:
all_entries.extend(_dict_to_logprob_entries(alt))
completion_logprobs = format_completion_logprobs(all_entries)
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
raw = getattr(shared.model, 'last_completion_probabilities', None)
completion_logprobs = format_completion_logprobs(raw)
if max_tokens == 0:
answer = ''
completion_token_count = 0
stop_reason = "stop"
else:
completion_logprobs = None
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
for a in generator:
answer = a
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
total_completion_token_count += completion_token_count
if max_tokens == 0:
all_entries = []
else:
if logprob_proc:
all_entries = []
for alt in logprob_proc.token_alternatives_history:
all_entries.extend(_dict_to_logprob_entries(alt))
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
all_entries = getattr(shared.model, 'last_completion_probabilities', None) or []
else:
all_entries = []
if prompt_entries:
all_entries = prompt_entries + all_entries
completion_logprobs = format_completion_logprobs(all_entries) if all_entries else None
respi = {
"index": choice_index,
@ -775,7 +951,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
prefix = prompt if echo else ''
token_count = len(encode(prompt)[0])
prompt_input_ids = encode(prompt)
token_count = len(prompt_input_ids[0])
# Check if usage should be included in streaming chunks per OpenAI spec
stream_options = body.get('stream_options')
@ -808,37 +985,57 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
return chunk
logprobs_val = body.get('logprobs', None)
if echo and logprobs_val is not None and logprobs_val >= 0:
prompt_entries = _compute_prompt_logprob_entries(prompt, logprobs_val, input_ids=prompt_input_ids)
prompt_logprobs_formatted = format_completion_logprobs(prompt_entries) if prompt_entries else None
else:
prompt_logprobs_formatted = None
# Clear stale logprobs from any previous request before building the
# first chunk, so text_streaming_chunk doesn't pick up old data.
if hasattr(shared.model, 'last_completion_probabilities'):
shared.model.last_completion_probabilities = []
cmpl_logprobs_offset[0] = 0
chunk = text_streaming_chunk(prefix)
if prompt_logprobs_formatted is not None:
chunk[resp_list][0]["logprobs"] = prompt_logprobs_formatted
if include_usage:
chunk['usage'] = None
yield chunk
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
seen_content = ''
completion_token_count = 0
if max_tokens == 0:
answer = ''
completion_token_count = 0
stop_reason = "stop"
else:
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
seen_content = ''
completion_token_count = 0
for a in generator:
answer = a
for a in generator:
answer = a
len_seen = len(seen_content)
new_content = answer[len_seen:]
len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
chunk = text_streaming_chunk(new_content)
if include_usage:
chunk['usage'] = None
yield chunk
seen_content = answer
chunk = text_streaming_chunk(new_content)
if include_usage:
chunk['usage'] = None
yield chunk
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
chunk = text_streaming_chunk(suffix)
chunk[resp_list][0]["finish_reason"] = stop_reason

View file

@ -489,15 +489,35 @@ class Exllamav3Model:
return
id_to_piece = self.tokenizer.get_id_to_piece_list(True)
sampled_ids = result.get("token_ids") # (batch, seq_len) - actually sampled tokens
sampled_probs = result.get("token_probs") # (batch, seq_len) - their probabilities
def _piece(tid):
s = id_to_piece[tid] if tid < len(id_to_piece) else f"<{tid}>"
return s.replace('\u2581', ' ')
def _logprob(prob):
return math.log(prob) if prob > 0 else float("-inf")
# top_k_tokens shape: (batch, seq_len, k), top_k_probs same
for seq_idx in range(top_k_tokens.shape[1]):
entry = {"top_logprobs": []}
for k_idx in range(top_k_tokens.shape[2]):
token_id = top_k_tokens[0, seq_idx, k_idx].item()
prob = top_k_probs[0, seq_idx, k_idx].item()
token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>"
logprob = math.log(prob) if prob > 0 else float("-inf")
entry["top_logprobs"].append({"token": token_str, "logprob": logprob})
entry["top_logprobs"].append({"token": _piece(token_id), "logprob": _logprob(prob)})
# Record the actually sampled token at the entry level so
# format_completion_logprobs uses it instead of top_logprobs[0]
# (they differ with non-greedy sampling).
if sampled_ids is not None:
sid = sampled_ids[0, seq_idx].item()
entry["token"] = _piece(sid)
if sampled_probs is not None:
entry["logprob"] = _logprob(sampled_probs[0, seq_idx].item())
else:
entry["logprob"] = None
self.last_completion_probabilities.append(entry)
def generate(self, prompt, state):

View file

@ -310,8 +310,45 @@ class LlamaServer:
else:
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
def get_prompt_logprob_entries(self, token_ids, n_probs=5, prompt=""):
"""Get logprob entries for prompt tokens via a single n_predict=0 request.
Requires llama.cpp server with prompt_logprobs support.
Returns entries in the standard format for format_completion_logprobs().
"""
token_ids_list = token_ids.tolist() if hasattr(token_ids, 'tolist') else list(token_ids)
url = f"http://127.0.0.1:{self.port}/completion"
payload = {
"prompt": token_ids_list,
"n_predict": 0,
"n_probs": n_probs,
"prompt_logprobs": True,
"stream": False,
"cache_prompt": False,
}
response = self.session.post(url, json=payload)
result = response.json()
prompt_probs = result.get("prompt_probabilities", [])
if not prompt_probs:
return []
# Null first token (no conditioning context); use empty string for BOS
# or tokens that don't appear at the start of the prompt text.
first_token_str = self.decode([token_ids_list[0]])
if self.bos_token and first_token_str == self.bos_token:
first_token_str = ""
elif not prompt.startswith(first_token_str):
first_token_str = ""
entries = [{"token": first_token_str, "null_logprob": True}]
entries.extend(prompt_probs)
return entries
def _get_vocabulary_size(self):
"""Get and store the model's maximum context length."""
"""Get and store the model's vocabulary size."""
url = f"http://127.0.0.1:{self.port}/v1/models"
response = self.session.get(url).json()