mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-05 06:35:15 +00:00
Fix CUDA error on MPS backend during API request (#6572)
--------- Co-authored-by: oobabooga <oobabooga4@gmail.com>
This commit is contained in:
parent
979e1f1bd6
commit
13c033c745
5 changed files with 63 additions and 65 deletions
|
|
@ -2,11 +2,10 @@ import time
|
|||
import traceback
|
||||
|
||||
import torch
|
||||
from transformers import is_torch_npu_available, is_torch_xpu_available
|
||||
|
||||
from modules import models, sampler_hijack, shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import load_model
|
||||
from modules.models import get_device, load_model
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
global_scores = None
|
||||
|
|
@ -57,23 +56,21 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
|
|||
scores = sampler_hijack.global_scores[-1]
|
||||
else:
|
||||
if is_non_hf_exllamav2:
|
||||
if is_torch_xpu_available():
|
||||
tokens = shared.tokenizer.encode(prompt).to("xpu:0")
|
||||
elif is_torch_npu_available():
|
||||
tokens = shared.tokenizer.encode(prompt).to("npu:0")
|
||||
else:
|
||||
tokens = shared.tokenizer.encode(prompt).cuda()
|
||||
device = get_device()
|
||||
tokens = shared.tokenizer.encode(prompt)
|
||||
if device:
|
||||
tokens = tokens.to(device)
|
||||
|
||||
scores = shared.model.get_logits(tokens)[-1][-1]
|
||||
elif is_non_hf_llamacpp:
|
||||
tokens = shared.tokenizer.encode(prompt)
|
||||
scores = shared.model.get_logits(tokens)[-1][-1]
|
||||
else:
|
||||
if is_torch_xpu_available():
|
||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("xpu:0")
|
||||
elif is_torch_npu_available():
|
||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').to("npu:0")
|
||||
else:
|
||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda()
|
||||
device = get_device()
|
||||
tokens = shared.tokenizer.encode(prompt, return_tensors='pt')
|
||||
if device:
|
||||
tokens = tokens.to(device)
|
||||
|
||||
output = shared.model(input_ids=tokens)
|
||||
scores = output['logits'][-1][-1]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue