From 5e7a8d1ccae80e54f2e2ecfffdf8e4294800b5c3 Mon Sep 17 00:00:00 2001
From: Legend <me@liux.pro>
Date: 星期日, 15 十二月 2024 01:47:12 +0800
Subject: [PATCH] Update readme_zh.md (#2312)
---
funasr/models/paraformer/cif_predictor.py | 10 +++++-----
1 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 535131f..d597050 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -245,7 +245,7 @@
hidden, alphas, token_num, mask=None
)
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+ acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
@@ -522,7 +522,7 @@
fires = fires + prefix_sum - prefix_sum_floor
# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
- prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
+ prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
@@ -534,7 +534,7 @@
remains = fires - torch.floor(fires)
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
- remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
+ remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0
@@ -702,7 +702,7 @@
# frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
# prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
- prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
+ prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
@@ -715,7 +715,7 @@
remains = fires - torch.floor(fires)
# remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
- remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
+ remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0
--
Gitblit v1.9.1