From a9e857e45250b16af60d5fe3efcd06e685f6506a Mon Sep 17 00:00:00 2001
From: lzr265946 <lzr265946@alibaba-inc.com>
Date: 星期六, 03 十二月 2022 16:39:38 +0800
Subject: [PATCH] update funasr 0.1.3

---
 funasr/models/predictor/cif.py |    4 ++--
 1 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index 8199708..ea41c6c 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -4,7 +4,7 @@
 from funasr.modules.nets_utils import make_pad_mask
 
 class CifPredictor(nn.Module):
-    def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0):
+    def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
         super(CifPredictor, self).__init__()
 
         self.pad = nn.ConstantPad1d((l_order, r_order), 0)
@@ -147,7 +147,7 @@
         b, t, d = hidden.size()
         tail_threshold = self.tail_threshold
         tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
-        tail_threshold = torch.reshape(tail_threshold, (1, 1))
+        tail_threshold = tail_threshold.unsqueeze(0).repeat(b, 1)
         alphas = torch.cat([alphas, tail_threshold], dim=1)
         zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
         hidden = torch.cat([hidden, zeros], dim=1)

--
Gitblit v1.9.1