From 20aa07268a7fafaaab7762b488615af32a0e82b4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 11 六月 2024 14:02:27 +0800
Subject: [PATCH] update with main (#1800)

---
 funasr/models/paraformer/model.py             |   24 ++++++-
 funasr/utils/timestamp_tools.py               |    6 +-
 funasr/auto/auto_model.py                     |    4 +
 runtime/docs/SDK_advanced_guide_offline_zh.md |    1 
 funasr/models/paraformer/cif_predictor.py     |  127 +++++++++++++++++++++++++++++++++++++++++-
 funasr/models/llm_asr/adaptor.py              |    1 
 docs/images/wechat.png                        |    0 
 7 files changed, 153 insertions(+), 10 deletions(-)

diff --git a/docs/images/wechat.png b/docs/images/wechat.png
index b5d9a55..491a1c3 100644
--- a/docs/images/wechat.png
+++ b/docs/images/wechat.png
Binary files differ
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 22b1ac0..047e652 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -429,6 +429,10 @@
             #                      f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
             #                      f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
 
+            if len(results_sorted) != n:
+                results_ret_list.append({"key": key, "text": "", "timestamp": []})
+                logging.info("decoding, utt: {}, empty result".format(key))
+                continue
             restored_data = [0] * n
             for j in range(n):
                 index = sorted_data[j][1]
diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py
index c939883..93534fe 100644
--- a/funasr/models/llm_asr/adaptor.py
+++ b/funasr/models/llm_asr/adaptor.py
@@ -125,6 +125,7 @@
         olens = None
         olens = (ilens - 1) // self.k + 1
         masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
+
         if self.blocks is not None:
             for layer, block in enumerate(self.blocks):
                 x, masks = block(x, masks)
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 8b1a9bb..7490310 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -80,7 +80,7 @@
                     hidden, alphas, token_num, mask=mask
                 )
 
-            acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+            acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
 
             if target_length is None and self.tail_threshold > 0.0:
                 token_num_int = torch.max(token_num).type(torch.int32).item()
@@ -245,7 +245,7 @@
                         hidden, alphas, token_num, mask=None
                     )
 
-            acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+            acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
             if target_length is None and self.tail_threshold > 0.0:
                 token_num_int = torch.max(token_num).type(torch.int32).item()
                 acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
@@ -449,7 +449,7 @@
         mask = mask.transpose(-1, -2).float()
         mask = mask.squeeze(-1)
         hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
-        acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
+        acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold)
 
         return acoustic_embeds, token_num, alphas, cif_peak
 
@@ -494,7 +494,60 @@
         token_num_floor = torch.floor(token_num)
 
         return hidden, alphas, token_num_floor
+@torch.jit.script
+def cif_v1_export(hidden, alphas, threshold: float):
+    device = hidden.device
+    dtype = hidden.dtype
+    batch_size, len_time, hidden_size = hidden.size()
+    threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
 
+    frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
+    fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
+
+    prefix_sum = torch.cumsum(alphas, dim=1)
+    prefix_sum_floor = torch.floor(prefix_sum)
+    dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
+    dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
+
+    dislocation_prefix_sum_floor[:, 0] = 0
+    dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
+
+    fire_idxs = dislocation_diff > 0
+    fires[fire_idxs] = 1
+    fires = fires + prefix_sum - prefix_sum_floor
+
+    prefix_sum_hidden = torch.cumsum(
+        alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
+    )
+
+    frames = prefix_sum_hidden[fire_idxs]
+    shift_frames = torch.roll(frames, 1, dims=0)
+
+    batch_len = fire_idxs.sum(1)
+    batch_idxs = torch.cumsum(batch_len, dim=0)
+    shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
+    shift_batch_idxs[0] = 0
+    shift_frames[shift_batch_idxs] = 0
+
+    remains = fires - torch.floor(fires)
+    remain_frames = (
+        remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
+    )
+
+    shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
+    shift_remain_frames[shift_batch_idxs] = 0
+
+    frames = frames - shift_frames + shift_remain_frames - remain_frames
+
+    max_label_len = batch_len.max()
+
+    frame_fires = torch.zeros(
+        batch_size, max_label_len, hidden_size, dtype=dtype, device=device
+    )
+    indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
+    frame_fires_idxs = indices < batch_len.unsqueeze(1)
+    frame_fires[frame_fires_idxs] = frames
+    return frame_fires, fires
 
 @torch.jit.script
 def cif_export(hidden, alphas, threshold: float):
@@ -608,6 +661,74 @@
     return torch.stack(list_ls, 0), fires
 
 
