From b3342b8dd8eb3f3c191ef0706ba2ab1493fa8937 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 18 Apr 2025 18:46:36 -0700 Subject: [PATCH] llama.cpp: optimize the llama-server connection --- modules/llama_cpp_server.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 26ab8f10..822800b9 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -24,6 +24,7 @@ class LlamaServer: self.server_path = server_path self.port = self._find_available_port() self.process = None + self.session = requests.Session() self.vocabulary_size = None self.bos_token = "" @@ -40,7 +41,7 @@ class LlamaServer: "add_special": add_bos_token, } - response = requests.post(url, json=payload) + response = self.session.post(url, json=payload) result = response.json() return result.get("tokens", []) @@ -50,7 +51,7 @@ class LlamaServer: "tokens": token_ids, } - response = requests.post(url, json=payload) + response = self.session.post(url, json=payload) result = response.json() return result.get("content", "") @@ -140,7 +141,7 @@ class LlamaServer: print() # Make a direct request with streaming enabled - response = requests.post(url, json=payload, stream=True) + response = self.session.post(url, json=payload, stream=True) response.raise_for_status() # Raise an exception for HTTP errors full_text = "" @@ -203,7 +204,7 @@ class LlamaServer: pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) print() - response = requests.post(url, json=payload) + response = self.session.post(url, json=payload) result = response.json() if "completion_probabilities" in result: @@ -217,7 +218,7 @@ class LlamaServer: def _get_vocabulary_size(self): """Get and store the model's maximum context length.""" url = f"http://localhost:{self.port}/v1/models" - response = requests.get(url).json() + response = self.session.get(url).json() if "data" in response and len(response["data"]) > 0: model_info = response["data"][0] @@ -227,7 +228,7 @@ class LlamaServer: def _get_bos_token(self): """Get and store the model's BOS token.""" url = f"http://localhost:{self.port}/props" - response = requests.get(url).json() + response = self.session.get(url).json() if "bos_token" in response: self.bos_token = response["bos_token"] @@ -309,7 +310,7 @@ class LlamaServer: raise RuntimeError(f"Server process terminated unexpectedly with exit code: {exit_code}") try: - response = requests.get(health_url) + response = self.session.get(health_url) if response.status_code == 200: break except: