mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-04 14:17:28 +00:00
Add Ascend NPU support (basic) (#5541)
This commit is contained in:
parent
a90509d82e
commit
fd4e46bce2
5 changed files with 35 additions and 7 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from transformers import is_torch_xpu_available
|
||||
from transformers import is_torch_npu_available, is_torch_xpu_available
|
||||
|
||||
from modules import sampler_hijack, shared
|
||||
from modules.logging_colors import logger
|
||||
|
|
@ -34,6 +34,8 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return
|
|||
if is_non_hf_exllamav2:
|
||||
if is_torch_xpu_available():
|
||||
tokens = shared.tokenizer.encode(prompt).to("xpu:0")
|
||||
elif is_torch_npu_available():
|
||||
tokens = shared.tokenizer.encode(prompt).to("npu:0")
|
||||
else:
|
||||
tokens = shared.tokenizer.encode(prompt).cuda()
|
||||
scores = shared.model.get_logits(tokens)[-1][-1]
|
||||
|
|
@ -43,6 +45,8 @@ def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return
|
|||
else:
|
||||
if is_torch_xpu_available():
|
||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("xpu:0")
|
||||
elif is_torch_npu_available():
|
||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("npu:0")
|
||||
else:
|
||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
|
||||
output = shared.model(input_ids=tokens)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue