mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-06 15:13:38 +00:00
API: Implement echo + logprobs for /v1/completions endpoint
This commit is contained in:
parent
6382fbef83
commit
71c1a52afe
3 changed files with 309 additions and 55 deletions
|
|
@ -310,8 +310,45 @@ class LlamaServer:
|
|||
else:
|
||||
raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}")
|
||||
|
||||
def get_prompt_logprob_entries(self, token_ids, n_probs=5, prompt=""):
|
||||
"""Get logprob entries for prompt tokens via a single n_predict=0 request.
|
||||
|
||||
Requires llama.cpp server with prompt_logprobs support.
|
||||
Returns entries in the standard format for format_completion_logprobs().
|
||||
"""
|
||||
token_ids_list = token_ids.tolist() if hasattr(token_ids, 'tolist') else list(token_ids)
|
||||
|
||||
url = f"http://127.0.0.1:{self.port}/completion"
|
||||
payload = {
|
||||
"prompt": token_ids_list,
|
||||
"n_predict": 0,
|
||||
"n_probs": n_probs,
|
||||
"prompt_logprobs": True,
|
||||
"stream": False,
|
||||
"cache_prompt": False,
|
||||
}
|
||||
|
||||
response = self.session.post(url, json=payload)
|
||||
result = response.json()
|
||||
|
||||
prompt_probs = result.get("prompt_probabilities", [])
|
||||
if not prompt_probs:
|
||||
return []
|
||||
|
||||
# Null first token (no conditioning context); use empty string for BOS
|
||||
# or tokens that don't appear at the start of the prompt text.
|
||||
first_token_str = self.decode([token_ids_list[0]])
|
||||
if self.bos_token and first_token_str == self.bos_token:
|
||||
first_token_str = ""
|
||||
elif not prompt.startswith(first_token_str):
|
||||
first_token_str = ""
|
||||
|
||||
entries = [{"token": first_token_str, "null_logprob": True}]
|
||||
entries.extend(prompt_probs)
|
||||
return entries
|
||||
|
||||
def _get_vocabulary_size(self):
|
||||
"""Get and store the model's maximum context length."""
|
||||
"""Get and store the model's vocabulary size."""
|
||||
url = f"http://127.0.0.1:{self.port}/v1/models"
|
||||
response = self.session.get(url).json()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue