From f34d20922c4d9055dd69940b1bc66b16226d2313 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 29 May 2023 13:31:17 -0300 Subject: [PATCH] Minor fix --- modules/evaluate.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/evaluate.py b/modules/evaluate.py index 866d7f90..3283278e 100644 --- a/modules/evaluate.py +++ b/modules/evaluate.py @@ -82,11 +82,12 @@ def calculate_perplexity(models, input_dataset, stride, _max_length): yield cumulative_log + "Tokenizing the input dataset...\n\n" encodings = encode(text, add_special_tokens=False) seq_len = encodings.shape[1] - if not _max_length: - if hasattr(shared.model.config, 'max_position_embeddings'): - max_length = shared.model.config.max_position_embeddings - else: - max_length = 2048 + if _max_length: + max_length = _max_length + elif hasattr(shared.model.config, 'max_position_embeddings'): + max_length = shared.model.config.max_position_embeddings + else: + max_length = 2048 nlls = [] prev_end_loc = 0