From 5fdebc554b7ca46afb9695babf89397635e9f91d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 19 Apr 2025 04:59:24 -0700 Subject: [PATCH 1/7] llama.cpp: close the connection immediately on 'Stop' --- modules/llama_cpp_server.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 5071c40c..3025aa7d 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -141,16 +141,24 @@ class LlamaServer: print() # Make a direct request with streaming enabled using a context manager - with self.session.post(url, json=payload, stream=True) as response: + with self.session.post(url, json=payload, stream=True, timeout=(5, 0.1)) as response: response.raise_for_status() # Raise an exception for HTTP errors full_text = "" + iterator = response.iter_lines(decode_unicode=True) - # Process the streaming response - for line in response.iter_lines(decode_unicode=True): + while True: if shared.stop_everything: break + try: + line = next(iterator) + except requests.exceptions.Timeout: + # Check stop flag again on timeout + continue + except StopIteration: + break + if line: try: # Check if the line starts with "data: " and remove it From ed42154c78d7d7a63c092684deaa47d22750d796 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 19 Apr 2025 05:32:36 -0700 Subject: [PATCH 2/7] Revert "llama.cpp: close the connection immediately on 'Stop'" This reverts commit 5fdebc554b7ca46afb9695babf89397635e9f91d. --- modules/llama_cpp_server.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 3025aa7d..5071c40c 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -141,24 +141,16 @@ class LlamaServer: print() # Make a direct request with streaming enabled using a context manager - with self.session.post(url, json=payload, stream=True, timeout=(5, 0.1)) as response: + with self.session.post(url, json=payload, stream=True) as response: response.raise_for_status() # Raise an exception for HTTP errors full_text = "" - iterator = response.iter_lines(decode_unicode=True) - while True: + # Process the streaming response + for line in response.iter_lines(decode_unicode=True): if shared.stop_everything: break - try: - line = next(iterator) - except requests.exceptions.Timeout: - # Check stop flag again on timeout - continue - except StopIteration: - break - if line: try: # Check if the line starts with "data: " and remove it From ba976d1390b7151633d693b270604517fd14e701 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 19 Apr 2025 16:35:01 -0700 Subject: [PATCH 3/7] llama.cpp: avoid two 'encode' calls --- modules/llama_cpp_server.py | 47 +++++++++++++++++++++---------------- modules/text_generation.py | 9 +++++-- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 5071c40c..faf6e20e 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -27,6 +27,8 @@ class LlamaServer: self.session = requests.Session() self.vocabulary_size = None self.bos_token = "" + self.last_input_length = 0 + self.last_output_length = 0 # Start the server self._start_server() @@ -140,6 +142,9 @@ class LlamaServer: pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) print() + self.last_input_length = len(token_ids) + self.last_output_length = 0 + # Make a direct request with streaming enabled using a context manager with self.session.post(url, json=payload, stream=True) as response: response.raise_for_status() # Raise an exception for HTTP errors @@ -151,30 +156,32 @@ class LlamaServer: if shared.stop_everything: break - if line: - try: - # Check if the line starts with "data: " and remove it - if line.startswith('data: '): - line = line[6:] # Remove the "data: " prefix + if not line: + continue - # Parse the JSON data - data = json.loads(line) + try: + # Check if the line starts with "data: " and remove it + if line.startswith('data: '): + line = line[6:] # Remove the "data: " prefix - # Extract the token content - if 'content' in data: - token_text = data['content'] - full_text += token_text - yield full_text + # Parse the JSON data + data = json.loads(line) - # Check if generation is complete - if data.get('stop', False): - break + # Extract the token content + if data.get('content', ''): + full_text += data['content'] + self.last_output_length += 1 + yield full_text - except json.JSONDecodeError as e: - # Log the error and the problematic line - print(f"JSON decode error: {e}") - print(f"Problematic line: {line}") - continue + # Check if generation is complete + if data.get('stop', False): + break + + except json.JSONDecodeError as e: + # Log the error and the problematic line + print(f"JSON decode error: {e}") + print(f"Problematic line: {line}") + continue def generate(self, prompt, state): output = "" diff --git a/modules/text_generation.py b/modules/text_generation.py index 16aba3cb..675eb379 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -481,8 +481,13 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str traceback.print_exc() finally: t1 = time.time() - original_tokens = len(encode(original_question)[0]) - new_tokens = len(encode(original_question + reply)[0]) - original_tokens + if shared.args.loader == 'llama.cpp': + original_tokens = shared.model.last_input_length + new_tokens = shared.model.last_output_length + else: + original_tokens = len(encode(original_question)[0]) + new_tokens = len(encode(original_question + reply)[0]) - original_tokens + print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return From 9c9df2063f61d19fa1755b2906ba6804f88c4e68 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 19 Apr 2025 16:38:02 -0700 Subject: [PATCH 4/7] llama.cpp: fix unicode decoding (closes #6856) --- modules/llama_cpp_server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index faf6e20e..9c97e00b 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -152,7 +152,7 @@ class LlamaServer: full_text = "" # Process the streaming response - for line in response.iter_lines(decode_unicode=True): + for line in response.iter_lines(): if shared.stop_everything: break @@ -160,6 +160,8 @@ class LlamaServer: continue try: + line = line.decode('utf-8') + # Check if the line starts with "data: " and remove it if line.startswith('data: '): line = line[6:] # Remove the "data: " prefix From b9da5c7e3a34ef1a5ae2e372db7ff2c5299d523a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 19 Apr 2025 17:36:04 -0700 Subject: [PATCH 5/7] Use 127.0.0.1 instead of localhost for faster llama.cpp on Windows --- modules/llama_cpp_server.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 9c97e00b..ebce987a 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -37,7 +37,7 @@ class LlamaServer: if self.bos_token and text.startswith(self.bos_token): add_bos_token = False - url = f"http://localhost:{self.port}/tokenize" + url = f"http://127.0.0.1:{self.port}/tokenize" payload = { "content": text, "add_special": add_bos_token, @@ -48,7 +48,7 @@ class LlamaServer: return result.get("tokens", []) def decode(self, token_ids, **kwargs): - url = f"http://localhost:{self.port}/detokenize" + url = f"http://127.0.0.1:{self.port}/detokenize" payload = { "tokens": token_ids, } @@ -121,7 +121,7 @@ class LlamaServer: return payload def generate_with_streaming(self, prompt, state): - url = f"http://localhost:{self.port}/completion" + url = f"http://127.0.0.1:{self.port}/completion" payload = self.prepare_payload(state) token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"]) @@ -194,7 +194,7 @@ class LlamaServer: def get_logits(self, prompt, state, n_probs=128, use_samplers=False): """Get the logits/probabilities for the next token after a prompt""" - url = f"http://localhost:{self.port}/completion" + url = f"http://127.0.0.1:{self.port}/completion" payload = self.prepare_payload(state) payload.update({ @@ -225,7 +225,7 @@ class LlamaServer: def _get_vocabulary_size(self): """Get and store the model's maximum context length.""" - url = f"http://localhost:{self.port}/v1/models" + url = f"http://127.0.0.1:{self.port}/v1/models" response = self.session.get(url).json() if "data" in response and len(response["data"]) > 0: @@ -235,7 +235,7 @@ class LlamaServer: def _get_bos_token(self): """Get and store the model's BOS token.""" - url = f"http://localhost:{self.port}/props" + url = f"http://127.0.0.1:{self.port}/props" response = self.session.get(url).json() if "bos_token" in response: self.bos_token = response["bos_token"] @@ -308,7 +308,7 @@ class LlamaServer: threading.Thread(target=filter_stderr, args=(self.process.stderr,), daemon=True).start() # Wait for server to be healthy - health_url = f"http://localhost:{self.port}/health" + health_url = f"http://127.0.0.1:{self.port}/health" start_time = time.time() timeout = 3600 * 8 # 8 hours while time.time() - start_time < timeout: From 5ab069786befb5473142a2c03de3a3870ba4a151 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 19 Apr 2025 17:38:36 -0700 Subject: [PATCH 6/7] llama.cpp: add back the two encode calls (they are harmless now) --- modules/llama_cpp_server.py | 6 ------ modules/text_generation.py | 8 ++------ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index ebce987a..02a56b3c 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -27,8 +27,6 @@ class LlamaServer: self.session = requests.Session() self.vocabulary_size = None self.bos_token = "" - self.last_input_length = 0 - self.last_output_length = 0 # Start the server self._start_server() @@ -142,9 +140,6 @@ class LlamaServer: pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) print() - self.last_input_length = len(token_ids) - self.last_output_length = 0 - # Make a direct request with streaming enabled using a context manager with self.session.post(url, json=payload, stream=True) as response: response.raise_for_status() # Raise an exception for HTTP errors @@ -172,7 +167,6 @@ class LlamaServer: # Extract the token content if data.get('content', ''): full_text += data['content'] - self.last_output_length += 1 yield full_text # Check if generation is complete diff --git a/modules/text_generation.py b/modules/text_generation.py index 675eb379..70f03443 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -481,12 +481,8 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str traceback.print_exc() finally: t1 = time.time() - if shared.args.loader == 'llama.cpp': - original_tokens = shared.model.last_input_length - new_tokens = shared.model.last_output_length - else: - original_tokens = len(encode(original_question)[0]) - new_tokens = len(encode(original_question + reply)[0]) - original_tokens + original_tokens = len(encode(original_question)[0]) + new_tokens = len(encode(original_question + reply)[0]) - original_tokens print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return From 6ba0164c70173cb03584d5954793cd1bec7c593b Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 19 Apr 2025 17:45:21 -0700 Subject: [PATCH 7/7] Lint --- modules/text_generation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 70f03443..16aba3cb 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -483,7 +483,6 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str t1 = time.time() original_tokens = len(encode(original_question)[0]) new_tokens = len(encode(original_question + reply)[0]) - original_tokens - print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return