llama.cpp: avoid two 'encode' calls

This commit is contained in:
oobabooga 2025-04-19 16:35:01 -07:00
parent ed42154c78
commit ba976d1390
2 changed files with 34 additions and 22 deletions

View file

@ -27,6 +27,8 @@ class LlamaServer:
self.session = requests.Session()
self.vocabulary_size = None
self.bos_token = "<s>"
self.last_input_length = 0
self.last_output_length = 0
# Start the server
self._start_server()
@ -140,6 +142,9 @@ class LlamaServer:
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
print()
self.last_input_length = len(token_ids)
self.last_output_length = 0
# Make a direct request with streaming enabled using a context manager
with self.session.post(url, json=payload, stream=True) as response:
response.raise_for_status() # Raise an exception for HTTP errors
@ -151,30 +156,32 @@ class LlamaServer:
if shared.stop_everything:
break
if line:
try:
# Check if the line starts with "data: " and remove it
if line.startswith('data: '):
line = line[6:] # Remove the "data: " prefix
if not line:
continue
# Parse the JSON data
data = json.loads(line)
try:
# Check if the line starts with "data: " and remove it
if line.startswith('data: '):
line = line[6:] # Remove the "data: " prefix
# Extract the token content
if 'content' in data:
token_text = data['content']
full_text += token_text
yield full_text
# Parse the JSON data
data = json.loads(line)
# Check if generation is complete
if data.get('stop', False):
break
# Extract the token content
if data.get('content', ''):
full_text += data['content']
self.last_output_length += 1
yield full_text
except json.JSONDecodeError as e:
# Log the error and the problematic line
print(f"JSON decode error: {e}")
print(f"Problematic line: {line}")
continue
# Check if generation is complete
if data.get('stop', False):
break
except json.JSONDecodeError as e:
# Log the error and the problematic line
print(f"JSON decode error: {e}")
print(f"Problematic line: {line}")
continue
def generate(self, prompt, state):
output = ""

View file

@ -481,8 +481,13 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str
traceback.print_exc()
finally:
t1 = time.time()
original_tokens = len(encode(original_question)[0])
new_tokens = len(encode(original_question + reply)[0]) - original_tokens
if shared.args.loader == 'llama.cpp':
original_tokens = shared.model.last_input_length
new_tokens = shared.model.last_output_length
else:
original_tokens = len(encode(original_question)[0])
new_tokens = len(encode(original_question + reply)[0]) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return