From 20aa07268a7fafaaab7762b488615af32a0e82b4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 11 六月 2024 14:02:27 +0800
Subject: [PATCH] update with main (#1800)
---
funasr/models/paraformer/model.py | 24 ++++++-
funasr/utils/timestamp_tools.py | 6 +-
funasr/auto/auto_model.py | 4 +
runtime/docs/SDK_advanced_guide_offline_zh.md | 1
funasr/models/paraformer/cif_predictor.py | 127 +++++++++++++++++++++++++++++++++++++++++-
funasr/models/llm_asr/adaptor.py | 1
docs/images/wechat.png | 0
7 files changed, 153 insertions(+), 10 deletions(-)
diff --git a/docs/images/wechat.png b/docs/images/wechat.png
index b5d9a55..491a1c3 100644
--- a/docs/images/wechat.png
+++ b/docs/images/wechat.png
Binary files differ
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 22b1ac0..047e652 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -429,6 +429,10 @@
# f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
# f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
+ if len(results_sorted) != n:
+ results_ret_list.append({"key": key, "text": "", "timestamp": []})
+ logging.info("decoding, utt: {}, empty result".format(key))
+ continue
restored_data = [0] * n
for j in range(n):
index = sorted_data[j][1]
diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py
index c939883..93534fe 100644
--- a/funasr/models/llm_asr/adaptor.py
+++ b/funasr/models/llm_asr/adaptor.py
@@ -125,6 +125,7 @@
olens = None
olens = (ilens - 1) // self.k + 1
masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
+
if self.blocks is not None:
for layer, block in enumerate(self.blocks):
x, masks = block(x, masks)
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 8b1a9bb..7490310 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -80,7 +80,7 @@
hidden, alphas, token_num, mask=mask
)
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+ acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
@@ -245,7 +245,7 @@
hidden, alphas, token_num, mask=None
)
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+ acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
@@ -449,7 +449,7 @@
mask = mask.transpose(-1, -2).float()
mask = mask.squeeze(-1)
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
- acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
+ acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold)
return acoustic_embeds, token_num, alphas, cif_peak
@@ -494,7 +494,60 @@
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
+@torch.jit.script
+def cif_v1_export(hidden, alphas, threshold: float):
+ device = hidden.device
+ dtype = hidden.dtype
+ batch_size, len_time, hidden_size = hidden.size()
+ threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
+ frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
+ fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
+
+ prefix_sum = torch.cumsum(alphas, dim=1)
+ prefix_sum_floor = torch.floor(prefix_sum)
+ dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
+ dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
+
+ dislocation_prefix_sum_floor[:, 0] = 0
+ dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
+
+ fire_idxs = dislocation_diff > 0
+ fires[fire_idxs] = 1
+ fires = fires + prefix_sum - prefix_sum_floor
+
+ prefix_sum_hidden = torch.cumsum(
+ alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
+ )
+
+ frames = prefix_sum_hidden[fire_idxs]
+ shift_frames = torch.roll(frames, 1, dims=0)
+
+ batch_len = fire_idxs.sum(1)
+ batch_idxs = torch.cumsum(batch_len, dim=0)
+ shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
+ shift_batch_idxs[0] = 0
+ shift_frames[shift_batch_idxs] = 0
+
+ remains = fires - torch.floor(fires)
+ remain_frames = (
+ remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
+ )
+
+ shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
+ shift_remain_frames[shift_batch_idxs] = 0
+
+ frames = frames - shift_frames + shift_remain_frames - remain_frames
+
+ max_label_len = batch_len.max()
+
+ frame_fires = torch.zeros(
+ batch_size, max_label_len, hidden_size, dtype=dtype, device=device
+ )
+ indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
+ frame_fires_idxs = indices < batch_len.unsqueeze(1)
+ frame_fires[frame_fires_idxs] = frames
+ return frame_fires, fires
@torch.jit.script
def cif_export(hidden, alphas, threshold: float):
@@ -608,6 +661,74 @@
return torch.stack(list_ls, 0), fires
+def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
+ batch_size, len_time = alphas.size()
+ device = alphas.device
+ dtype = alphas.dtype
+
+ threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
+
+ fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
+
+ prefix_sum = torch.cumsum(alphas, dim=1)
+ prefix_sum_floor = torch.floor(prefix_sum)
+ dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
+ dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
+
+ dislocation_prefix_sum_floor[:, 0] = 0
+ dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
+
+ fire_idxs = dislocation_diff > 0
+ fires[fire_idxs] = 1
+ fires = fires + prefix_sum - prefix_sum_floor
+ if return_fire_idxs:
+ return fires, fire_idxs
+ return fires
+
+
+def cif_v1(hidden, alphas, threshold):
+ fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)
+
+ device = hidden.device
+ dtype = hidden.dtype
+ batch_size, len_time, hidden_size = hidden.size()
+ frames = torch.zeros(batch_size, len_time, hidden_size,
+ dtype=dtype, device=device)
+ prefix_sum_hidden = torch.cumsum(
+ alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
+ )
+
+ frames = prefix_sum_hidden[fire_idxs]
+ shift_frames = torch.roll(frames, 1, dims=0)
+
+ batch_len = fire_idxs.sum(1)
+ batch_idxs = torch.cumsum(batch_len, dim=0)
+ shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
+ shift_batch_idxs[0] = 0
+ shift_frames[shift_batch_idxs] = 0
+
+ remains = fires - torch.floor(fires)
+ remain_frames = (
+ remains[fire_idxs].unsqueeze(-1).tile((1,
+ hidden_size)) * hidden[fire_idxs]
+ )
+
+ shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
+ shift_remain_frames[shift_batch_idxs] = 0
+
+ frames = frames - shift_frames + shift_remain_frames - remain_frames
+
+ max_label_len = batch_len.max()
+
+ frame_fires = torch.zeros(
+ batch_size, max_label_len, hidden_size, dtype=dtype, device=device
+ )
+ indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
+ frame_fires_idxs = indices < batch_len.unsqueeze(1)
+ frame_fires[frame_fires_idxs] = frames
+ return frame_fires, fires
+
+
def cif_wo_hidden(alphas, threshold):
batch_size, len_time = alphas.size()
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 0d9bb2b..85967af 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -4,6 +4,7 @@
# MIT License (https://opensource.org/licenses/MIT)
import time
+import copy
import torch
import logging
from torch.cuda.amp import autocast
@@ -21,6 +22,7 @@
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
@@ -452,6 +454,7 @@
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
+ pred_timestamp = kwargs.get("pred_timestamp", False)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
@@ -506,6 +509,7 @@
predictor_outs[2],
predictor_outs[3],
)
+
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
@@ -564,10 +568,22 @@
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
text_postprocessed = tokenizer.tokens2text(token)
- if not hasattr(tokenizer, "bpemodel"):
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-
- result_i = {"key": key[i], "text": text_postprocessed}
+
+ if pred_timestamp:
+ timestamp_str, timestamp = ts_prediction_lfr6_standard(
+ pre_peak_index[i],
+ alphas[i],
+ copy.copy(token),
+ vad_offset=kwargs.get("begin_time", 0),
+ upsample_rate=1,
+ )
+ if not hasattr(tokenizer, "bpemodel"):
+ text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
+ result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,}
+ else:
+ if not hasattr(tokenizer, "bpemodel"):
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "text": text_postprocessed}
if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 831d773..af61e5a 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -29,13 +29,13 @@
def ts_prediction_lfr6_standard(
- us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True
+ us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
):
if not len(char_list):
return "", []
START_END_THRESHOLD = 5
- MAX_TOKEN_DURATION = 12
- TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
+ MAX_TOKEN_DURATION = 12 # 3 times upsampled
+ TIME_RATE=10.0 * 6 / 1000 / upsample_rate
if len(us_alphas.shape) == 2:
alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
else:
diff --git a/runtime/docs/SDK_advanced_guide_offline_zh.md b/runtime/docs/SDK_advanced_guide_offline_zh.md
index 1cecb88..902e169 100644
--- a/runtime/docs/SDK_advanced_guide_offline_zh.md
+++ b/runtime/docs/SDK_advanced_guide_offline_zh.md
@@ -149,6 +149,7 @@
--port 10095 閮ㄧ讲绔彛鍙�
--wav-path 闇�瑕佽繘琛岃浆鍐欑殑闊抽鏂囦欢锛屾敮鎸佹枃浠惰矾寰�
--hotword 鐑瘝鏂囦欢锛屾瘡琛屼竴涓儹璇嶏紝鏍煎紡(鐑瘝 鏉冮噸)锛氶樋閲屽反宸� 20
+--thread-num 璁剧疆瀹㈡埛绔嚎绋嬫暟
--use-itn 璁剧疆鏄惁浣跨敤itn锛岄粯璁�1寮�鍚紝璁剧疆涓�0鍏抽棴
```
--
Gitblit v1.9.1