From a9e857e45250b16af60d5fe3efcd06e685f6506a Mon Sep 17 00:00:00 2001
From: lzr265946 <lzr265946@alibaba-inc.com>
Date: 星期六, 03 十二月 2022 16:39:38 +0800
Subject: [PATCH] update funasr 0.1.3
---
funasr/models/predictor/cif.py | 4 ++--
1 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index 8199708..ea41c6c 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -4,7 +4,7 @@
from funasr.modules.nets_utils import make_pad_mask
class CifPredictor(nn.Module):
- def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0):
+ def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
super(CifPredictor, self).__init__()
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
@@ -147,7 +147,7 @@
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
- tail_threshold = torch.reshape(tail_threshold, (1, 1))
+ tail_threshold = tail_threshold.unsqueeze(0).repeat(b, 1)
alphas = torch.cat([alphas, tail_threshold], dim=1)
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
hidden = torch.cat([hidden, zeros], dim=1)
--
Gitblit v1.9.1