mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
llama.cpp: new optimization attempt
This commit is contained in:
parent
e2e90af6cd
commit
e2e73ed22f
|
|
@ -10,6 +10,7 @@ import llama_cpp_binaries
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.callbacks import Iteratorize
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
llamacpp_valid_cache_types = {"fp16", "q8_0", "q4_0"}
|
llamacpp_valid_cache_types = {"fp16", "q8_0", "q4_0"}
|
||||||
|
|
@ -118,7 +119,7 @@ class LlamaServer:
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def generate_with_streaming(self, prompt, state):
|
def generate(self, prompt, state, callback=None):
|
||||||
url = f"http://localhost:{self.port}/completion"
|
url = f"http://localhost:{self.port}/completion"
|
||||||
payload = self.prepare_payload(state)
|
payload = self.prepare_payload(state)
|
||||||
|
|
||||||
|
|
@ -144,7 +145,7 @@ class LlamaServer:
|
||||||
with self.session.post(url, json=payload, stream=True) as response:
|
with self.session.post(url, json=payload, stream=True) as response:
|
||||||
response.raise_for_status() # Raise an exception for HTTP errors
|
response.raise_for_status() # Raise an exception for HTTP errors
|
||||||
|
|
||||||
full_text = ""
|
output = ""
|
||||||
|
|
||||||
# Process the streaming response
|
# Process the streaming response
|
||||||
for line in response.iter_lines(decode_unicode=True):
|
for line in response.iter_lines(decode_unicode=True):
|
||||||
|
|
@ -162,9 +163,10 @@ class LlamaServer:
|
||||||
|
|
||||||
# Extract the token content
|
# Extract the token content
|
||||||
if 'content' in data:
|
if 'content' in data:
|
||||||
token_text = data['content']
|
text = data['content']
|
||||||
full_text += token_text
|
output += text
|
||||||
yield full_text
|
if callback:
|
||||||
|
callback(output)
|
||||||
|
|
||||||
# Check if generation is complete
|
# Check if generation is complete
|
||||||
if data.get('stop', False):
|
if data.get('stop', False):
|
||||||
|
|
@ -176,13 +178,13 @@ class LlamaServer:
|
||||||
print(f"Problematic line: {line}")
|
print(f"Problematic line: {line}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def generate(self, prompt, state):
|
|
||||||
output = ""
|
|
||||||
for output in self.generate_with_streaming(prompt, state):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def generate_with_streaming(self, *args, **kwargs):
|
||||||
|
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
|
||||||
|
for output in generator:
|
||||||
|
yield output
|
||||||
|
|
||||||
def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
|
def get_logits(self, prompt, state, n_probs=128, use_samplers=False):
|
||||||
"""Get the logits/probabilities for the next token after a prompt"""
|
"""Get the logits/probabilities for the next token after a prompt"""
|
||||||
url = f"http://localhost:{self.port}/completion"
|
url = f"http://localhost:{self.port}/completion"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue