From d783b24ba7d8a03dabfa2139fcbf40c216e0ea3d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 三月 2023 19:34:52 +0800
Subject: [PATCH] Merge pull request #199 from alibaba-damo-academy/dev_xw

---
 funasr/models/predictor/cif.py |   16 +++++++++-------
 1 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index 60cf902..5615373 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -68,7 +68,8 @@
             mask_2 = torch.cat([ones_t, mask], dim=1)
             mask = mask_2 - mask_1
             tail_threshold = mask * tail_threshold
-            alphas = torch.cat([alphas, tail_threshold], dim=1)
+            alphas = torch.cat([alphas, zeros_t], dim=1)
+            alphas = torch.add(alphas, tail_threshold)
         else:
             tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
             tail_threshold = torch.reshape(tail_threshold, (1, 1))
@@ -208,7 +209,8 @@
             mask_2 = torch.cat([ones_t, mask], dim=1)
             mask = mask_2 - mask_1
             tail_threshold = mask * tail_threshold
-            alphas = torch.cat([alphas, tail_threshold], dim=1)
+            alphas = torch.cat([alphas, zeros_t], dim=1)
+            alphas = torch.add(alphas, tail_threshold)
         else:
             tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
             tail_threshold = torch.reshape(tail_threshold, (1, 1))
@@ -542,9 +544,8 @@
             token_num_int = torch.max(token_num).type(torch.int32).item()
             acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
         return acoustic_embeds, token_num, alphas, cif_peak, token_num2
-    
-    def get_upsample_timestamp(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
-                target_label_length=None, token_num=None):
+
+    def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
         h = hidden
         b = hidden.shape[0]
         context = h.transpose(1, 2)
@@ -596,7 +597,8 @@
             mask_2 = torch.cat([ones_t, mask], dim=1)
             mask = mask_2 - mask_1
             tail_threshold = mask * tail_threshold
-            alphas = torch.cat([alphas, tail_threshold], dim=1)
+            alphas = torch.cat([alphas, zeros_t], dim=1)
+            alphas = torch.add(alphas, tail_threshold)
         else:
             tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
             tail_threshold = torch.reshape(tail_threshold, (1, 1))
@@ -654,4 +656,4 @@
 
         predictor_alignments = index_div_bool_zeros_count_tile_out
         predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
-        return predictor_alignments.detach(), predictor_alignments_length.detach()
\ No newline at end of file
+        return predictor_alignments.detach(), predictor_alignments_length.detach()

--
Gitblit v1.9.1