This commit is contained in:
oobabooga 2024-06-12 19:00:21 -07:00
parent 2d196ed2fe
commit a36fa73071
2 changed files with 3 additions and 5 deletions

View file

@ -218,7 +218,7 @@ class DRYLogitsProcessor(LogitsProcessor):
match_lengths = {}
for i in match_indices:
next_token = input_ids_row[i+1].item()
next_token = input_ids_row[i + 1].item()
if next_token in self.sequence_breakers:
continue
@ -234,7 +234,7 @@ class DRYLogitsProcessor(LogitsProcessor):
# Start of input reached.
break
previous_token = input_ids_row[-(match_length+1)].item()
previous_token = input_ids_row[-(match_length + 1)].item()
if input_ids_row[j] != previous_token:
# Start of match reached.
break