Forward logit_bias, logprobs, and n to llama.cpp backend

- Forward logit_bias and logprobs natively to llama.cpp
- Support n>1 completions with seed increment for diversity
- Fix logprobs returning empty dict when not requested
This commit is contained in:
oobabooga 2026-03-10 10:41:15 -03:00
parent 6ec4ca8b10
commit 8aeaa76365
4 changed files with 94 additions and 30 deletions

View file

@ -37,6 +37,24 @@ def load_chat_template_file(filepath):
return text
def get_logprobs_from_llama_cpp():
"""Read logprobs captured from the llama.cpp server response."""
if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities:
return None
# Convert llama.cpp format to {token: logprob} dict
# llama.cpp returns: [{"token": "text", "logprob": -0.5, "top_logprobs": [{"token": "t", "logprob": -0.1}, ...]}, ...]
result = {}
for entry in shared.model.last_completion_probabilities:
top = entry.get('top_logprobs', entry.get('top_probs', []))
for item in top:
token = item.get('token', '')
logprob = item.get('logprob', item.get('prob', 0))
result[token] = logprob
return result
def convert_logprobs_to_tiktoken(model, logprobs):
# more problems than it's worth.
# try:
@ -72,6 +90,7 @@ def process_parameters(body, is_legacy=False):
elif isinstance(body['stop'], list):
generate_params['custom_stopping_strings'] = body['stop']
# For llama.cpp, logit_bias and logprobs are forwarded natively via prepare_payload()
if shared.args.loader != 'llama.cpp':
from transformers import LogitsProcessorList
@ -85,13 +104,10 @@ def process_parameters(body, is_legacy=False):
if logit_bias: # {str: float, ...}
logits_processor = [LogitsBiasProcessor(logit_bias)]
logprobs = None # coming to chat eventually
if 'logprobs' in body:
logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
logprobs = body.get('logprobs', None)
if logprobs is not None and logprobs > 0:
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
logits_processor.extend([generate_params['logprob_proc']])
else:
logprobs = None
if logits_processor: # requires logits_processor support
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
@ -456,6 +472,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
logger.info(f"Found {len(raw_images)} image(s) in request.")
generate_params['raw_images'] = raw_images
n_completions = body.get('n', 1) or 1
if not stream:
prompt_arg = body[prompt_str]
@ -469,6 +487,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
resp_list_data = []
total_completion_token_count = 0
total_prompt_token_count = 0
choice_index = 0
for idx, prompt in enumerate(prompt_arg, start=0):
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int):
@ -483,31 +502,46 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
prompt = decode(prompt)[0]
prefix = prompt if echo else ''
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
for a in generator:
answer = a
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
respi = {
"index": idx,
"finish_reason": stop_reason,
"text": prefix + answer + suffix,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}
original_seed = generate_params.get('seed', -1)
for _n in range(n_completions):
# Increment seed for each completion to ensure diversity (matches llama.cpp native behavior)
if original_seed >= 0:
generate_params['seed'] = original_seed + _n
resp_list_data.extend([respi])
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
for a in generator:
answer = a
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
if logprob_proc:
completion_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]}
elif shared.args.loader == 'llama.cpp':
llama_logprobs = get_logprobs_from_llama_cpp()
completion_logprobs = {'top_logprobs': [llama_logprobs]} if llama_logprobs else None
else:
completion_logprobs = None
respi = {
"index": choice_index,
"finish_reason": stop_reason,
"text": prefix + answer + suffix,
"logprobs": completion_logprobs,
}
resp_list_data.append(respi)
choice_index += 1
resp = {
"id": cmpl_id,
@ -540,6 +574,14 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
def text_streaming_chunk(content):
# begin streaming
if logprob_proc:
chunk_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]}
elif shared.args.loader == 'llama.cpp':
llama_logprobs = get_logprobs_from_llama_cpp()
chunk_logprobs = {'top_logprobs': [llama_logprobs]} if llama_logprobs else None
else:
chunk_logprobs = None
chunk = {
"id": cmpl_id,
"object": object_type,
@ -549,7 +591,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
"index": 0,
"finish_reason": None,
"text": content,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
"logprobs": chunk_logprobs,
}],
}

View file

@ -119,6 +119,12 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
is_legacy = "/generate" in path
if request_data.stream:
if (request_data.n or 1) > 1:
return JSONResponse(
status_code=400,
content={"error": {"message": "n > 1 is not supported with streaming.", "type": "invalid_request_error", "param": "n", "code": None}}
)
stop_event = threading.Event()
async def generator():

View file

@ -109,7 +109,7 @@ class CompletionRequestParams(BaseModel):
logit_bias: dict | None = None
logprobs: int | None = None
max_tokens: int | None = 512
n: int | None = Field(default=1, description="Unused parameter.")
n: int | None = Field(default=1, description="Number of completions to generate. Only supported without streaming.")
presence_penalty: float | None = shared.args.presence_penalty
stop: str | List[str] | None = None
stream: bool | None = False

View file

@ -133,9 +133,20 @@ class LlamaServer:
payload["samplers"] = filtered_samplers
logit_bias = []
if state['custom_token_bans']:
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
payload["logit_bias"] = to_ban
logit_bias.extend([[int(token_id), False] for token_id in state['custom_token_bans'].split(',')])
if state.get('logit_bias'):
for token_id_str, bias in state['logit_bias'].items():
logit_bias.append([int(token_id_str), bias])
if logit_bias:
payload["logit_bias"] = logit_bias
n_probs = state.get('logprobs', 0)
if n_probs and n_probs > 0:
payload["n_probs"] = n_probs
return payload
@ -215,6 +226,7 @@ class LlamaServer:
response.raise_for_status() # Raise an exception for HTTP errors
full_text = ""
self.last_completion_probabilities = []
# Process the streaming response
stop_event = state.get('stop_event')
@ -240,6 +252,10 @@ class LlamaServer:
full_text += data['content']
yield full_text
# Capture logprobs if present
if 'completion_probabilities' in data:
self.last_completion_probabilities.extend(data['completion_probabilities'])
# Check if generation is complete
if data.get('stop', False):
break