mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-04 14:17:28 +00:00
Fix CUDA error on MPS backend during API request (#6572)
--------- Co-authored-by: oobabooga <oobabooga4@gmail.com>
This commit is contained in:
parent
979e1f1bd6
commit
13c033c745
5 changed files with 63 additions and 65 deletions
|
|
@ -16,7 +16,7 @@ from transformers import (
|
|||
)
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import models
|
||||
from modules import models, sampler_hijack
|
||||
from modules.cache_utils import process_llamacpp_cache
|
||||
from modules.callbacks import (
|
||||
Iteratorize,
|
||||
|
|
@ -28,7 +28,9 @@ from modules.grammar.grammar_utils import initialize_grammar
|
|||
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
|
||||
from modules.html_generator import generate_basic_html
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import clear_torch_cache, load_model
|
||||
from modules.models import clear_torch_cache, get_device, load_model
|
||||
|
||||
sampler_hijack.hijack_samplers()
|
||||
|
||||
|
||||
def generate_reply(*args, **kwargs):
|
||||
|
|
@ -159,18 +161,12 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
|
||||
return input_ids
|
||||
elif shared.args.deepspeed:
|
||||
import deepspeed
|
||||
return input_ids.to(deepspeed.get_accelerator().current_device_name())
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device('mps')
|
||||
return input_ids.to(device)
|
||||
elif is_torch_xpu_available():
|
||||
return input_ids.to("xpu:0")
|
||||
elif is_torch_npu_available():
|
||||
return input_ids.to("npu:0")
|
||||
else:
|
||||
return input_ids.cuda()
|
||||
device = get_device()
|
||||
if device:
|
||||
return input_ids.to(device)
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
def decode(output_ids, skip_special_tokens=True):
|
||||
|
|
@ -328,7 +324,6 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||
# Encode the input
|
||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
output = input_ids[0]
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed))
|
||||
if state['auto_max_new_tokens']:
|
||||
generate_params['max_new_tokens'] = state['truncation_length'] - input_ids.shape[-1]
|
||||
|
||||
|
|
@ -383,8 +378,9 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||
if not state['stream']:
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
if cuda:
|
||||
output = output.cuda()
|
||||
device = get_device()
|
||||
if device:
|
||||
output = output.to(device)
|
||||
|
||||
starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
|
||||
yield get_reply_from_output_ids(output, state, starting_from=starting_from)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue