From 609c3ac8936092af1df48305a49af3f0a3da165e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 15 Jun 2025 08:03:27 -0700 Subject: [PATCH] Optimize the end of generation with llama.cpp --- modules/llama_cpp_server.py | 2 ++ modules/text_generation.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index a79e24e4..e64f1694 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -30,6 +30,7 @@ class LlamaServer: self.session = requests.Session() self.vocabulary_size = None self.bos_token = "" + self.last_prompt_token_count = 0 # Start the server self._start_server() @@ -128,6 +129,7 @@ class LlamaServer: payload = self.prepare_payload(state) token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"]) + self.last_prompt_token_count = len(token_ids) if state['auto_max_new_tokens']: max_new_tokens = state['truncation_length'] - len(token_ids) else: diff --git a/modules/text_generation.py b/modules/text_generation.py index 55b538b0..a75141f1 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -498,8 +498,14 @@ def generate_reply_custom(question, original_question, state, stopping_strings=N traceback.print_exc() finally: t1 = time.time() - original_tokens = len(encode(original_question)[0]) - new_tokens = len(encode(original_question + reply)[0]) - original_tokens + + if hasattr(shared.model, 'last_prompt_token_count'): + original_tokens = shared.model.last_prompt_token_count + new_tokens = len(encode(reply)[0]) if reply else 0 + else: + original_tokens = len(encode(original_question)[0]) + new_tokens = len(encode(original_question + reply)[0]) - original_tokens + logger.info(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {state["seed"]})') return