+def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
+    batch_size, len_time = alphas.size()
+    device = alphas.device
+    dtype = alphas.dtype
+
+    threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
+
+    fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
+
+    prefix_sum = torch.cumsum(alphas, dim=1)
+    prefix_sum_floor = torch.floor(prefix_sum)
+    dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
+    dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
+
+    dislocation_prefix_sum_floor[:, 0] = 0
+    dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
+
+    fire_idxs = dislocation_diff > 0
+    fires[fire_idxs] = 1
+    fires = fires + prefix_sum - prefix_sum_floor
+    if return_fire_idxs:
+        return fires, fire_idxs
+    return fires
+
+
+def cif_v1(hidden, alphas, threshold):
+    fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)
+
+    device = hidden.device
+    dtype = hidden.dtype
+    batch_size, len_time, hidden_size = hidden.size()
+    frames = torch.zeros(batch_size, len_time, hidden_size,
+                         dtype=dtype, device=device)
+    prefix_sum_hidden = torch.cumsum(
+        alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
+    )
+
+    frames = prefix_sum_hidden[fire_idxs]
+    shift_frames = torch.roll(frames, 1, dims=0)
+
+    batch_len = fire_idxs.sum(1)
+    batch_idxs = torch.cumsum(batch_len, dim=0)
+    shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
+    shift_batch_idxs[0] = 0
+    shift_frames[shift_batch_idxs] = 0
+
+    remains = fires - torch.floor(fires)
+    remain_frames = (
+        remains[fire_idxs].unsqueeze(-1).tile((1,
+                                               hidden_size)) * hidden[fire_idxs]
+    )
+
+    shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
+    shift_remain_frames[shift_batch_idxs] = 0
+
+    frames = frames - shift_frames + shift_remain_frames - remain_frames
+
+    max_label_len = batch_len.max()
+
+    frame_fires = torch.zeros(
+        batch_size, max_label_len, hidden_size, dtype=dtype, device=device
+    )
+    indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
+    frame_fires_idxs = indices < batch_len.unsqueeze(1)
+    frame_fires[frame_fires_idxs] = frames
+    return frame_fires, fires
+
+
 def cif_wo_hidden(alphas, threshold):
     batch_size, len_time = alphas.size()
 
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 0d9bb2b..85967af 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -4,6 +4,7 @@
 #  MIT License  (https://opensource.org/licenses/MIT)
 
 import time
+import copy
 import torch
 import logging
 from torch.cuda.amp import autocast
@@ -21,6 +22,7 @@
 from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
 from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
 
@@ -452,6 +454,7 @@
         is_use_lm = (
             kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
         )
+        pred_timestamp = kwargs.get("pred_timestamp", False)
         if self.beam_search is None and (is_use_lm or is_use_ctc):
             logging.info("enable beam_search")
             self.init_beam_search(**kwargs)
@@ -506,6 +509,7 @@
             predictor_outs[2],
             predictor_outs[3],
         )
+        
         pre_token_length = pre_token_length.round().long()
         if torch.max(pre_token_length) < 1:
             return []
@@ -564,10 +568,22 @@
                     # Change integer-ids to tokens
                     token = tokenizer.ids2tokens(token_int)
                     text_postprocessed = tokenizer.tokens2text(token)
-                    if not hasattr(tokenizer, "bpemodel"):
-                        text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-
-                    result_i = {"key": key[i], "text": text_postprocessed}
+                    
+                    if pred_timestamp:
+                        timestamp_str, timestamp = ts_prediction_lfr6_standard(
+                            pre_peak_index[i],
+                            alphas[i],
+                            copy.copy(token),
+                            vad_offset=kwargs.get("begin_time", 0),
+                            upsample_rate=1,
+                        )
+                        if not hasattr(tokenizer, "bpemodel"):
+                            text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
+                        result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,}
+                    else:
+                        if not hasattr(tokenizer, "bpemodel"):
+                            text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+                        result_i = {"key": key[i], "text": text_postprocessed}
 
                     if ibest_writer is not None:
                         ibest_writer["token"][key[i]] = " ".join(token)
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 831d773..af61e5a 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -29,13 +29,13 @@
 
 
 def ts_prediction_lfr6_standard(
-    us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True
+    us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
 ):
     if not len(char_list):
         return "", []
     START_END_THRESHOLD = 5
-    MAX_TOKEN_DURATION = 12
-    TIME_RATE = 10.0 * 6 / 1000 / 3  #  3 times upsampled
+    MAX_TOKEN_DURATION = 12  #  3 times upsampled
+    TIME_RATE=10.0 * 6 / 1000 / upsample_rate
     if len(us_alphas.shape) == 2:
         alphas, peaks = us_alphas[0], us_peaks[0]  # support inference batch_size=1 only
     else:
diff --git a/runtime/docs/SDK_advanced_guide_offline_zh.md b/runtime/docs/SDK_advanced_guide_offline_zh.md
index 1cecb88..902e169 100644
--- a/runtime/docs/SDK_advanced_guide_offline_zh.md
+++ b/runtime/docs/SDK_advanced_guide_offline_zh.md
@@ -149,6 +149,7 @@
 --port 10095 閮ㄧ讲绔彛鍙�
 --wav-path 闇�瑕佽繘琛岃浆鍐欑殑闊抽鏂囦欢锛屾敮鎸佹枃浠惰矾寰�
 --hotword 鐑瘝鏂囦欢锛屾瘡琛屼竴涓儹璇嶏紝鏍煎紡(鐑瘝 鏉冮噸)锛氶樋閲屽反宸� 20
+--thread-num 璁剧疆瀹㈡埛绔嚎绋嬫暟
 --use-itn 璁剧疆鏄惁浣跨敤itn锛岄粯璁�1寮�鍚紝璁剧疆涓�0鍏抽棴
 ```
 

--
Gitblit v1.9.1