Update llama_cpp_python_hijack.py, fix llamacpp_hf

This commit is contained in:
oobabooga 2024-09-30 14:04:21 -07:00
parent 9ca0cd7749
commit 4d9ce586d3
2 changed files with 14 additions and 8 deletions

View file

@ -2,12 +2,12 @@ import importlib
import platform
from typing import Sequence
import numpy as np
from tqdm import tqdm
from modules import shared
from modules.cache_utils import process_llamacpp_cache
imported_module = None
@ -57,8 +57,6 @@ def eval_with_progress(self, tokens: Sequence[int]):
with tqdm to show prompt processing progress.
"""
assert self._ctx.ctx is not None
assert self._batch.batch is not None
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
if len(tokens) > self.n_batch:
@ -80,13 +78,20 @@ def eval_with_progress(self, tokens: Sequence[int]):
if self.context_params.logits_all:
rows = n_tokens
cols = self._n_vocab
logits = self._ctx.get_logits()[: rows * cols]
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits
logits = np.ctypeslib.as_array(
self._ctx.get_logits(), shape=(rows * cols,)
)
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
self.last_updated_index = n_past + n_tokens - 1
else:
rows = 1
cols = self._n_vocab
logits = self._ctx.get_logits()[: rows * cols]
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits
logits = np.ctypeslib.as_array(
self._ctx.get_logits(), shape=(rows * cols,)
)
last_token_index = min(n_past + n_tokens - 1, self.scores.shape[0] - 1)
self.scores[last_token_index, :] = logits.reshape(-1)
self.last_updated_index = last_token_index
# Update n_tokens
self.n_tokens += n_tokens