diff --git a/modules/api/completions.py b/modules/api/completions.py index 8948bb86..587ad6ea 100644 --- a/modules/api/completions.py +++ b/modules/api/completions.py @@ -39,6 +39,129 @@ def load_chat_template_file(filepath): return text +def _first_token_display_str(token_id, prompt, tokenizer): + """Return the display string for the first prompt token. + + Returns empty string for BOS or tokens that don't appear at the start + of the prompt text, so they don't shift text_offset for subsequent tokens. + """ + token_id = int(token_id) + bos_id = getattr(tokenizer, 'bos_token_id', None) + if bos_id is not None and token_id == bos_id: + return "" + + import torch + tok = tokenizer.decode(torch.tensor([token_id])) + if not prompt.startswith(tok): + return "" + + return tok + + +def _compute_prompt_logprob_entries(prompt, logprobs_count, input_ids=None): + """Compute logprob entries for prompt tokens via a forward pass. + + Returns a list of logprob entries in the standard format. + The first token gets a null entry (no conditioning context). + + Supported for HF-compatible loaders (Transformers, ExLlamav3_HF, etc.) + via a single forward pass, and for llama.cpp via the server's + prompt_logprobs parameter. Returns [] for unsupported loaders. + """ + if input_ids is None: + input_ids = encode(prompt) # (1, seq_len) tensor or array + + token_ids = input_ids[0] + n_tokens = len(token_ids) + + if n_tokens == 0: + return [] + + loader = shared.args.loader + model = shared.model + + if loader == 'llama.cpp': + return model.get_prompt_logprob_entries(token_ids, max(logprobs_count, 1), prompt=prompt) + + first_token_str = _first_token_display_str(token_ids[0], prompt, shared.tokenizer) + + if n_tokens <= 1: + return [{"token": first_token_str, "null_logprob": True}] + + import torch + + if loader == 'ExLlamav3' and hasattr(model, 'model') and hasattr(model, 'cache'): + # Native ExLlamav3: call the underlying Model.forward() directly + input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long) + with torch.no_grad(): + logits = model.model.forward( + input_ids=input_ids_tensor, + params={ + "attn_mode": "flash_attn", + "cache": model.cache, + "past_len": 0, + "batch_shape": (1, model.max_tokens), + } + ).float().cpu() + + elif hasattr(model, 'forward'): + # HF-compatible loaders (Transformers, ExLlamav3_HF, etc.) + input_ids_tensor = input_ids if isinstance(input_ids, torch.Tensor) else torch.tensor(input_ids, dtype=torch.long) + if hasattr(model, 'device'): + input_ids_tensor = input_ids_tensor.to(model.device) + with torch.no_grad(): + # Pass labels to ensure logits are returned for ALL positions, + # not just the last token (some HF wrappers like ExLlamav3_HF + # only compute the last-token logits when labels are absent). + outputs = model(input_ids=input_ids_tensor, labels=input_ids_tensor) + logits = outputs.logits.float().cpu() + + else: + return [] + + entries = [{"token": first_token_str, "null_logprob": True}] + + # Batch logsumexp and topk as single operations across all positions + # to avoid per-position kernel launch overhead. + prompt_logits = logits[0, :n_tokens - 1] # positions 0..n-2 predict tokens 1..n-1 + k = min(logprobs_count, prompt_logits.shape[-1]) + all_top_values, all_top_indices = torch.topk(prompt_logits, k=k, dim=-1) + all_lse = torch.logsumexp(prompt_logits, dim=-1) + all_top_log_probs = all_top_values - all_lse.unsqueeze(-1) + + # Batch-decode all unique token IDs to avoid O(N*k) individual decode calls + unique_ids = set(int(tid) for tid in token_ids[1:]) + unique_ids.update(int(tid) for tid in all_top_indices.flatten().tolist()) + + decoded_strs = {tid: shared.tokenizer.decode(torch.tensor([tid])) for tid in unique_ids} + + for i in range(1, n_tokens): + token_id = int(token_ids[i]) + idx = i - 1 + top_log_probs = all_top_log_probs[idx] + top_ids = all_top_indices[idx].tolist() + actual_token_str = decoded_strs[token_id] + + # Build the top list with the actual prompt token guaranteed at front + if token_id in top_ids: + actual_lp = top_log_probs[top_ids.index(token_id)].item() + alternatives = [ + {"token": decoded_strs[top_ids[j]], "logprob": top_log_probs[j].item()} + for j in range(k) if top_ids[j] != token_id + ] + else: + actual_lp = (prompt_logits[idx, token_id] - all_lse[idx]).item() + alternatives = [ + {"token": decoded_strs[top_ids[j]], "logprob": top_log_probs[j].item()} + for j in range(k - 1) # drop lowest to make room + ] + + entry = {"top_logprobs": [{"token": actual_token_str, "logprob": actual_lp}] + alternatives} + entries.append(entry) + + return entries + + def _get_raw_logprob_entries(offset=0): """Get raw logprob entries from llama.cpp/ExLlamav3 backend, starting from offset. @@ -65,6 +188,21 @@ def _parse_entry_top(entry): return entry.get('top_logprobs', entry.get('top_probs', [])) +def _extract_sampled_token(entry, top): + """Get the actually sampled token and its logprob from a logprob entry. + + Uses the entry-level token/logprob when available (the actually sampled + token), falling back to top[0] (highest-probability alternative) which + may differ with non-greedy sampling. + """ + if 'token' in entry: + return entry['token'], entry.get('logprob', entry.get('prob', 0)) + + token_str = top[0].get('token', '') + token_logprob = top[0].get('logprob', top[0].get('prob', 0)) + return token_str, token_logprob + + def format_chat_logprobs(entries): """Format logprob entries into OpenAI chat completions logprobs format. @@ -79,9 +217,7 @@ def format_chat_logprobs(entries): if not top: continue - chosen = top[0] - token_str = chosen.get('token', '') - token_logprob = chosen.get('logprob', chosen.get('prob', 0)) + token_str, token_logprob = _extract_sampled_token(entry, top) top_list = [] for item in top: @@ -118,13 +254,21 @@ def format_completion_logprobs(entries): offset = 0 for entry in entries: + # Handle null logprob entries (first prompt token with echo) + if entry.get("null_logprob"): + token_str = entry.get("token", "") + tokens.append(token_str) + token_logprobs.append(None) + top_logprobs.append(None) + text_offset.append(offset) + offset += len(token_str) + continue + top = _parse_entry_top(entry) if not top: continue - chosen = top[0] - token_str = chosen.get('token', '') - token_logprob = chosen.get('logprob', chosen.get('prob', 0)) + token_str, token_logprob = _extract_sampled_token(entry, top) tokens.append(token_str) token_logprobs.append(token_logprob) @@ -407,7 +551,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p }) max_tokens = generate_params['max_new_tokens'] - if max_tokens in [None, 0]: + if max_tokens is not None and max_tokens <= 0: + raise InvalidRequestError(message="max_tokens must be greater than 0.", param="max_tokens") + + if max_tokens is None: generate_params['max_new_tokens'] = 512 generate_params['auto_max_new_tokens'] = True @@ -652,6 +799,15 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e # common params generate_params = process_parameters(body, is_legacy=is_legacy) max_tokens = generate_params['max_new_tokens'] + if max_tokens is None: + generate_params['max_new_tokens'] = 512 + generate_params['auto_max_new_tokens'] = True + max_tokens = 512 + elif max_tokens < 0: + raise InvalidRequestError(message="max_tokens must be greater than or equal to 0.", param="max_tokens") + elif max_tokens == 0 and body.get('logprobs') is None: + raise InvalidRequestError(message="max_tokens is 0 but no logprobs parameter was specified.", param="max_tokens") + generate_params['stream'] = stream if stop_event is not None: generate_params['stop_event'] = stop_event @@ -700,9 +856,17 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e prompt = decode(prompt)[0] prefix = prompt if echo else '' - token_count = len(encode(prompt)[0]) + prompt_input_ids = encode(prompt) + token_count = len(prompt_input_ids[0]) total_prompt_token_count += token_count + # Compute prompt logprobs once per prompt (shared across n_completions) + logprobs_val = body.get('logprobs', None) + if echo and logprobs_val is not None and logprobs_val >= 0: + prompt_entries = _compute_prompt_logprob_entries(prompt, logprobs_val, input_ids=prompt_input_ids) + else: + prompt_entries = None + original_seed = generate_params.get('seed', -1) for _n in range(n_completions): # Increment seed for each completion to ensure diversity (matches llama.cpp native behavior) @@ -713,29 +877,41 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e logprob_proc.token_alternatives_history.clear() # generate reply ####################################### - debug_msg({'prompt': prompt, 'generate_params': generate_params}) - generator = generate_reply(prompt, generate_params, is_chat=False) - answer = '' - - for a in generator: - answer = a - - completion_token_count = len(encode(answer)[0]) - total_completion_token_count += completion_token_count - stop_reason = "stop" - if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: - stop_reason = "length" - - if logprob_proc: - all_entries = [] - for alt in logprob_proc.token_alternatives_history: - all_entries.extend(_dict_to_logprob_entries(alt)) - completion_logprobs = format_completion_logprobs(all_entries) - elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): - raw = getattr(shared.model, 'last_completion_probabilities', None) - completion_logprobs = format_completion_logprobs(raw) + if max_tokens == 0: + answer = '' + completion_token_count = 0 + stop_reason = "stop" else: - completion_logprobs = None + debug_msg({'prompt': prompt, 'generate_params': generate_params}) + generator = generate_reply(prompt, generate_params, is_chat=False) + answer = '' + + for a in generator: + answer = a + + completion_token_count = len(encode(answer)[0]) + stop_reason = "stop" + if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: + stop_reason = "length" + + total_completion_token_count += completion_token_count + + if max_tokens == 0: + all_entries = [] + else: + if logprob_proc: + all_entries = [] + for alt in logprob_proc.token_alternatives_history: + all_entries.extend(_dict_to_logprob_entries(alt)) + elif shared.args.loader in ('llama.cpp', 'ExLlamav3'): + all_entries = getattr(shared.model, 'last_completion_probabilities', None) or [] + else: + all_entries = [] + + if prompt_entries: + all_entries = prompt_entries + all_entries + + completion_logprobs = format_completion_logprobs(all_entries) if all_entries else None respi = { "index": choice_index, @@ -775,7 +951,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str) prefix = prompt if echo else '' - token_count = len(encode(prompt)[0]) + prompt_input_ids = encode(prompt) + token_count = len(prompt_input_ids[0]) # Check if usage should be included in streaming chunks per OpenAI spec stream_options = body.get('stream_options') @@ -808,37 +985,57 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False, stop_e return chunk + logprobs_val = body.get('logprobs', None) + if echo and logprobs_val is not None and logprobs_val >= 0: + prompt_entries = _compute_prompt_logprob_entries(prompt, logprobs_val, input_ids=prompt_input_ids) + prompt_logprobs_formatted = format_completion_logprobs(prompt_entries) if prompt_entries else None + else: + prompt_logprobs_formatted = None + + # Clear stale logprobs from any previous request before building the + # first chunk, so text_streaming_chunk doesn't pick up old data. + if hasattr(shared.model, 'last_completion_probabilities'): + shared.model.last_completion_probabilities = [] + cmpl_logprobs_offset[0] = 0 + chunk = text_streaming_chunk(prefix) + if prompt_logprobs_formatted is not None: + chunk[resp_list][0]["logprobs"] = prompt_logprobs_formatted if include_usage: chunk['usage'] = None yield chunk # generate reply ####################################### - debug_msg({'prompt': prompt, 'generate_params': generate_params}) - generator = generate_reply(prompt, generate_params, is_chat=False) - answer = '' - seen_content = '' - completion_token_count = 0 + if max_tokens == 0: + answer = '' + completion_token_count = 0 + stop_reason = "stop" + else: + debug_msg({'prompt': prompt, 'generate_params': generate_params}) + generator = generate_reply(prompt, generate_params, is_chat=False) + answer = '' + seen_content = '' + completion_token_count = 0 - for a in generator: - answer = a + for a in generator: + answer = a - len_seen = len(seen_content) - new_content = answer[len_seen:] + len_seen = len(seen_content) + new_content = answer[len_seen:] - if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. - continue + if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. + continue - seen_content = answer - chunk = text_streaming_chunk(new_content) - if include_usage: - chunk['usage'] = None - yield chunk + seen_content = answer + chunk = text_streaming_chunk(new_content) + if include_usage: + chunk['usage'] = None + yield chunk - completion_token_count = len(encode(answer)[0]) - stop_reason = "stop" - if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: - stop_reason = "length" + completion_token_count = len(encode(answer)[0]) + stop_reason = "stop" + if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: + stop_reason = "length" chunk = text_streaming_chunk(suffix) chunk[resp_list][0]["finish_reason"] = stop_reason diff --git a/modules/exllamav3.py b/modules/exllamav3.py index f873503a..3782a693 100644 --- a/modules/exllamav3.py +++ b/modules/exllamav3.py @@ -489,15 +489,35 @@ class Exllamav3Model: return id_to_piece = self.tokenizer.get_id_to_piece_list(True) + sampled_ids = result.get("token_ids") # (batch, seq_len) - actually sampled tokens + sampled_probs = result.get("token_probs") # (batch, seq_len) - their probabilities + + def _piece(tid): + s = id_to_piece[tid] if tid < len(id_to_piece) else f"<{tid}>" + return s.replace('\u2581', ' ') + + def _logprob(prob): + return math.log(prob) if prob > 0 else float("-inf") + # top_k_tokens shape: (batch, seq_len, k), top_k_probs same for seq_idx in range(top_k_tokens.shape[1]): entry = {"top_logprobs": []} for k_idx in range(top_k_tokens.shape[2]): token_id = top_k_tokens[0, seq_idx, k_idx].item() prob = top_k_probs[0, seq_idx, k_idx].item() - token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>" - logprob = math.log(prob) if prob > 0 else float("-inf") - entry["top_logprobs"].append({"token": token_str, "logprob": logprob}) + entry["top_logprobs"].append({"token": _piece(token_id), "logprob": _logprob(prob)}) + + # Record the actually sampled token at the entry level so + # format_completion_logprobs uses it instead of top_logprobs[0] + # (they differ with non-greedy sampling). + if sampled_ids is not None: + sid = sampled_ids[0, seq_idx].item() + entry["token"] = _piece(sid) + if sampled_probs is not None: + entry["logprob"] = _logprob(sampled_probs[0, seq_idx].item()) + else: + entry["logprob"] = None + self.last_completion_probabilities.append(entry) def generate(self, prompt, state): diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index fa968be1..34080466 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -310,8 +310,45 @@ class LlamaServer: else: raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}") + def get_prompt_logprob_entries(self, token_ids, n_probs=5, prompt=""): + """Get logprob entries for prompt tokens via a single n_predict=0 request. + + Requires llama.cpp server with prompt_logprobs support. + Returns entries in the standard format for format_completion_logprobs(). + """ + token_ids_list = token_ids.tolist() if hasattr(token_ids, 'tolist') else list(token_ids) + + url = f"http://127.0.0.1:{self.port}/completion" + payload = { + "prompt": token_ids_list, + "n_predict": 0, + "n_probs": n_probs, + "prompt_logprobs": True, + "stream": False, + "cache_prompt": False, + } + + response = self.session.post(url, json=payload) + result = response.json() + + prompt_probs = result.get("prompt_probabilities", []) + if not prompt_probs: + return [] + + # Null first token (no conditioning context); use empty string for BOS + # or tokens that don't appear at the start of the prompt text. + first_token_str = self.decode([token_ids_list[0]]) + if self.bos_token and first_token_str == self.bos_token: + first_token_str = "" + elif not prompt.startswith(first_token_str): + first_token_str = "" + + entries = [{"token": first_token_str, "null_logprob": True}] + entries.extend(prompt_probs) + return entries + def _get_vocabulary_size(self): - """Get and store the model's maximum context length.""" + """Get and store the model's vocabulary size.""" url = f"http://127.0.0.1:{self.port}/v1/models" response = self.session.get(url).json()