From 7b2654f3576e1f1cb8381c8fc169b8dfccd8449d Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 09 二月 2023 17:25:21 +0800
Subject: [PATCH] readme
---
funasr/models/predictor/cif.py | 6 ++++--
1 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index 00c5a3e..c34759d 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -68,7 +68,8 @@
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
- alphas = torch.cat([alphas, tail_threshold], dim=1)
+ alphas = torch.cat([alphas, zeros_t], dim=1)
+ alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
@@ -597,7 +598,8 @@
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
- alphas = torch.cat([alphas, tail_threshold], dim=1)
+ alphas = torch.cat([alphas, zeros_t], dim=1)
+ alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
--
Gitblit v1.9.1