diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index c01f5d5b..f77d2e07 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -296,19 +296,26 @@ class LlamaServer: pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) print() - for retry in range(5): - response = self.session.post(url, json=payload) - result = response.json() + def _try_fetch_logits(): + for retry in range(5): + response = self.session.post(url, json=payload) + result = response.json() - if "completion_probabilities" in result: - if use_samplers: - return result["completion_probabilities"][0]["top_probs"] - else: - return result["completion_probabilities"][0]["top_logprobs"] + if "completion_probabilities" in result: + if use_samplers: + return result["completion_probabilities"][0]["top_probs"] + else: + return result["completion_probabilities"][0]["top_logprobs"] - time.sleep(0.05) - else: - raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}") + time.sleep(0.05) + else: + raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}") + + result = _try_fetch_logits() + for entry in result: + if not entry.get('token'): + entry['token'] = self.decode([entry['id']]) + return result def get_prompt_logprob_entries(self, token_ids, n_probs=5, prompt=""): """Get logprob entries for prompt tokens via a single n_predict=0 request. diff --git a/modules/logits.py b/modules/logits.py index 473f5890..09849721 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -36,15 +36,22 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur return error_message, previous # llama.cpp case + def _escaped(token): + chars = [] + for a in token: + # C0 and DEL and C1 + if ord(a) <= 0x1F or 0x7F <= ord(a) <= 0x9F: + chars.append(repr(a)[1:-1]) + else: + chars.append(a) + return ''.join(chars) if shared.model.__class__.__name__ == 'LlamaServer': logprobs = shared.model.get_logits(prompt, state, n_probs=top_logits, use_samplers=use_samplers) if return_dict: output = {} for entry in logprobs: - token = repr(entry['token']) - if len(token) > 2 and token.startswith("'") and token.endswith("'"): - token = token[1:-1] + token = _escaped(entry['token']) prob = entry['prob'] if use_samplers else np.exp(entry['logprob']) output[token] = prob @@ -52,12 +59,11 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur else: output = '' for entry in logprobs: - token = repr(entry['token']) - if len(token) > 2 and token.startswith("'") and token.endswith("'"): - token = token[1:-1] + token = _escaped(entry['token']) + token_id = entry['id'] prob = entry['prob'] if use_samplers else np.exp(entry['logprob']) - output += f"{prob:.5f} - {token}\n" + output += f"{prob:.5f} - [{token}] ({token_id})\n" return output, previous # All other model types