Add Ascend NPU support (basic) (#5541)

This commit is contained in:
wangshuai09 2024-04-12 05:42:20 +08:00 committed by GitHub
parent a90509d82e
commit fd4e46bce2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 35 additions and 7 deletions

View file

@ -1,5 +1,5 @@
import torch
from transformers import is_torch_xpu_available
from transformers import is_torch_npu_available, is_torch_xpu_available
from modules import sampler_hijack, shared
from modules.logging_colors import logger
@ -34,6 +34,8 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return
if is_non_hf_exllamav2:
if is_torch_xpu_available():
tokens = shared.tokenizer.encode(prompt).to("xpu:0")
elif is_torch_npu_available():
tokens = shared.tokenizer.encode(prompt).to("npu:0")
else:
tokens = shared.tokenizer.encode(prompt).cuda()
scores = shared.model.get_logits(tokens)[-1][-1]
@ -43,6 +45,8 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return
else:
if is_torch_xpu_available():
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("xpu:0")
elif is_torch_npu_available():
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("npu:0")
else:
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
output = shared.model(input_ids=tokens)