From 218dc01b5189b2dcf08045bf0a8f17e4fad176a2 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 9 Oct 2025 10:59:34 -0700 Subject: [PATCH] Add fallbacks after 93aa7b3ed3f770c25e21594bafe41bb2dded2caa --- modules/torch_utils.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/modules/torch_utils.py b/modules/torch_utils.py index 1bc82f03..418520a8 100644 --- a/modules/torch_utils.py +++ b/modules/torch_utils.py @@ -2,12 +2,27 @@ import gc import torch from accelerate.utils import is_npu_available, is_xpu_available +from transformers import is_torch_npu_available, is_torch_xpu_available from modules import shared def get_device(): - return getattr(shared.model, 'device', None) + if hasattr(shared.model, 'device'): + return shared.model.device + elif torch.cuda.is_available(): + return torch.device('cuda') + elif shared.args.deepspeed: + import deepspeed + return deepspeed.get_accelerator().current_device_name() + elif torch.backends.mps.is_available(): + return torch.device('mps') + elif is_torch_xpu_available(): + return torch.device('xpu:0') + elif is_torch_npu_available(): + return torch.device('npu:0') + else: + return None def clear_torch_cache():