diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index b0fe1154..5071c40c 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -10,7 +10,6 @@ import llama_cpp_binaries import requests from modules import shared -from modules.callbacks import Iteratorize from modules.logging_colors import logger llamacpp_valid_cache_types = {"fp16", "q8_0", "q4_0"} @@ -119,7 +118,7 @@ class LlamaServer: return payload - def generate(self, prompt, state, callback=None): + def generate_with_streaming(self, prompt, state): url = f"http://localhost:{self.port}/completion" payload = self.prepare_payload(state) @@ -145,7 +144,7 @@ class LlamaServer: with self.session.post(url, json=payload, stream=True) as response: response.raise_for_status() # Raise an exception for HTTP errors - output = "" + full_text = "" # Process the streaming response for line in response.iter_lines(decode_unicode=True): @@ -163,10 +162,9 @@ class LlamaServer: # Extract the token content if 'content' in data: - text = data['content'] - output += text - if callback: - callback(output) + token_text = data['content'] + full_text += token_text + yield full_text # Check if generation is complete if data.get('stop', False): @@ -178,12 +176,12 @@ class LlamaServer: print(f"Problematic line: {line}") continue - return output + def generate(self, prompt, state): + output = "" + for output in self.generate_with_streaming(prompt, state): + pass - def generate_with_streaming(self, *args, **kwargs): - with Iteratorize(self.generate, args, kwargs, callback=None) as generator: - for output in generator: - yield output + return output def get_logits(self, prompt, state, n_probs=128, use_samplers=False): """Get the logits/probabilities for the next token after a prompt"""