From e2b3edec45fe1bfd76493acc366df971e89e7ae2 Mon Sep 17 00:00:00 2001
From: Xian Shi <40013335+R1ckShi@users.noreply.github.com>
Date: 星期四, 10 八月 2023 17:30:34 +0800
Subject: [PATCH] Merge pull request #830 from alibaba-damo-academy/dev_ts
---
funasr/utils/timestamp_tools.py | 38 +++++++++++++++++++++++++++-----------
1 files changed, 27 insertions(+), 11 deletions(-)
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 5787f1d..c194179 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -7,6 +7,24 @@
from itertools import zip_longest
+def cif_wo_hidden(alphas, threshold):
+ batch_size, len_time = alphas.size()
+ # loop varss
+ integrate = torch.zeros([batch_size], device=alphas.device)
+ # intermediate vars along time
+ list_fires = []
+ for t in range(len_time):
+ alpha = alphas[:, t]
+ integrate += alpha
+ list_fires.append(integrate)
+ fire_place = integrate >= threshold
+ integrate = torch.where(fire_place,
+ integrate - torch.ones([batch_size], device=alphas.device),
+ integrate)
+ fires = torch.stack(list_fires, 1)
+ return fires
+
+
def ts_prediction_lfr6_standard(us_alphas,
us_peaks,
char_list,
@@ -20,25 +38,23 @@
MAX_TOKEN_DURATION = 12
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
if len(us_alphas.shape) == 2:
- _, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
+ alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
else:
- _, peaks = us_alphas, us_peaks
- num_frames = peaks.shape[0]
+ alphas, peaks = us_alphas, us_peaks
if char_list[-1] == '</s>':
char_list = char_list[:-1]
+ fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
+ if len(fire_place) != len(char_list) + 1:
+ alphas /= (alphas.sum() / (len(char_list) + 1))
+ alphas = alphas.unsqueeze(0)
+ peaks = cif_wo_hidden(alphas, threshold=1.0-1e-4)[0]
+ fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
+ num_frames = peaks.shape[0]
timestamp_list = []
new_char_list = []
# for bicif model trained with large data, cif2 actually fires when a character starts
# so treat the frames between two peaks as the duration of the former token
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
- num_peak = len(fire_place)
- if num_peak != len(char_list) + 1:
- logging.warning("length mismatch, result might be incorrect.")
- logging.warning("num_peaks: {}, num_chars+1: {}, which is supposed to be same.".format(num_peak, len(char_list)+1))
- if num_peak > len(char_list) + 1:
- fire_place = fire_place[:len(char_list) - 1]
- elif num_peak < len(char_list) + 1:
- char_list = char_list[:num_peak + 1]
# assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
# begin silence
if fire_place[0] > START_END_THRESHOLD:
--
Gitblit v1.9.1