diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py index 3bf44c9b..0f742fa2 100644 --- a/modules/exllamav3_hf.py +++ b/modules/exllamav3_hf.py @@ -153,6 +153,9 @@ class Exllamav3HF(PreTrainedModel): else: self.past_seq = seq_tensor + if torch.cuda.is_available(): + torch.cuda.synchronize() + loss = None if labels is not None: # Shift so that tokens < n predict n