Update llama_cpp_python_hijack.py, fix llamacpp_hf

This commit is contained in:
oobabooga 2024-09-30 14:04:21 -07:00
parent 9ca0cd7749
commit 4d9ce586d3
2 changed files with 14 additions and 8 deletions

View file

@ -127,7 +127,7 @@ class LlamacppHF(PreTrainedModel):
self.model.reset()
self.model.eval(seq)
logits = torch.tensor(self.model.scores[self.model.n_tokens - 1, :]).view(1, 1, -1).to(input_ids.device)
logits = torch.tensor(self.model.scores[self.model.last_updated_index, :]).view(1, 1, -1).to(input_ids.device)
else:
self.model.reset()
self.model.eval(seq)
@ -205,5 +205,6 @@ class LlamacppHF(PreTrainedModel):
Llama = llama_cpp_lib().Llama
model = Llama(**params)
model.last_updated_index = -1
return LlamacppHF(model, model_file)