mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-04 14:17:28 +00:00
Add Ascend NPU support (basic) (#5541)
This commit is contained in:
parent
a90509d82e
commit
fd4e46bce2
5 changed files with 35 additions and 7 deletions
|
|
@ -10,7 +10,11 @@ import traceback
|
|||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import LogitsProcessorList, is_torch_xpu_available
|
||||
from transformers import (
|
||||
LogitsProcessorList,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available
|
||||
)
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.cache_utils import process_llamacpp_cache
|
||||
|
|
@ -24,7 +28,7 @@ from modules.grammar.grammar_utils import initialize_grammar
|
|||
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
|
||||
from modules.html_generator import generate_basic_html
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import clear_torch_cache, local_rank
|
||||
from modules.models import clear_torch_cache
|
||||
|
||||
|
||||
def generate_reply(*args, **kwargs):
|
||||
|
|
@ -131,12 +135,15 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model'] or shared.args.cpu:
|
||||
return input_ids
|
||||
elif shared.args.deepspeed:
|
||||
return input_ids.to(device=local_rank)
|
||||
import deepspeed
|
||||
return input_ids.to(deepspeed.get_accelerator().current_device_name())
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device('mps')
|
||||
return input_ids.to(device)
|
||||
elif is_torch_xpu_available():
|
||||
return input_ids.to("xpu:0")
|
||||
elif is_torch_npu_available():
|
||||
return input_ids.to("npu:0")
|
||||
else:
|
||||
return input_ids.cuda()
|
||||
|
||||
|
|
@ -213,6 +220,8 @@ def set_manual_seed(seed):
|
|||
torch.cuda.manual_seed_all(seed)
|
||||
elif is_torch_xpu_available():
|
||||
torch.xpu.manual_seed_all(seed)
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.manual_seed_all(seed)
|
||||
|
||||
return seed
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue