mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-18 03:14:39 +01:00
Forward logit_bias, logprobs, and n to llama.cpp backend
- Forward logit_bias and logprobs natively to llama.cpp - Support n>1 completions with seed increment for diversity - Fix logprobs returning empty dict when not requested
This commit is contained in:
parent
6ec4ca8b10
commit
8aeaa76365
|
|
@ -37,6 +37,24 @@ def load_chat_template_file(filepath):
|
|||
return text
|
||||
|
||||
|
||||
def get_logprobs_from_llama_cpp():
|
||||
"""Read logprobs captured from the llama.cpp server response."""
|
||||
if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities:
|
||||
return None
|
||||
|
||||
# Convert llama.cpp format to {token: logprob} dict
|
||||
# llama.cpp returns: [{"token": "text", "logprob": -0.5, "top_logprobs": [{"token": "t", "logprob": -0.1}, ...]}, ...]
|
||||
result = {}
|
||||
for entry in shared.model.last_completion_probabilities:
|
||||
top = entry.get('top_logprobs', entry.get('top_probs', []))
|
||||
for item in top:
|
||||
token = item.get('token', '')
|
||||
logprob = item.get('logprob', item.get('prob', 0))
|
||||
result[token] = logprob
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def convert_logprobs_to_tiktoken(model, logprobs):
|
||||
# more problems than it's worth.
|
||||
# try:
|
||||
|
|
@ -72,6 +90,7 @@ def process_parameters(body, is_legacy=False):
|
|||
elif isinstance(body['stop'], list):
|
||||
generate_params['custom_stopping_strings'] = body['stop']
|
||||
|
||||
# For llama.cpp, logit_bias and logprobs are forwarded natively via prepare_payload()
|
||||
if shared.args.loader != 'llama.cpp':
|
||||
from transformers import LogitsProcessorList
|
||||
|
||||
|
|
@ -85,13 +104,10 @@ def process_parameters(body, is_legacy=False):
|
|||
if logit_bias: # {str: float, ...}
|
||||
logits_processor = [LogitsBiasProcessor(logit_bias)]
|
||||
|
||||
logprobs = None # coming to chat eventually
|
||||
if 'logprobs' in body:
|
||||
logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
||||
logprobs = body.get('logprobs', None)
|
||||
if logprobs is not None and logprobs > 0:
|
||||
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||
logits_processor.extend([generate_params['logprob_proc']])
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
if logits_processor: # requires logits_processor support
|
||||
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||
|
|
@ -456,6 +472,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
logger.info(f"Found {len(raw_images)} image(s) in request.")
|
||||
generate_params['raw_images'] = raw_images
|
||||
|
||||
n_completions = body.get('n', 1) or 1
|
||||
|
||||
if not stream:
|
||||
prompt_arg = body[prompt_str]
|
||||
|
||||
|
|
@ -469,6 +487,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
resp_list_data = []
|
||||
total_completion_token_count = 0
|
||||
total_prompt_token_count = 0
|
||||
choice_index = 0
|
||||
|
||||
for idx, prompt in enumerate(prompt_arg, start=0):
|
||||
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int):
|
||||
|
|
@ -483,31 +502,46 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
prompt = decode(prompt)[0]
|
||||
|
||||
prefix = prompt if echo else ''
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
answer = ''
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
total_prompt_token_count += token_count
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
total_completion_token_count += completion_token_count
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
respi = {
|
||||
"index": idx,
|
||||
"finish_reason": stop_reason,
|
||||
"text": prefix + answer + suffix,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
}
|
||||
original_seed = generate_params.get('seed', -1)
|
||||
for _n in range(n_completions):
|
||||
# Increment seed for each completion to ensure diversity (matches llama.cpp native behavior)
|
||||
if original_seed >= 0:
|
||||
generate_params['seed'] = original_seed + _n
|
||||
|
||||
resp_list_data.extend([respi])
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
answer = ''
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
total_completion_token_count += completion_token_count
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
if logprob_proc:
|
||||
completion_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]}
|
||||
elif shared.args.loader == 'llama.cpp':
|
||||
llama_logprobs = get_logprobs_from_llama_cpp()
|
||||
completion_logprobs = {'top_logprobs': [llama_logprobs]} if llama_logprobs else None
|
||||
else:
|
||||
completion_logprobs = None
|
||||
|
||||
respi = {
|
||||
"index": choice_index,
|
||||
"finish_reason": stop_reason,
|
||||
"text": prefix + answer + suffix,
|
||||
"logprobs": completion_logprobs,
|
||||
}
|
||||
|
||||
resp_list_data.append(respi)
|
||||
choice_index += 1
|
||||
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
|
|
@ -540,6 +574,14 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
|
||||
def text_streaming_chunk(content):
|
||||
# begin streaming
|
||||
if logprob_proc:
|
||||
chunk_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]}
|
||||
elif shared.args.loader == 'llama.cpp':
|
||||
llama_logprobs = get_logprobs_from_llama_cpp()
|
||||
chunk_logprobs = {'top_logprobs': [llama_logprobs]} if llama_logprobs else None
|
||||
else:
|
||||
chunk_logprobs = None
|
||||
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
|
|
@ -549,7 +591,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e
|
|||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"text": content,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
"logprobs": chunk_logprobs,
|
||||
}],
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -119,6 +119,12 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
|
|||
is_legacy = "/generate" in path
|
||||
|
||||
if request_data.stream:
|
||||
if (request_data.n or 1) > 1:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": {"message": "n > 1 is not supported with streaming.", "type": "invalid_request_error", "param": "n", "code": None}}
|
||||
)
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
async def generator():
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class CompletionRequestParams(BaseModel):
|
|||
logit_bias: dict | None = None
|
||||
logprobs: int | None = None
|
||||
max_tokens: int | None = 512
|
||||
n: int | None = Field(default=1, description="Unused parameter.")
|
||||
n: int | None = Field(default=1, description="Number of completions to generate. Only supported without streaming.")
|
||||
presence_penalty: float | None = shared.args.presence_penalty
|
||||
stop: str | List[str] | None = None
|
||||
stream: bool | None = False
|
||||
|
|
|
|||
|
|
@ -133,9 +133,20 @@ class LlamaServer:
|
|||
|
||||
payload["samplers"] = filtered_samplers
|
||||
|
||||
logit_bias = []
|
||||
if state['custom_token_bans']:
|
||||
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
|
||||
payload["logit_bias"] = to_ban
|
||||
logit_bias.extend([[int(token_id), False] for token_id in state['custom_token_bans'].split(',')])
|
||||
|
||||
if state.get('logit_bias'):
|
||||
for token_id_str, bias in state['logit_bias'].items():
|
||||
logit_bias.append([int(token_id_str), bias])
|
||||
|
||||
if logit_bias:
|
||||
payload["logit_bias"] = logit_bias
|
||||
|
||||
n_probs = state.get('logprobs', 0)
|
||||
if n_probs and n_probs > 0:
|
||||
payload["n_probs"] = n_probs
|
||||
|
||||
return payload
|
||||
|
||||
|
|
@ -215,6 +226,7 @@ class LlamaServer:
|
|||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
|
||||
full_text = ""
|
||||
self.last_completion_probabilities = []
|
||||
|
||||
# Process the streaming response
|
||||
stop_event = state.get('stop_event')
|
||||
|
|
@ -240,6 +252,10 @@ class LlamaServer:
|
|||
full_text += data['content']
|
||||
yield full_text
|
||||
|
||||
# Capture logprobs if present
|
||||
if 'completion_probabilities' in data:
|
||||
self.last_completion_probabilities.extend(data['completion_probabilities'])
|
||||
|
||||
# Check if generation is complete
|
||||
if data.get('stop', False):
|
||||
break
|
||||
|
|
|
|||
Loading…
Reference in a new issue