diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 5071c40c..02a56b3c 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -35,7 +35,7 @@ class LlamaServer: if self.bos_token and text.startswith(self.bos_token): add_bos_token = False - url = f"http://localhost:{self.port}/tokenize" + url = f"http://127.0.0.1:{self.port}/tokenize" payload = { "content": text, "add_special": add_bos_token, @@ -46,7 +46,7 @@ class LlamaServer: return result.get("tokens", []) def decode(self, token_ids, **kwargs): - url = f"http://localhost:{self.port}/detokenize" + url = f"http://127.0.0.1:{self.port}/detokenize" payload = { "tokens": token_ids, } @@ -119,7 +119,7 @@ class LlamaServer: return payload def generate_with_streaming(self, prompt, state): - url = f"http://localhost:{self.port}/completion" + url = f"http://127.0.0.1:{self.port}/completion" payload = self.prepare_payload(state) token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"]) @@ -147,34 +147,37 @@ class LlamaServer: full_text = "" # Process the streaming response - for line in response.iter_lines(decode_unicode=True): + for line in response.iter_lines(): if shared.stop_everything: break - if line: - try: - # Check if the line starts with "data: " and remove it - if line.startswith('data: '): - line = line[6:] # Remove the "data: " prefix + if not line: + continue - # Parse the JSON data - data = json.loads(line) + try: + line = line.decode('utf-8') - # Extract the token content - if 'content' in data: - token_text = data['content'] - full_text += token_text - yield full_text + # Check if the line starts with "data: " and remove it + if line.startswith('data: '): + line = line[6:] # Remove the "data: " prefix - # Check if generation is complete - if data.get('stop', False): - break + # Parse the JSON data + data = json.loads(line) - except json.JSONDecodeError as e: - # Log the error and the problematic line - print(f"JSON decode error: {e}") - print(f"Problematic line: {line}") - continue + # Extract the token content + if data.get('content', ''): + full_text += data['content'] + yield full_text + + # Check if generation is complete + if data.get('stop', False): + break + + except json.JSONDecodeError as e: + # Log the error and the problematic line + print(f"JSON decode error: {e}") + print(f"Problematic line: {line}") + continue def generate(self, prompt, state): output = "" @@ -185,7 +188,7 @@ class LlamaServer: def get_logits(self, prompt, state, n_probs=128, use_samplers=False): """Get the logits/probabilities for the next token after a prompt""" - url = f"http://localhost:{self.port}/completion" + url = f"http://127.0.0.1:{self.port}/completion" payload = self.prepare_payload(state) payload.update({ @@ -216,7 +219,7 @@ class LlamaServer: def _get_vocabulary_size(self): """Get and store the model's maximum context length.""" - url = f"http://localhost:{self.port}/v1/models" + url = f"http://127.0.0.1:{self.port}/v1/models" response = self.session.get(url).json() if "data" in response and len(response["data"]) > 0: @@ -226,7 +229,7 @@ class LlamaServer: def _get_bos_token(self): """Get and store the model's BOS token.""" - url = f"http://localhost:{self.port}/props" + url = f"http://127.0.0.1:{self.port}/props" response = self.session.get(url).json() if "bos_token" in response: self.bos_token = response["bos_token"] @@ -299,7 +302,7 @@ class LlamaServer: threading.Thread(target=filter_stderr, args=(self.process.stderr,), daemon=True).start() # Wait for server to be healthy - health_url = f"http://localhost:{self.port}/health" + health_url = f"http://127.0.0.1:{self.port}/health" start_time = time.time() timeout = 3600 * 8 # 8 hours while time.time() - start_time < timeout: