mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-10 17:13:39 +00:00
Revert the llama-cpp-python update
This commit is contained in:
parent
a687f950ba
commit
0f53a736c1
13 changed files with 109 additions and 45 deletions
|
|
@ -1,9 +1,13 @@
|
|||
import importlib
|
||||
import platform
|
||||
from typing import Sequence
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import shared
|
||||
from modules.cache_utils import process_llamacpp_cache
|
||||
|
||||
|
||||
imported_module = None
|
||||
|
||||
|
||||
|
|
@ -21,6 +25,7 @@ def llama_cpp_lib():
|
|||
else:
|
||||
lib_names = [
|
||||
('cpu', 'llama_cpp'),
|
||||
('tensorcores', 'llama_cpp_cuda_tensorcores'),
|
||||
(None, 'llama_cpp_cuda'),
|
||||
(None, 'llama_cpp')
|
||||
]
|
||||
|
|
@ -44,6 +49,48 @@ def llama_cpp_lib():
|
|||
return None
|
||||
|
||||
|
||||
def eval_with_progress(self, tokens: Sequence[int]):
|
||||
"""
|
||||
A copy of
|
||||
|
||||
https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py
|
||||
|
||||
with tqdm to show prompt processing progress.
|
||||
"""
|
||||
assert self._ctx.ctx is not None
|
||||
assert self._batch.batch is not None
|
||||
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
||||
|
||||
if len(tokens) > 1:
|
||||
progress_bar = tqdm(range(0, len(tokens), self.n_batch), desc="Prompt evaluation", leave=False)
|
||||
else:
|
||||
progress_bar = range(0, len(tokens), self.n_batch)
|
||||
|
||||
for i in progress_bar:
|
||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||
n_past = self.n_tokens
|
||||
n_tokens = len(batch)
|
||||
self._batch.set_batch(
|
||||
batch=batch, n_past=n_past, logits_all=self.context_params.logits_all
|
||||
)
|
||||
self._ctx.decode(self._batch)
|
||||
# Save tokens
|
||||
self.input_ids[n_past : n_past + n_tokens] = batch
|
||||
# Save logits
|
||||
if self.context_params.logits_all:
|
||||
rows = n_tokens
|
||||
cols = self._n_vocab
|
||||
logits = self._ctx.get_logits()[: rows * cols]
|
||||
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits
|
||||
else:
|
||||
rows = 1
|
||||
cols = self._n_vocab
|
||||
logits = self._ctx.get_logits()[: rows * cols]
|
||||
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits
|
||||
# Update n_tokens
|
||||
self.n_tokens += n_tokens
|
||||
|
||||
|
||||
def monkey_patch_llama_cpp_python(lib):
|
||||
if getattr(lib.Llama, '_is_patched', False):
|
||||
# If the patch is already applied, do nothing
|
||||
|
|
@ -60,6 +107,7 @@ def monkey_patch_llama_cpp_python(lib):
|
|||
for output in self.original_generate(*args, **kwargs):
|
||||
yield output
|
||||
|
||||
lib.Llama.eval = eval_with_progress
|
||||
lib.Llama.original_generate = lib.Llama.generate
|
||||
lib.Llama.generate = my_generate
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue