diff --git a/tortoise/models/diffusion_decoder.py b/tortoise/models/diffusion_decoder.py index 9d66e87..f67d21a 100644 --- a/tortoise/models/diffusion_decoder.py +++ b/tortoise/models/diffusion_decoder.py @@ -5,7 +5,7 @@ from abc import abstractmethod import torch import torch.nn as nn import torch.nn.functional as F -from torch.cuda.amp import autocast +from torch import autocast from tortoise.models.arch_util import normalization, AttentionBlock