llama.cpp: optimize the llama-server connection

This commit is contained in:
oobabooga 2025-04-18 18:46:36 -07:00
parent 2002590536
commit b3342b8dd8

View file

@ -24,6 +24,7 @@ class LlamaServer:
self.server_path = server_path
self.port = self._find_available_port()
self.process = None
self.session = requests.Session()
self.vocabulary_size = None
self.bos_token = "<s>"
@ -40,7 +41,7 @@ class LlamaServer:
"add_special": add_bos_token,
}
response = requests.post(url, json=payload)
response = self.session.post(url, json=payload)
result = response.json()
return result.get("tokens", [])
@ -50,7 +51,7 @@ class LlamaServer:
"tokens": token_ids,
}
response = requests.post(url, json=payload)
response = self.session.post(url, json=payload)
result = response.json()
return result.get("content", "")
@ -140,7 +141,7 @@ class LlamaServer:
print()
# Make a direct request with streaming enabled
response = requests.post(url, json=payload, stream=True)
response = self.session.post(url, json=payload, stream=True)
response.raise_for_status() # Raise an exception for HTTP errors
full_text = ""
@ -203,7 +204,7 @@ class LlamaServer:
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
print()
response = requests.post(url, json=payload)
response = self.session.post(url, json=payload)
result = response.json()
if "completion_probabilities" in result:
@ -217,7 +218,7 @@ class LlamaServer:
def _get_vocabulary_size(self):
"""Get and store the model's maximum context length."""
url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url).json()
response = self.session.get(url).json()
if "data" in response and len(response["data"]) > 0:
model_info = response["data"][0]
@ -227,7 +228,7 @@ class LlamaServer:
def _get_bos_token(self):
"""Get and store the model's BOS token."""
url = f"http://localhost:{self.port}/props"
response = requests.get(url).json()
response = self.session.get(url).json()
if "bos_token" in response:
self.bos_token = response["bos_token"]
@ -309,7 +310,7 @@ class LlamaServer:
raise RuntimeError(f"Server process terminated unexpectedly with exit code: {exit_code}")
try:
response = requests.get(health_url)
response = self.session.get(health_url)
if response.status_code == 200:
break
except: