From 4ace5a95b052d338947fc88809a440ccd55cf6b4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 十一月 2023 16:39:52 +0800
Subject: [PATCH] funasr pages

---
 funasr/utils/timestamp_tools.py |   41 ++++++++++++++++++++++++++++++-----------
 1 files changed, 30 insertions(+), 11 deletions(-)

diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 4e7a8a9..6594273 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -1,14 +1,28 @@
-from itertools import zip_longest
-
 import torch
-import copy
 import codecs
 import logging
-import edit_distance
 import argparse
-import pdb
 import numpy as np
-from typing import Any, List, Tuple, Union
+import edit_distance
+from itertools import zip_longest
+
+
+def cif_wo_hidden(alphas, threshold):
+    batch_size, len_time = alphas.size()
+    # loop varss
+    integrate = torch.zeros([batch_size], device=alphas.device)
+    # intermediate vars along time
+    list_fires = []
+    for t in range(len_time):
+        alpha = alphas[:, t]
+        integrate += alpha
+        list_fires.append(integrate)
+        fire_place = integrate >= threshold
+        integrate = torch.where(fire_place,
+                                integrate - torch.ones([batch_size], device=alphas.device)*threshold,
+                                integrate)
+    fires = torch.stack(list_fires, 1)
+    return fires
 
 
 def ts_prediction_lfr6_standard(us_alphas, 
@@ -24,19 +38,24 @@
     MAX_TOKEN_DURATION = 12
     TIME_RATE = 10.0 * 6 / 1000 / 3  #  3 times upsampled
     if len(us_alphas.shape) == 2:
-        _, peaks = us_alphas[0], us_peaks[0]  # support inference batch_size=1 only
+        alphas, peaks = us_alphas[0], us_peaks[0]  # support inference batch_size=1 only
     else:
-        _, peaks = us_alphas, us_peaks
-    num_frames = peaks.shape[0]
+        alphas, peaks = us_alphas, us_peaks
     if char_list[-1] == '</s>':
         char_list = char_list[:-1]
+    fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
+    if len(fire_place) != len(char_list) + 1:
+        alphas /= (alphas.sum() / (len(char_list) + 1))
+        alphas = alphas.unsqueeze(0)
+        peaks = cif_wo_hidden(alphas, threshold=1.0-1e-4)[0]
+        fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
+    num_frames = peaks.shape[0]
     timestamp_list = []
     new_char_list = []
     # for bicif model trained with large data, cif2 actually fires when a character starts
     # so treat the frames between two peaks as the duration of the former token
     fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
-    num_peak = len(fire_place)
-    assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
+    # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
     # begin silence
     if fire_place[0] > START_END_THRESHOLD:
         # char_list.insert(0, '<sil>')

--
Gitblit v1.9.1