Better handle multigpu setups with transformers + bitsandbytes

This commit is contained in:
oobabooga 2025-10-09 08:49:27 -07:00
parent d229dfe991
commit 93aa7b3ed3

View file

@ -2,25 +2,12 @@ import gc
import torch import torch
from accelerate.utils import is_npu_available, is_xpu_available from accelerate.utils import is_npu_available, is_xpu_available
from transformers import is_torch_npu_available, is_torch_xpu_available
from modules import shared from modules import shared
def get_device(): def get_device():
if torch.cuda.is_available(): return getattr(shared.model, 'device', None)
return torch.device('cuda')
elif shared.args.deepspeed:
import deepspeed
return deepspeed.get_accelerator().current_device_name()
elif torch.backends.mps.is_available():
return torch.device('mps')
elif is_torch_xpu_available():
return torch.device('xpu:0')
elif is_torch_npu_available():
return torch.device('npu:0')
else:
return None
def clear_torch_cache(): def clear_torch_cache():