import deepspeed only if use_deepspeed is True

This commit is contained in:
manmay-nakhashi 2023-07-10 07:47:46 +05:30
parent 31a2e153ff
commit 82724cca54

View file

@ -3,7 +3,6 @@ import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import deepspeed
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
@ -362,6 +361,7 @@ class UnifiedVoice(nn.Module):
self.mel_head
)
if use_deepspeed:
import deepspeed
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
mp_size=1,
replace_with_kernel_inject=True,