From 2e336050db6b2c634f9ae5f7c4ef6710fd641822 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 07 二月 2023 20:40:53 +0800
Subject: [PATCH] Merge pull request #72 from alibaba-damo-academy/dev_lzr

---
 funasr/models/predictor/cif.py |    5 +++--
 1 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index 60cf902..00c5a3e 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -208,7 +208,8 @@
             mask_2 = torch.cat([ones_t, mask], dim=1)
             mask = mask_2 - mask_1
             tail_threshold = mask * tail_threshold
-            alphas = torch.cat([alphas, tail_threshold], dim=1)
+            alphas = torch.cat([alphas, zeros_t], dim=1)
+            alphas = torch.add(alphas, tail_threshold)
         else:
             tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
             tail_threshold = torch.reshape(tail_threshold, (1, 1))
@@ -654,4 +655,4 @@
 
         predictor_alignments = index_div_bool_zeros_count_tile_out
         predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
-        return predictor_alignments.detach(), predictor_alignments_length.detach()
\ No newline at end of file
+        return predictor_alignments.detach(), predictor_alignments_length.detach()

--
Gitblit v1.9.1