mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-20 22:13:43 +00:00
Refactor the transformers loader (#6859)
This commit is contained in:
parent
6ba0164c70
commit
ae02ffc605
18 changed files with 464 additions and 528 deletions
37
modules/torch_utils.py
Normal file
37
modules/torch_utils.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import gc
|
||||
|
||||
import torch
|
||||
from accelerate.utils import is_npu_available, is_xpu_available
|
||||
from transformers import is_torch_npu_available, is_torch_xpu_available
|
||||
|
||||
from modules import shared
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device('cuda')
|
||||
elif shared.args.deepspeed:
|
||||
import deepspeed
|
||||
return deepspeed.get_accelerator().current_device_name()
|
||||
elif torch.backends.mps.is_available():
|
||||
return torch.device('mps')
|
||||
elif is_torch_xpu_available():
|
||||
return torch.device('xpu:0')
|
||||
elif is_torch_npu_available():
|
||||
return torch.device('npu:0')
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
elif torch.backends.mps.is_available():
|
||||
if hasattr(torch.backends.mps, 'empty_cache'):
|
||||
torch.backends.mps.empty_cache()
|
||||
Loading…
Add table
Add a link
Reference in a new issue