Add StreamingLLM for llamacpp & llamacpp_HF (2nd attempt) (#5669)

This commit is contained in:
oobabooga 2024-03-09 00:25:33 -03:00 committed by GitHub
parent 9271e80914
commit afb51bd5d6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 147 additions and 0 deletions

View file

@ -2,6 +2,9 @@ from typing import Sequence
from tqdm import tqdm
from modules import shared
from modules.cache_utils import process_llamacpp_cache
try:
import llama_cpp
except:
@ -58,6 +61,25 @@ def eval_with_progress(self, tokens: Sequence[int]):
self.n_tokens += n_tokens
def monkey_patch_generate(lib):
def my_generate(self, *args, **kwargs):
if shared.args.streaming_llm:
new_sequence = args[0]
past_sequence = self._input_ids
# Do the cache trimming for StreamingLLM
process_llamacpp_cache(self, new_sequence, past_sequence)
for output in self.original_generate(*args, **kwargs):
yield output
lib.Llama.original_generate = lib.Llama.generate
lib.Llama.generate = my_generate
for lib in [llama_cpp, llama_cpp_cuda, llama_cpp_cuda_tensorcores]:
if lib is not None:
lib.Llama.eval = eval_with_progress
monkey_patch_generate(lib)