diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 822800b9..34aab613 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -140,42 +140,42 @@ class LlamaServer: pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) print() - # Make a direct request with streaming enabled - response = self.session.post(url, json=payload, stream=True) - response.raise_for_status() # Raise an exception for HTTP errors + # Make a direct request with streaming enabled using a context manager + with self.session.post(url, json=payload, stream=True) as response: + response.raise_for_status() # Raise an exception for HTTP errors - full_text = "" + full_text = "" - # Process the streaming response - for line in response.iter_lines(): - if shared.stop_everything: - break + # Process the streaming response + for line in response.iter_lines(): + if shared.stop_everything: + break - if line: - try: - # Check if the line starts with "data: " and remove it - line_str = line.decode('utf-8') - if line_str.startswith('data: '): - line_str = line_str[6:] # Remove the "data: " prefix + if line: + try: + # Check if the line starts with "data: " and remove it + line_str = line.decode('utf-8') + if line_str.startswith('data: '): + line_str = line_str[6:] # Remove the "data: " prefix - # Parse the JSON data - data = json.loads(line_str) + # Parse the JSON data + data = json.loads(line_str) - # Extract the token content - if 'content' in data: - token_text = data['content'] - full_text += token_text - yield full_text + # Extract the token content + if 'content' in data: + token_text = data['content'] + full_text += token_text + yield full_text - # Check if generation is complete - if data.get('stop', False): - break + # Check if generation is complete + if data.get('stop', False): + break - except json.JSONDecodeError as e: - # Log the error and the problematic line - print(f"JSON decode error: {e}") - print(f"Problematic line: {line}") - continue + except json.JSONDecodeError as e: + # Log the error and the problematic line + print(f"JSON decode error: {e}") + print(f"Problematic line: {line}") + continue def generate(self, prompt, state): output = ""