Add Ascend NPU support (basic) (#5541)

This commit is contained in:
wangshuai09 2024-04-12 05:42:20 +08:00 committed by GitHub
parent a90509d82e
commit fd4e46bce2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 35 additions and 7 deletions

View file

@ -10,7 +10,11 @@ import traceback
import numpy as np
import torch
import transformers
from transformers import LogitsProcessorList, is_torch_xpu_available
from transformers import (
LogitsProcessorList,
is_torch_npu_available,
is_torch_xpu_available
)
import modules.shared as shared
from modules.cache_utils import process_llamacpp_cache
@ -24,7 +28,7 @@ from modules.grammar.grammar_utils import initialize_grammar
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
from modules.html_generator import generate_basic_html
from modules.logging_colors import logger
from modules.models import clear_torch_cache, local_rank
from modules.models import clear_torch_cache
def generate_reply(*args, **kwargs):
@ -131,12 +135,15 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model'] or shared.args.cpu:
return input_ids
elif shared.args.deepspeed:
return input_ids.to(device=local_rank)
import deepspeed
return input_ids.to(deepspeed.get_accelerator().current_device_name())
elif torch.backends.mps.is_available():
device = torch.device('mps')
return input_ids.to(device)
elif is_torch_xpu_available():
return input_ids.to("xpu:0")
elif is_torch_npu_available():
return input_ids.to("npu:0")
else:
return input_ids.cuda()
@ -213,6 +220,8 @@ def set_manual_seed(seed):
torch.cuda.manual_seed_all(seed)
elif is_torch_xpu_available():
torch.xpu.manual_seed_all(seed)
elif is_torch_npu_available():
torch.npu.manual_seed_all(seed)
return seed