mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
llama.cpp: new optimization attempt
This commit is contained in:
parent
e2e90af6cd
commit
e2e73ed22f
|
|
@ -10,6 +10,7 @@ import llama_cpp_binaries
|
|||
import requests
|
||||
|
||||
from modules import shared
|
||||
from modules.callbacks import Iteratorize
|
||||
from modules.logging_colors import logger
|
||||
|
||||
llamacpp_valid_cache_types = {"fp16", "q8_0", "q4_0"}
|
||||
|
|
@ -118,7 +119,7 @@ class LlamaServer:
|
|||
|
||||
return payload
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
def generate(self, prompt, state, callback=None):
|
||||
url = f"http://localhost:{self.port}/completion"
|
||||
payload = self.prepare_payload(state)
|
||||
|
||||
|
|
@ -144,7 +145,7 @@ class LlamaServer:
|
|||
with self.session.post(url, json=payload, stream=True) as response:
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
|
||||
full_text = ""
|
||||
output = ""
|
||||
|
||||
# Process the streaming response
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
|
|
@ -162,9 +163,10 @@ class LlamaServer:
|
|||
|
||||
# Extract the token content
|
||||
if 'content' in data:
|
||||
token_text = data['content']
|
||||
full_text += token_text
|
||||
yield full_text
|
||||
text = data['content']
|
||||
output += text
|
||||
if callback:
|
||||
callback(output)
|
||||
|
||||
# Check if generation is complete
|
||||
if data.get('stop', False):
|
||||
|
|
@ -176,13 +178,13 @@ class LlamaServer:
|
|||
print(f"Problematic line: {line}")
|
||||
continue
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ""
|
||||
for output in self.generate_with_streaming(prompt, state):
|
||||
pass
|
||||
|
||||
return output
|
||||
|
||||
def generate_with_streaming(self, *args, **kwargs):
|
||||
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
|
||||
for output in generator:
|
||||
yield output
|
||||
|
||||
def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
|
||||
"""Get the logits/probabilities for the next token after a prompt"""
|
||||
url = f"http://localhost:{self.port}/completion"
|
||||
|
|
|
|||
Loading…
Reference in a new issue