Fix CUDA error on MPS backend during API request (#6572)

---------

Co-authored-by: oobabooga <oobabooga4@gmail.com>
This commit is contained in:
Petr Korolev 2025-01-02 06:06:11 +03:00 committed by GitHub
parent 979e1f1bd6
commit 13c033c745
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 63 additions and 65 deletions

View file

@ -5,7 +5,7 @@ import random
import torch
import transformers
from transformers import LogitsWarper, is_torch_xpu_available
from transformers import LogitsWarper
from transformers.generation.logits_process import (
LogitNormalization,
LogitsProcessor,
@ -14,6 +14,7 @@ from transformers.generation.logits_process import (
from modules import shared
from modules.logging_colors import logger
from modules.models import get_device
global_scores = None
@ -339,12 +340,12 @@ class MirostatLogitsWarper(LogitsWarper):
break
# Normalize the probabilities of the remaining words
if is_torch_xpu_available():
prob_topk = torch.softmax(sorted_logits, dim=0).to("xpu")
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to("xpu")
else:
prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda')
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda')
prob_topk = torch.softmax(sorted_logits, dim=0)
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True)
device = get_device()
if device:
prob_topk = prob_topk.to(device)
prev_i = prev_i.to(device)
observed_surprise = -math.log2(prob_topk[prev_i])
self.e = observed_surprise - self.mirostat_tau