diff --git a/modules/torch_utils.py b/modules/torch_utils.py index ad9b26ad..1bc82f03 100644 --- a/modules/torch_utils.py +++ b/modules/torch_utils.py @@ -2,25 +2,12 @@ import gc import torch from accelerate.utils import is_npu_available, is_xpu_available -from transformers import is_torch_npu_available, is_torch_xpu_available from modules import shared def get_device(): - if torch.cuda.is_available(): - return torch.device('cuda') - elif shared.args.deepspeed: - import deepspeed - return deepspeed.get_accelerator().current_device_name() - elif torch.backends.mps.is_available(): - return torch.device('mps') - elif is_torch_xpu_available(): - return torch.device('xpu:0') - elif is_torch_npu_available(): - return torch.device('npu:0') - else: - return None + return getattr(shared.model, 'device', None) def clear_torch_cache():