diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 8ba031c1..d70e69e6 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -37,6 +37,24 @@ def load_chat_template_file(filepath): return text +def get_logprobs_from_llama_cpp(): + """Read logprobs captured from the llama.cpp server response.""" + if not hasattr(shared.model, 'last_completion_probabilities') or not shared.model.last_completion_probabilities: + return None + + # Convert llama.cpp format to {token: logprob} dict + # llama.cpp returns: [{"token": "text", "logprob": -0.5, "top_logprobs": [{"token": "t", "logprob": -0.1}, ...]}, ...] + result = {} + for entry in shared.model.last_completion_probabilities: + top = entry.get('top_logprobs', entry.get('top_probs', [])) + for item in top: + token = item.get('token', '') + logprob = item.get('logprob', item.get('prob', 0)) + result[token] = logprob + + return result + + def convert_logprobs_to_tiktoken(model, logprobs): # more problems than it's worth. # try: @@ -72,6 +90,7 @@ def process_parameters(body, is_legacy=False): elif isinstance(body['stop'], list): generate_params['custom_stopping_strings'] = body['stop'] + # For llama.cpp, logit_bias and logprobs are forwarded natively via prepare_payload() if shared.args.loader != 'llama.cpp': from transformers import LogitsProcessorList @@ -85,13 +104,10 @@ def process_parameters(body, is_legacy=False): if logit_bias: # {str: float, ...} logits_processor = [LogitsBiasProcessor(logit_bias)] - logprobs = None # coming to chat eventually - if 'logprobs' in body: - logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5. + logprobs = body.get('logprobs', None) + if logprobs is not None and logprobs > 0: generate_params['logprob_proc'] = LogprobProcessor(logprobs) logits_processor.extend([generate_params['logprob_proc']]) - else: - logprobs = None if logits_processor: # requires logits_processor support generate_params['logits_processor'] = LogitsProcessorList(logits_processor) @@ -456,6 +472,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e logger.info(f"Found {len(raw_images)} image(s) in request.") generate_params['raw_images'] = raw_images + n_completions = body.get('n', 1) or 1 + if not stream: prompt_arg = body[prompt_str] @@ -469,6 +487,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e resp_list_data = [] total_completion_token_count = 0 total_prompt_token_count = 0 + choice_index = 0 for idx, prompt in enumerate(prompt_arg, start=0): if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], int): @@ -483,31 +502,46 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e prompt = decode(prompt)[0] prefix = prompt if echo else '' - - # generate reply ####################################### - debug_msg({'prompt': prompt, 'generate_params': generate_params}) - generator = generate_reply(prompt, generate_params, is_chat=False) - answer = '' - - for a in generator: - answer = a - token_count = len(encode(prompt)[0]) total_prompt_token_count += token_count - completion_token_count = len(encode(answer)[0]) - total_completion_token_count += completion_token_count - stop_reason = "stop" - if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: - stop_reason = "length" - respi = { - "index": idx, - "finish_reason": stop_reason, - "text": prefix + answer + suffix, - "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, - } + original_seed = generate_params.get('seed', -1) + for _n in range(n_completions): + # Increment seed for each completion to ensure diversity (matches llama.cpp native behavior) + if original_seed >= 0: + generate_params['seed'] = original_seed + _n - resp_list_data.extend([respi]) + # generate reply ####################################### + debug_msg({'prompt': prompt, 'generate_params': generate_params}) + generator = generate_reply(prompt, generate_params, is_chat=False) + answer = '' + + for a in generator: + answer = a + + completion_token_count = len(encode(answer)[0]) + total_completion_token_count += completion_token_count + stop_reason = "stop" + if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: + stop_reason = "length" + + if logprob_proc: + completion_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]} + elif shared.args.loader == 'llama.cpp': + llama_logprobs = get_logprobs_from_llama_cpp() + completion_logprobs = {'top_logprobs': [llama_logprobs]} if llama_logprobs else None + else: + completion_logprobs = None + + respi = { + "index": choice_index, + "finish_reason": stop_reason, + "text": prefix + answer + suffix, + "logprobs": completion_logprobs, + } + + resp_list_data.append(respi) + choice_index += 1 resp = { "id": cmpl_id, @@ -540,6 +574,14 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e def text_streaming_chunk(content): # begin streaming + if logprob_proc: + chunk_logprobs = {'top_logprobs': [logprob_proc.token_alternatives]} + elif shared.args.loader == 'llama.cpp': + llama_logprobs = get_logprobs_from_llama_cpp() + chunk_logprobs = {'top_logprobs': [llama_logprobs]} if llama_logprobs else None + else: + chunk_logprobs = None + chunk = { "id": cmpl_id, "object": object_type, @@ -549,7 +591,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e "index": 0, "finish_reason": None, "text": content, - "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, + "logprobs": chunk_logprobs, }], } diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 7a13638d..94c7650f 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -119,6 +119,12 @@ async def openai_completions(request: Request, request_data: CompletionRequest): is_legacy = "/generate" in path if request_data.stream: + if (request_data.n or 1) > 1: + return JSONResponse( + status_code=400, + content={"error": {"message": "n > 1 is not supported with streaming.", "type": "invalid_request_error", "param": "n", "code": None}} + ) + stop_event = threading.Event() async def generator(): diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index e48b7b60..2156074b 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -109,7 +109,7 @@ class CompletionRequestParams(BaseModel): logit_bias: dict | None = None logprobs: int | None = None max_tokens: int | None = 512 - n: int | None = Field(default=1, description="Unused parameter.") + n: int | None = Field(default=1, description="Number of completions to generate. Only supported without streaming.") presence_penalty: float | None = shared.args.presence_penalty stop: str | List[str] | None = None stream: bool | None = False diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 6f7cbd20..a3e431ac 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -133,9 +133,20 @@ class LlamaServer: payload["samplers"] = filtered_samplers + logit_bias = [] if state['custom_token_bans']: - to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')] - payload["logit_bias"] = to_ban + logit_bias.extend([[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]) + + if state.get('logit_bias'): + for token_id_str, bias in state['logit_bias'].items(): + logit_bias.append([int(token_id_str), bias]) + + if logit_bias: + payload["logit_bias"] = logit_bias + + n_probs = state.get('logprobs', 0) + if n_probs and n_probs > 0: + payload["n_probs"] = n_probs return payload @@ -215,6 +226,7 @@ class LlamaServer: response.raise_for_status() # Raise an exception for HTTP errors full_text = "" + self.last_completion_probabilities = [] # Process the streaming response stop_event = state.get('stop_event') @@ -240,6 +252,10 @@ class LlamaServer: full_text += data['content'] yield full_text + # Capture logprobs if present + if 'completion_probabilities' in data: + self.last_completion_probabilities.extend(data['completion_probabilities']) + # Check if generation is complete if data.get('stop', False): break