diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 5071c40c..faf6e20e 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -27,6 +27,8 @@ class LlamaServer: self.session = requests.Session() self.vocabulary_size = None self.bos_token = "" + self.last_input_length = 0 + self.last_output_length = 0 # Start the server self._start_server() @@ -140,6 +142,9 @@ class LlamaServer: pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) print() + self.last_input_length = len(token_ids) + self.last_output_length = 0 + # Make a direct request with streaming enabled using a context manager with self.session.post(url, json=payload, stream=True) as response: response.raise_for_status() # Raise an exception for HTTP errors @@ -151,30 +156,32 @@ class LlamaServer: if shared.stop_everything: break - if line: - try: - # Check if the line starts with "data: " and remove it - if line.startswith('data: '): - line = line[6:] # Remove the "data: " prefix + if not line: + continue - # Parse the JSON data - data = json.loads(line) + try: + # Check if the line starts with "data: " and remove it + if line.startswith('data: '): + line = line[6:] # Remove the "data: " prefix - # Extract the token content - if 'content' in data: - token_text = data['content'] - full_text += token_text - yield full_text + # Parse the JSON data + data = json.loads(line) - # Check if generation is complete - if data.get('stop', False): - break + # Extract the token content + if data.get('content', ''): + full_text += data['content'] + self.last_output_length += 1 + yield full_text - except json.JSONDecodeError as e: - # Log the error and the problematic line - print(f"JSON decode error: {e}") - print(f"Problematic line: {line}") - continue + # Check if generation is complete + if data.get('stop', False): + break + + except json.JSONDecodeError as e: + # Log the error and the problematic line + print(f"JSON decode error: {e}") + print(f"Problematic line: {line}") + continue def generate(self, prompt, state): output = "" diff --git a/modules/text_generation.py b/modules/text_generation.py index 16aba3cb..675eb379 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -481,8 +481,13 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str traceback.print_exc() finally: t1 = time.time() - original_tokens = len(encode(original_question)[0]) - new_tokens = len(encode(original_question + reply)[0]) - original_tokens + if shared.args.loader == 'llama.cpp': + original_tokens = shared.model.last_input_length + new_tokens = shared.model.last_output_length + else: + original_tokens = len(encode(original_question)[0]) + new_tokens = len(encode(original_question + reply)[0]) - original_tokens + print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return