mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-06 07:03:37 +00:00
Fix CUDA error on MPS backend during API request (#6572)
--------- Co-authored-by: oobabooga <oobabooga4@gmail.com>
This commit is contained in:
parent
979e1f1bd6
commit
13c033c745
5 changed files with 63 additions and 65 deletions
|
|
@ -21,11 +21,12 @@ from transformers import (
|
|||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
GPTQConfig
|
||||
GPTQConfig,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available
|
||||
)
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sampler_hijack
|
||||
from modules.logging_colors import logger
|
||||
from modules.models_settings import get_model_metadata
|
||||
|
||||
|
|
@ -56,8 +57,6 @@ if shared.args.deepspeed:
|
|||
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
|
||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||
|
||||
sampler_hijack.hijack_samplers()
|
||||
|
||||
|
||||
last_generation_time = time.time()
|
||||
|
||||
|
|
@ -172,17 +171,9 @@ def huggingface_loader(model_name):
|
|||
|
||||
model = LoaderClass.from_pretrained(path_to_model, **params)
|
||||
if not (hasattr(model, 'is_loaded_in_4bit') and model.is_loaded_in_4bit):
|
||||
if torch.backends.mps.is_available():
|
||||
device = torch.device('mps')
|
||||
device = get_device()
|
||||
if device:
|
||||
model = model.to(device)
|
||||
elif is_xpu_available():
|
||||
device = torch.device("xpu")
|
||||
model = model.to(device)
|
||||
elif is_npu_available():
|
||||
device = torch.device("npu")
|
||||
model = model.to(device)
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
# DeepSpeed ZeRO-3
|
||||
elif shared.args.deepspeed:
|
||||
|
|
@ -380,13 +371,34 @@ def get_max_memory_dict():
|
|||
return max_memory if len(max_memory) > 0 else None
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device('cuda')
|
||||
elif shared.args.deepspeed:
|
||||
import deepspeed
|
||||
return deepspeed.get_accelerator().current_device_name()
|
||||
elif torch.backends.mps.is_available():
|
||||
return torch.device('mps')
|
||||
elif is_torch_xpu_available():
|
||||
return torch.device('xpu:0')
|
||||
elif is_torch_npu_available():
|
||||
return torch.device('npu:0')
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
if is_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
elif torch.backends.mps.is_available():
|
||||
if hasattr(torch.backends.mps, 'empty_cache'):
|
||||
torch.backends.mps.empty_cache()
|
||||
|
||||
|
||||
def unload_model(keep_model_name=False):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue