Add fallbacks after 93aa7b3ed3

This commit is contained in:
oobabooga 2025-10-09 10:59:34 -07:00
parent 1aa2b924d2
commit 218dc01b51

View file

@ -2,12 +2,27 @@ import gc
import torch
from accelerate.utils import is_npu_available, is_xpu_available
from transformers import is_torch_npu_available, is_torch_xpu_available
from modules import shared
def get_device():
return getattr(shared.model, 'device', None)
if hasattr(shared.model, 'device'):
return shared.model.device
elif torch.cuda.is_available():
return torch.device('cuda')
elif shared.args.deepspeed:
import deepspeed
return deepspeed.get_accelerator().current_device_name()
elif torch.backends.mps.is_available():
return torch.device('mps')
elif is_torch_xpu_available():
return torch.device('xpu:0')
elif is_torch_npu_available():
return torch.device('npu:0')
else:
return None
def clear_torch_cache():