From 93aa7b3ed3f770c25e21594bafe41bb2dded2caa Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 9 Oct 2025 08:49:27 -0700 Subject: [PATCH] Better handle multigpu setups with transformers + bitsandbytes --- modules/torch_utils.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/modules/torch_utils.py b/modules/torch_utils.py index ad9b26ad..1bc82f03 100644 --- a/modules/torch_utils.py +++ b/modules/torch_utils.py @@ -2,25 +2,12 @@ import gc import torch from accelerate.utils import is_npu_available, is_xpu_available -from transformers import is_torch_npu_available, is_torch_xpu_available from modules import shared def get_device(): - if torch.cuda.is_available(): - return torch.device('cuda') - elif shared.args.deepspeed: - import deepspeed - return deepspeed.get_accelerator().current_device_name() - elif torch.backends.mps.is_available(): - return torch.device('mps') - elif is_torch_xpu_available(): - return torch.device('xpu:0') - elif is_torch_npu_available(): - return torch.device('npu:0') - else: - return None + return getattr(shared.model, 'device', None) def clear_torch_cache():