mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-06 15:13:38 +00:00
API: Implement echo + logprobs for /v1/completions endpoint
This commit is contained in:
parent
6382fbef83
commit
71c1a52afe
3 changed files with 309 additions and 55 deletions
|
|
@ -39,6 +39,129 @@ def load_chat_template_file(filepath):
|
|||
return text
|
||||
|
||||
|
||||
def _first_token_display_str(token_id, prompt, tokenizer):
|
||||
"""Return the display string for the first prompt token.
|
||||
|
||||
Returns empty string for BOS or tokens that don't appear at the start
|
||||
of the prompt text, so they don't shift text_offset for subsequent tokens.
|
||||
"""
|
||||
token_id = int(token_id)
|
||||
bos_id = getattr(tokenizer, 'bos_token_id', None)
|
||||
if bos_id is not None and token_id == bos_id:
|
||||
return ""
|
||||
|
||||
import torch
|
||||
tok = tokenizer.decode(torch.tensor([token_id]))
|
||||
if not prompt.startswith(tok):
|
||||
return ""
|
||||
|
||||
return tok
|
||||
|
||||
|
||||
def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None):
|
||||
"""Compute logprob entries for prompt tokens via a forward pass.
|
||||
|
||||
Returns a list of logprob entries in the standard format.
|
||||
The first token gets a null entry (no conditioning context).
|
||||
|
||||
Supported for HF-compatible loaders (Transformers, ExLlamav3_HF, etc.)
|
||||
via a single forward pass, and for llama.cpp via the server's
|
||||
prompt_logprobs parameter. Returns [] for unsupported loaders.
|
||||
"""
|
||||
if input_ids is None:
|
||||
input_ids = encode(prompt) # (1, seq_len) tensor or array
|
||||
|
||||
token_ids = input_ids[0]
|
||||
n_tokens = len(token_ids)
|
||||
|
||||
if n_tokens == 0:
|
||||
return []
|
||||
|
||||
loader = shared.args.loader
|
||||
model = shared.model
|
||||
|
||||
if loader == 'llama.cpp':
|
||||
return model.get_prompt_logprob_entries(token_ids, max(logprobs_count, 1), prompt=prompt)
|
||||
|
||||
first_token_str = _first_token_display_str(token_ids[0], prompt, shared.tokenizer)
|
||||
|
||||
if n_tokens <= 1:
|
||||
return [{"token": first_token_str, "null_logprob": True}]
|
||||
|
||||
import torch
|
||||
|
||||
if loader == 'ExLlamav3' and hasattr(model, 'model') and hasattr(model, 'cache'):
|
||||
# Native ExLlamav3: call the underlying Model.forward() directly
|
||||
input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long)
|
||||
with torch.no_grad():
|
||||
logits = model.model.forward(
|
||||
input_ids=input_ids_tensor,
|
||||
params={
|
||||
"attn_mode": "flash_attn",
|
||||
"cache": model.cache,
|
||||
"past_len": 0,
|
||||
"batch_shape": (1, model.max_tokens),
|
||||
}
|
||||
).float().cpu()
|
||||
|
||||
elif hasattr(model, 'forward'):
|
||||
# HF-compatible loaders (Transformers, ExLlamav3_HF, etc.)
|
||||
input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long)
|
||||
if hasattr(model, 'device'):
|
||||
input_ids_tensor = input_ids_tensor.to(model.device)
|
||||
with torch.no_grad():
|
||||
# Pass labels to ensure logits are returned for ALL positions,
|
||||
# not just the last token (some HF wrappers like ExLlamav3_HF
|
||||
# only compute the last-token logits when labels are absent).
|
||||
outputs = model(input_ids=input_ids_tensor, labels=input_ids_tensor)
|
||||
logits = outputs.logits.float().cpu()
|
||||
|
||||
else:
|
||||
return []
|
||||
|
||||
entries = [{"token": first_token_str, "null_logprob": True}]
|
||||
|
||||
# Batch logsumexp and topk as single operations across all positions
|
||||
# to avoid per-position kernel launch overhead.
|
||||
prompt_logits = logits[0, :n_tokens - 1] # positions 0..n-2 predict tokens 1..n-1
|
||||
k = min(logprobs_count, prompt_logits.shape[-1])
|
||||
all_top_values, all_top_indices = torch.topk(prompt_logits, k=k, dim=-1)
|
||||
all_lse = torch.logsumexp(prompt_logits, dim=-1)
|
||||
all_top_log_probs = all_top_values - all_lse.unsqueeze(-1)
|
||||
|
||||
# Batch-decode all unique token IDs to avoid O(N*k) individual decode calls
|
||||
unique_ids = set(int(tid) for tid in token_ids[1:])
|
||||
unique_ids.update(int(tid) for tid in all_top_indices.flatten().tolist())
|
||||
|
||||
decoded_strs = {tid: shared.tokenizer.decode(torch.tensor([tid])) for tid in unique_ids}
|
||||
|
||||
for i in range(1, n_tokens):
|
||||
token_id = int(token_ids[i])
|
||||
idx = i - 1
|
||||
top_log_probs = all_top_log_probs[idx]
|
||||
top_ids = all_top_indices[idx].tolist()
|
||||
actual_token_str = decoded_strs[token_id]
|
||||
|
||||
# Build the top list with the actual prompt token guaranteed at front
|
||||
if token_id in top_ids:
|
||||
actual_lp = top_log_probs[top_ids.index(token_id)].item()
|
||||
alternatives = [
|
||||
{"token": decoded_strs[top_ids[j]], "logprob": top_log_probs[j].item()}
|
||||
for j in range(k) if top_ids[j] != token_id
|
||||
]
|
||||
else:
|
||||
actual_lp = (prompt_logits[idx, token_id] - all_lse[idx]).item()
|
||||
alternatives = [
|
||||
{"token": decoded_strs[top_ids[j]], "logprob": top_log_probs[j].item()}
|
||||
for j in range(k - 1) # drop lowest to make room
|
||||
]
|
||||
|
||||
entry = {"top_logprobs": [{"token": actual_token_str, "logprob": actual_lp}] + alternatives}
|
||||
entries.append(entry)
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def _get_raw_logprob_entries(offset=0):
|
||||
"""Get raw logprob entries from llama.cpp/ExLlamav3 backend, starting from offset.
|
||||
|
||||
|
|
@ -65,6 +188,21 @@ def _parse_entry_top(entry):
|
|||
return entry.get('top_logprobs', entry.get('top_probs', []))
|
||||
|
||||
|
||||
def _extract_sampled_token(entry, top):
|
||||
"""Get the actually sampled token and its logprob from a logprob entry.
|
||||
|
||||
Uses the entry-level token/logprob when available (the actually sampled
|
||||
token), falling back to top[0] (highest-probability alternative) which
|
||||
may differ with non-greedy sampling.
|
||||
"""
|
||||
if 'token' in entry:
|
||||
return entry['token'], entry.get('logprob', entry.get('prob', 0))
|
||||
|
||||
token_str = top[0].get('token', '')
|
||||
token_logprob = top[0].get('logprob', top[0].get('prob', 0))
|
||||
return token_str, token_logprob
|
||||
|
||||
|
||||
def format_chat_logprobs(entries):
|
||||
"""Format logprob entries into OpenAI chat completions logprobs format.
|
||||
|
||||
|
|
@ -79,9 +217,7 @@ def format_chat_logprobs(entries):
|
|||
if not top:
|
||||
continue
|
||||
|
||||
chosen = top[0]
|
||||
token_str = chosen.get('token', '')
|
||||
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
|
||||
token_str, token_logprob = _extract_sampled_token(entry, top)
|
||||
|
||||
top_list = []
|
||||
for item in top:
|
||||
|
|
@ -118,13 +254,21 @@ def format_completion_logprobs(entries):
|
|||
offset = 0
|
||||
|
||||
for entry in entries:
|
||||
# Handle null logprob entries (first prompt token with echo)
|
||||
if entry.get("null_logprob"):
|
||||
token_str = entry.get("token", "")
|
||||
tokens.append(token_str)
|
||||
token_logprobs.append(None)
|
||||
top_logprobs.append(None)
|
||||
text_offset.append(offset)
|
||||
offset += len(token_str)
|
||||
continue
|
||||
|
||||
top = _parse_entry_top(entry)
|
||||
if not top:
|
||||
continue
|
||||
|
||||
chosen = top[0]
|
||||
token_str = chosen.get('token', '')
|
||||
token_logprob = chosen.get('logprob', chosen.get('prob', 0))
|
||||
token_str, token_logprob = _extract_sampled_token(entry, top)
|
||||
|
||||
tokens.append(token_str)
|
||||
token_logprobs.append(token_logprob)
|
||||
|
|
@ -407,7 +551,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
})
|
||||
|
||||
max_tokens = generate_params['max_new_tokens']
|
||||
if max_tokens in [None, 0]:
|
||||
if max_tokens is not None and max_tokens <= 0:
|
||||
raise InvalidRequestError(message="max_tokens must be greater than 0.", param="max_tokens")
|
||||
|
||||
if max_tokens is None:
|
||||
generate_params['max_new_tokens'] = 512
|
||||
generate_params['auto_max_new_tokens'] = True
|
||||
|
||||
|
|
@ -652,6 +799,15 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
# common params
|
||||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||
max_tokens = generate_params['max_new_tokens']
|
||||
if max_tokens is None:
|
||||
generate_params['max_new_tokens'] = 512
|
||||
generate_params['auto_max_new_tokens'] = True
|
||||
max_tokens = 512
|
||||
elif max_tokens < 0:
|
||||
raise InvalidRequestError(message="max_tokens must be greater than or equal to 0.", param="max_tokens")
|
||||
elif max_tokens == 0 and body.get('logprobs') is None:
|
||||
raise InvalidRequestError(message="max_tokens is 0 but no logprobs parameter was specified.", param="max_tokens")
|
||||
|
||||
generate_params['stream'] = stream
|
||||
if stop_event is not None:
|
||||
generate_params['stop_event'] = stop_event
|
||||
|
|
@ -700,9 +856,17 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
prompt = decode(prompt)[0]
|
||||
|
||||
prefix = prompt if echo else ''
|
||||
token_count = len(encode(prompt)[0])
|
||||
prompt_input_ids = encode(prompt)
|
||||
token_count = len(prompt_input_ids[0])
|
||||
total_prompt_token_count += token_count
|
||||
|
||||
# Compute prompt logprobs once per prompt (shared across n_completions)
|
||||
logprobs_val = body.get('logprobs', None)
|
||||
if echo and logprobs_val is not None and logprobs_val >= 0:
|
||||
prompt_entries = _compute_prompt_logprob_entries(prompt, logprobs_val, input_ids=prompt_input_ids)
|
||||
else:
|
||||
prompt_entries = None
|
||||
|
||||
original_seed = generate_params.get('seed', -1)
|
||||
for _n in range(n_completions):
|
||||
# Increment seed for each completion to ensure diversity (matches llama.cpp native behavior)
|
||||
|
|
@ -713,29 +877,41 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
logprob_proc.token_alternatives_history.clear()
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
answer = ''
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
total_completion_token_count += completion_token_count
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
if logprob_proc:
|
||||
all_entries = []
|
||||
for alt in logprob_proc.token_alternatives_history:
|
||||
all_entries.extend(_dict_to_logprob_entries(alt))
|
||||
completion_logprobs = format_completion_logprobs(all_entries)
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
raw = getattr(shared.model, 'last_completion_probabilities', None)
|
||||
completion_logprobs = format_completion_logprobs(raw)
|
||||
if max_tokens == 0:
|
||||
answer = ''
|
||||
completion_token_count = 0
|
||||
stop_reason = "stop"
|
||||
else:
|
||||
completion_logprobs = None
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
answer = ''
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
total_completion_token_count += completion_token_count
|
||||
|
||||
if max_tokens == 0:
|
||||
all_entries = []
|
||||
else:
|
||||
if logprob_proc:
|
||||
all_entries = []
|
||||
for alt in logprob_proc.token_alternatives_history:
|
||||
all_entries.extend(_dict_to_logprob_entries(alt))
|
||||
elif shared.args.loader in ('llama.cpp', 'ExLlamav3'):
|
||||
all_entries = getattr(shared.model, 'last_completion_probabilities', None) or []
|
||||
else:
|
||||
all_entries = []
|
||||
|
||||
if prompt_entries:
|
||||
all_entries = prompt_entries + all_entries
|
||||
|
||||
completion_logprobs = format_completion_logprobs(all_entries) if all_entries else None
|
||||
|
||||
respi = {
|
||||
"index": choice_index,
|
||||
|
|
@ -775,7 +951,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||
|
||||
prefix = prompt if echo else ''
|
||||
token_count = len(encode(prompt)[0])
|
||||
prompt_input_ids = encode(prompt)
|
||||
token_count = len(prompt_input_ids[0])
|
||||
|
||||
# Check if usage should be included in streaming chunks per OpenAI spec
|
||||
stream_options = body.get('stream_options')
|
||||
|
|
@ -808,37 +985,57 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
|
||||
return chunk
|
||||
|
||||
logprobs_val = body.get('logprobs', None)
|
||||
if echo and logprobs_val is not None and logprobs_val >= 0:
|
||||
prompt_entries = _compute_prompt_logprob_entries(prompt, logprobs_val, input_ids=prompt_input_ids)
|
||||
prompt_logprobs_formatted = format_completion_logprobs(prompt_entries) if prompt_entries else None
|
||||
else:
|
||||
prompt_logprobs_formatted = None
|
||||
|
||||
# Clear stale logprobs from any previous request before building the
|
||||
# first chunk, so text_streaming_chunk doesn't pick up old data.
|
||||
if hasattr(shared.model, 'last_completion_probabilities'):
|
||||
shared.model.last_completion_probabilities = []
|
||||
cmpl_logprobs_offset[0] = 0
|
||||
|
||||
chunk = text_streaming_chunk(prefix)
|
||||
if prompt_logprobs_formatted is not None:
|
||||
chunk[resp_list][0]["logprobs"] = prompt_logprobs_formatted
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
completion_token_count = 0
|
||||
if max_tokens == 0:
|
||||
answer = ''
|
||||
completion_token_count = 0
|
||||
stop_reason = "stop"
|
||||
else:
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
completion_token_count = 0
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
seen_content = answer
|
||||
chunk = text_streaming_chunk(new_content)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
seen_content = answer
|
||||
chunk = text_streaming_chunk(new_content)
|
||||
if include_usage:
|
||||
chunk['usage'] = None
|
||||
yield chunk
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
chunk = text_streaming_chunk(suffix)
|
||||
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||
|
|
|
|||
|
|
@ -489,15 +489,35 @@ class Exllamav3Model:
|
|||
return
|
||||
|
||||
id_to_piece = self.tokenizer.get_id_to_piece_list(True)
|
||||
sampled_ids = result.get("token_ids") # (batch, seq_len) - actually sampled tokens
|
||||
sampled_probs = result.get("token_probs") # (batch, seq_len) - their probabilities
|
||||
|
||||
def _piece(tid):
|
||||
s = id_to_piece[tid] if tid < len(id_to_piece) else f"<{tid}>"
|
||||
return s.replace('\u2581', ' ')
|
||||
|
||||
def _logprob(prob):
|
||||
return math.log(prob) if prob > 0 else float("-inf")
|
||||
|
||||
# top_k_tokens shape: (batch, seq_len, k), top_k_probs same
|
||||
for seq_idx in range(top_k_tokens.shape[1]):
|
||||
entry = {"top_logprobs": []}
|
||||
for k_idx in range(top_k_tokens.shape[2]):
|
||||
token_id = top_k_tokens[0, seq_idx, k_idx].item()
|
||||
prob = top_k_probs[0, seq_idx, k_idx].item()
|
||||
token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>"
|
||||
logprob = math.log(prob) if prob > 0 else float("-inf")
|
||||
entry["top_logprobs"].append({"token": token_str, "logprob": logprob})
|
||||
entry["top_logprobs"].append({"token": _piece(token_id), "logprob": _logprob(prob)})
|
||||
|
||||
# Record the actually sampled token at the entry level so
|
||||
# format_completion_logprobs uses it instead of top_logprobs[0]
|
||||
# (they differ with non-greedy sampling).
|
||||
if sampled_ids is not None:
|
||||
sid = sampled_ids[0, seq_idx].item()
|
||||
entry["token"] = _piece(sid)
|
||||
if sampled_probs is not None:
|
||||
entry["logprob"] = _logprob(sampled_probs[0, seq_idx].item())
|
||||
else:
|
||||
entry["logprob"] = None
|
||||
|
||||
self.last_completion_probabilities.append(entry)
|
||||
|
||||
def generate(self, prompt, state):
|
||||
|
|
|
|||
|
|
@ -310,8 +310,45 @@ class LlamaServer:
|
|||
else:
|
||||
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
|
||||
|
||||
def get_prompt_logprob_entries(self, token_ids, n_probs=5, prompt=""):
|
||||
"""Get logprob entries for prompt tokens via a single n_predict=0 request.
|
||||
|
||||
Requires llama.cpp server with prompt_logprobs support.
|
||||
Returns entries in the standard format for format_completion_logprobs().
|
||||
"""
|
||||
token_ids_list = token_ids.tolist() if hasattr(token_ids, 'tolist') else list(token_ids)
|
||||
|
||||
url = f"http://127.0.0.1:{self.port}/completion"
|
||||
payload = {
|
||||
"prompt": token_ids_list,
|
||||
"n_predict": 0,
|
||||
"n_probs": n_probs,
|
||||
"prompt_logprobs": True,
|
||||
"stream": False,
|
||||
"cache_prompt": False,
|
||||
}
|
||||
|
||||
response = self.session.post(url, json=payload)
|
||||
result = response.json()
|
||||
|
||||
prompt_probs = result.get("prompt_probabilities", [])
|
||||
if not prompt_probs:
|
||||
return []
|
||||
|
||||
# Null first token (no conditioning context); use empty string for BOS
|
||||
# or tokens that don't appear at the start of the prompt text.
|
||||
first_token_str = self.decode([token_ids_list[0]])
|
||||
if self.bos_token and first_token_str == self.bos_token:
|
||||
first_token_str = ""
|
||||
elif not prompt.startswith(first_token_str):
|
||||
first_token_str = ""
|
||||
|
||||
entries = [{"token": first_token_str, "null_logprob": True}]
|
||||
entries.extend(prompt_probs)
|
||||
return entries
|
||||
|
||||
def _get_vocabulary_size(self):
|
||||
"""Get and store the model's maximum context length."""
|
||||
"""Get and store the model's vocabulary size."""
|
||||
url = f"http://127.0.0.1:{self.port}/v1/models"
|
||||
response = self.session.get(url).json()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue