diff --git a/modules/torch_utils.py b/modules/torch_utils.py index 1bc82f03..418520a8 100644 --- a/modules/torch_utils.py +++ b/modules/torch_utils.py @@ -2,12 +2,27 @@ import gc import torch from accelerate.utils import is_npu_available, is_xpu_available +from transformers import is_torch_npu_available, is_torch_xpu_available from modules import shared def get_device(): - return getattr(shared.model, 'device', None) + if hasattr(shared.model, 'device'): + return shared.model.device + elif torch.cuda.is_available(): + return torch.device('cuda') + elif shared.args.deepspeed: + import deepspeed + return deepspeed.get_accelerator().current_device_name() + elif torch.backends.mps.is_available(): + return torch.device('mps') + elif is_torch_xpu_available(): + return torch.device('xpu:0') + elif is_torch_npu_available(): + return torch.device('npu:0') + else: + return None def clear_torch_cache():