EXL2: add another torch.cuda.synchronize() call to prevent errors

This commit is contained in:
oobabooga 2025-04-24 09:03:49 -07:00
parent b313adf653
commit f1b64df8dd

View file

@ -264,6 +264,11 @@ def apply_stopping_strings(reply, all_stop_strings):
def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
import torch
if torch.cuda.is_available():
torch.cuda.synchronize()
reply = decode(output_ids[starting_from:], state['skip_special_tokens'] if state else True)
# Handle tokenizers that do not add the leading space for the first token