llama.cpp: new optimization attempt

This commit is contained in:
oobabooga 2025-04-18 21:04:56 -07:00
parent e2e90af6cd
commit e2e73ed22f

View file

@ -10,6 +10,7 @@ import llama_cpp_binaries
import requests
from modules import shared
from modules.callbacks import Iteratorize
from modules.logging_colors import logger
llamacpp_valid_cache_types = {"fp16", "q8_0", "q4_0"}
@ -118,7 +119,7 @@ class LlamaServer:
return payload
def generate_with_streaming(self, prompt, state):
def generate(self, prompt, state, callback=None):
url = f"http://localhost:{self.port}/completion"
payload = self.prepare_payload(state)
@ -144,7 +145,7 @@ class LlamaServer:
with self.session.post(url, json=payload, stream=True) as response:
response.raise_for_status() # Raise an exception for HTTP errors
full_text = ""
output = ""
# Process the streaming response
for line in response.iter_lines(decode_unicode=True):
@ -162,9 +163,10 @@ class LlamaServer:
# Extract the token content
if 'content' in data:
token_text = data['content']
full_text += token_text
yield full_text
text = data['content']
output += text
if callback:
callback(output)
# Check if generation is complete
if data.get('stop', False):
@ -176,13 +178,13 @@ class LlamaServer:
print(f"Problematic line: {line}")
continue
def generate(self, prompt, state):
output = ""
for output in self.generate_with_streaming(prompt, state):
pass
return output
def generate_with_streaming(self, *args, **kwargs):
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
for output in generator:
yield output
def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
"""Get the logits/probabilities for the next token after a prompt"""
url = f"http://localhost:{self.port}/completion"