diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index 5071c40c..b0fe1154 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -10,6 +10,7 @@ import llama_cpp_binaries import requests from modules import shared +from modules.callbacks import Iteratorize from modules.logging_colors import logger llamacpp_valid_cache_types = {"fp16", "q8_0", "q4_0"} @@ -118,7 +119,7 @@ class LlamaServer: return payload - def generate_with_streaming(self, prompt, state): + def generate(self, prompt, state, callback=None): url = f"http://localhost:{self.port}/completion" payload = self.prepare_payload(state) @@ -144,7 +145,7 @@ class LlamaServer: with self.session.post(url, json=payload, stream=True) as response: response.raise_for_status() # Raise an exception for HTTP errors - full_text = "" + output = "" # Process the streaming response for line in response.iter_lines(decode_unicode=True): @@ -162,9 +163,10 @@ class LlamaServer: # Extract the token content if 'content' in data: - token_text = data['content'] - full_text += token_text - yield full_text + text = data['content'] + output += text + if callback: + callback(output) # Check if generation is complete if data.get('stop', False): @@ -176,13 +178,13 @@ class LlamaServer: print(f"Problematic line: {line}") continue - def generate(self, prompt, state): - output = "" - for output in self.generate_with_streaming(prompt, state): - pass - return output + def generate_with_streaming(self, *args, **kwargs): + with Iteratorize(self.generate, args, kwargs, callback=None) as generator: + for output in generator: + yield output + def get_logits(self, prompt, state, n_probs=128, use_samplers=False): """Get the logits/probabilities for the next token after a prompt""" url = f"http://localhost:{self.port}/completion"