mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-10 17:13:39 +00:00
Add StreamingLLM for llamacpp & llamacpp_HF (2nd attempt) (#5669)
This commit is contained in:
parent
9271e80914
commit
afb51bd5d6
7 changed files with 147 additions and 0 deletions
|
|
@ -2,6 +2,9 @@ from typing import Sequence
|
|||
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import shared
|
||||
from modules.cache_utils import process_llamacpp_cache
|
||||
|
||||
try:
|
||||
import llama_cpp
|
||||
except:
|
||||
|
|
@ -58,6 +61,25 @@ def eval_with_progress(self, tokens: Sequence[int]):
|
|||
self.n_tokens += n_tokens
|
||||
|
||||
|
||||
def monkey_patch_generate(lib):
|
||||
|
||||
def my_generate(self, *args, **kwargs):
|
||||
|
||||
if shared.args.streaming_llm:
|
||||
new_sequence = args[0]
|
||||
past_sequence = self._input_ids
|
||||
|
||||
# Do the cache trimming for StreamingLLM
|
||||
process_llamacpp_cache(self, new_sequence, past_sequence)
|
||||
|
||||
for output in self.original_generate(*args, **kwargs):
|
||||
yield output
|
||||
|
||||
lib.Llama.original_generate = lib.Llama.generate
|
||||
lib.Llama.generate = my_generate
|
||||
|
||||
|
||||
for lib in [llama_cpp, llama_cpp_cuda, llama_cpp_cuda_tensorcores]:
|
||||
if lib is not None:
|
||||
lib.Llama.eval = eval_with_progress
|
||||
monkey_patch_generate(lib)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue