llama.cpp: close the connection properly when generation is cancelled

This commit is contained in:
oobabooga 2025-04-18 19:01:39 -07:00
parent b3342b8dd8
commit f727b4a2cc

View file

@ -140,42 +140,42 @@ class LlamaServer:
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
print()
# Make a direct request with streaming enabled
response = self.session.post(url, json=payload, stream=True)
response.raise_for_status() # Raise an exception for HTTP errors
# Make a direct request with streaming enabled using a context manager
with self.session.post(url, json=payload, stream=True) as response:
response.raise_for_status() # Raise an exception for HTTP errors
full_text = ""
full_text = ""
# Process the streaming response
for line in response.iter_lines():
if shared.stop_everything:
break
# Process the streaming response
for line in response.iter_lines():
if shared.stop_everything:
break
if line:
try:
# Check if the line starts with "data: " and remove it
line_str = line.decode('utf-8')
if line_str.startswith('data: '):
line_str = line_str[6:] # Remove the "data: " prefix
if line:
try:
# Check if the line starts with "data: " and remove it
line_str = line.decode('utf-8')
if line_str.startswith('data: '):
line_str = line_str[6:] # Remove the "data: " prefix
# Parse the JSON data
data = json.loads(line_str)
# Parse the JSON data
data = json.loads(line_str)
# Extract the token content
if 'content' in data:
token_text = data['content']
full_text += token_text
yield full_text
# Extract the token content
if 'content' in data:
token_text = data['content']
full_text += token_text
yield full_text
# Check if generation is complete
if data.get('stop', False):
break
# Check if generation is complete
if data.get('stop', False):
break
except json.JSONDecodeError as e:
# Log the error and the problematic line
print(f"JSON decode error: {e}")
print(f"Problematic line: {line}")
continue
except json.JSONDecodeError as e:
# Log the error and the problematic line
print(f"JSON decode error: {e}")
print(f"Problematic line: {line}")
continue
def generate(self, prompt, state):
output = ""