From aaed65869acfb8229ad938d38855b5aa8d13cdb6 Mon Sep 17 00:00:00 2001 From: manmay-nakhashi Date: Sun, 16 Jul 2023 16:00:40 +0530 Subject: [PATCH] bug fix --- tortoise/api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tortoise/api.py b/tortoise/api.py index 6b6b32d..68e1482 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -445,7 +445,7 @@ class TextToSpeech: for i in range(batch.shape[0]): batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) if cvvp_amount != 1: - clvp = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) + clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) if auto_conds is not None and cvvp_amount > 0: cvvp_accumulator = 0 for cl in range(auto_conds.shape[1]): @@ -454,9 +454,9 @@ class TextToSpeech: if cvvp_amount == 1: clip_results.append(cvvp) else: - clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount)) + clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount)) else: - clip_results.append(clvp) + clip_results.append(clvp_out) clip_results = torch.cat(clip_results, dim=0) samples = torch.cat(samples, dim=0) best_results = samples[torch.topk(clip_results, k=k).indices]