mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-04-04 22:27:29 +00:00
Fix CUDA error on MPS backend during API request (#6572)
--------- Co-authored-by: oobabooga <oobabooga4@gmail.com>
This commit is contained in:
parent
979e1f1bd6
commit
13c033c745
5 changed files with 63 additions and 65 deletions
|
|
@ -5,7 +5,7 @@ import random
|
|||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import LogitsWarper, is_torch_xpu_available
|
||||
from transformers import LogitsWarper
|
||||
from transformers.generation.logits_process import (
|
||||
LogitNormalization,
|
||||
LogitsProcessor,
|
||||
|
|
@ -14,6 +14,7 @@ from transformers.generation.logits_process import (
|
|||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import get_device
|
||||
|
||||
global_scores = None
|
||||
|
||||
|
|
@ -339,12 +340,12 @@ class MirostatLogitsWarper(LogitsWarper):
|
|||
break
|
||||
|
||||
# Normalize the probabilities of the remaining words
|
||||
if is_torch_xpu_available():
|
||||
prob_topk = torch.softmax(sorted_logits, dim=0).to("xpu")
|
||||
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to("xpu")
|
||||
else:
|
||||
prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda')
|
||||
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda')
|
||||
prob_topk = torch.softmax(sorted_logits, dim=0)
|
||||
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True)
|
||||
device = get_device()
|
||||
if device:
|
||||
prob_topk = prob_topk.to(device)
|
||||
prev_i = prev_i.to(device)
|
||||
|
||||
observed_surprise = -math.log2(prob_topk[prev_i])
|
||||
self.e = observed_surprise - self.mirostat_tau
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue