Forward logit_bias, logprobs, and n to llama.cpp backend

- Forward logit_bias and logprobs natively to llama.cpp
- Support n>1 completions with seed increment for diversity
- Fix logprobs returning empty dict when not requested
This commit is contained in:
oobabooga 2026-03-10 10:41:15 -03:00
parent 6ec4ca8b10
commit 8aeaa76365
4 changed files with 94 additions and 30 deletions

View file

@ -133,9 +133,20 @@ class LlamaServer:
payload["samplers"] = filtered_samplers
logit_bias = []
if state['custom_token_bans']:
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
payload["logit_bias"] = to_ban
logit_bias.extend([[int(token_id), False] for token_id in state['custom_token_bans'].split(',')])
if state.get('logit_bias'):
for token_id_str, bias in state['logit_bias'].items():
logit_bias.append([int(token_id_str), bias])
if logit_bias:
payload["logit_bias"] = logit_bias
n_probs = state.get('logprobs', 0)
if n_probs and n_probs > 0:
payload["n_probs"] = n_probs
return payload
@ -215,6 +226,7 @@ class LlamaServer:
response.raise_for_status() # Raise an exception for HTTP errors
full_text = ""
self.last_completion_probabilities = []
# Process the streaming response
stop_event = state.get('stop_event')
@ -240,6 +252,10 @@ class LlamaServer:
full_text += data['content']
yield full_text
# Capture logprobs if present
if 'completion_probabilities' in data:
self.last_completion_probabilities.extend(data['completion_probabilities'])
# Check if generation is complete
if data.get('stop', False):
break