From 6baf10d5d15bed3948459b30567edd3c5898ff84 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 二月 2023 10:31:05 +0800
Subject: [PATCH] Merge branch 'main' into dev_zly
---
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer.py | 4
funasr/bin/asr_inference_paraformer.py | 174 +++++++++++++++++++++++++++--
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/RESULTS.md | 2
funasr/bin/asr_inference_paraformer_vad.py | 1
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py | 4
funasr/utils/timestamp_tools.py | 82 -------------
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/finetune.py | 4
funasr/bin/asr_inference_paraformer_vad_punc.py | 16 +-
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py | 2
README.md | 22 ++-
funasr/export/export_model.py | 1
11 files changed, 193 insertions(+), 119 deletions(-)
diff --git a/README.md b/README.md
index 6d44e6d..eaeb9ba 100644
--- a/README.md
+++ b/README.md
@@ -17,21 +17,23 @@
## What's new:
-### 2023.2.16, funasr-0.2.0
-- We support a new feature, export paraformer models into [onnx and torchscripts](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export) from modelscopes. The local finetuned models are also supported.
-- We support a new feature, [onnxruntime](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer), you could deploy the runtime without modelscope or funasr, for the [paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) model, the rtf of onnxruntime is 3x speedup(0.110->0.038) on cpu.
-- We support e new feature, [grpc](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/grpc), you could build the ASR service with grpc, by deploying the modelscope pipeline or onnxruntime.
+### 2023.2.17, funasr-0.2.0, modelscope-1.3.0
+- We support a new feature, export paraformer models into [onnx and torchscripts](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export) from modelscope. The local finetuned models are also supported.
+- We support a new feature, [onnxruntime](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer), you could deploy the runtime without modelscope or funasr, for the [paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) model, the rtf of onnxruntime is 3x speedup(0.110->0.038) on cpu, [details](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer#speed).
+- We support a new feature, [grpc](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/grpc), you could build the ASR service with grpc, by deploying the modelscope pipeline or onnxruntime.
- We release a new model [paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary), which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords.
+- We optimize the timestamp alignment of [Paraformer-large-long](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary), the prediction accuracy of timestamp is much improved, and achieving accumulated average shift (aas) of 74.7ms, [details](https://arxiv.org/abs/2301.12343).
- We release a new model, [8k VAD model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary), which could predict the duration of none-silence speech. It could be freely integrated with any ASR models in [modelscope](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
-- We release a new model, [MFCCA](https://www.modelscope.cn/models/yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary), a multi-channel multi-speaker model which is independent of the number and geometry of microphones and supports Mandarin meeting transcription.
+- We release a new model, [MFCCA](https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary), a multi-channel multi-speaker model which is independent of the number and geometry of microphones and supports Mandarin meeting transcription.
- We release several new UniASR model:
-[Southern Fujian Dialect model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825/summary),
+[Southern Fujian Dialect model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825/summary),
[French model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online/summary),
[German model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online/summary),
[Vietnamese model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online/summary),
[Persian model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/summary).
- We release a new model, [paraformer-data2vec model](https://www.modelscope.cn/models/damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/summary), an unsupervised pretraining model on AISHELL-2, which is inited for paraformer model and then finetune on AISHEL-1.
-### 2023.1.16, funasr-0.1.6
+- Various new types of audio input types are now supported by modelscope inference pipeline, including: mp3銆乫lac銆乷gg銆乷pus...
+### 2023.1.16, funasr-0.1.6锛� modelscope-1.2.0
- We release a new version model [Paraformer-large-long](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary), which integrate the [VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) model, [ASR](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary),
[Punctuation](https://www.modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) model and timestamp together. The model could take in several hours long inputs.
- We release a new model, [16k VAD model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary), which could predict the duration of none-silence speech. It could be freely integrated with any ASR models in [modelscope](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
@@ -101,4 +103,10 @@
booktitle={INTERSPEECH},
year={2022}
}
+@inproceedings{Shi2023AchievingTP,
+ title={Achieving Timestamp Prediction While Recognizing with Non-Autoregressive End-to-End ASR Model},
+ author={Xian Shi and Yanni Chen and Shiliang Zhang and Zhijie Yan},
+ booktitle={arXiv preprint arXiv:2301.12343}
+ year={2023}
+}
```
\ No newline at end of file
diff --git a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/RESULTS.md b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/RESULTS.md
index 8f58259..716d44e 100644
--- a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/RESULTS.md
+++ b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/RESULTS.md
@@ -1,5 +1,5 @@
# Paraformer-Large
-- Model link: <https://www.modelscope.cn/models/yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary>
+- Model link: <https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary>
- Model size: 45M
# Environments
diff --git a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/finetune.py b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/finetune.py
index 281292f..bf8176e 100755
--- a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/finetune.py
+++ b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/finetune.py
@@ -24,12 +24,12 @@
if __name__ == '__main__':
- params = modelscope_args(model="yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950")
+ params = modelscope_args(model="NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950")
params.output_dir = "./checkpoint" # m妯″瀷淇濆瓨璺緞
params.data_path = "./example_data/" # 鏁版嵁璺緞
params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
params.batch_bins = 1000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
params.max_epoch = 10 # 鏈�澶ц缁冭疆鏁�
params.lr = 0.0001 # 璁剧疆瀛︿範鐜�
- params.model_revision = 'v2.0.0'
+ params.model_revision = 'v1.0.0'
modelscope_finetune(params)
diff --git a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer.py b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer.py
index 3054394..fa22aad 100755
--- a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer.py
+++ b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer.py
@@ -18,8 +18,8 @@
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
inference_pipline = pipeline(
task=Tasks.auto_speech_recognition,
- model='yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950',
- model_revision='v2.0.0',
+ model='NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950',
+ model_revision='v1.0.0',
output_dir=output_dir_job,
batch_size=1,
)
diff --git a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py
index 00faad0..e714a3d 100755
--- a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py
+++ b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py
@@ -59,7 +59,7 @@
if __name__ == '__main__':
params = {}
- params["modelscope_model_name"] = "yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950"
+ params["modelscope_model_name"] = "NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950"
params["required_files"] = ["feats_stats.npz", "decoding.yaml", "configuration.json"]
params["output_dir"] = "./checkpoint"
params["data_dir"] = "./example_data/validation"
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index be35e78..18d788e 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -41,16 +41,7 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
-
-
-header_colors = '\033[95m'
-end_colors = '\033[0m'
-
-global_asr_language: str = 'zh-cn'
-global_sample_rate: Union[int, Dict[Any, int]] = {
- 'audio_fs': 16000,
- 'model_fs': 16000
-}
+from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
class Speech2Text:
@@ -346,6 +337,160 @@
# assert check_return_type(results)
return results
+class Speech2TextExport:
+ """Speech2TextExport class
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ frontend_conf: dict = None,
+ hotword_list_or_file: str = None,
+ **kwargs,
+ ):
+
+ # 1. Build ASR model
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+
+ logging.info("asr_model: {}".format(asr_model))
+ logging.info("asr_train_args: {}".format(asr_train_args))
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ token_list = asr_model.token_list
+
+
+
+ logging.info(f"Decoding device={device}, dtype={dtype}")
+
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ # self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+ self.frontend = frontend
+
+ model = Paraformer_export(asr_model, onnx=False)
+ self.asr_model = model
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+ ):
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+
+ enc_len_batch_total = feats_len.sum()
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ decoder_outs = self.asr_model(**batch)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ am_scores = decoder_out[i, :ys_pad_lens[i], :]
+
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ yseq.tolist(), device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+
+ results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
+
+ return results
+
def inference(
maxlenratio: float,
@@ -454,9 +599,11 @@
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
+ export_mode = False
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
+ export_mode = param_dict.get("export_mode", False)
else:
hotword_list_or_file = None
@@ -490,7 +637,10 @@
nbest=nbest,
hotword_list_or_file=hotword_list_or_file,
)
- speech2text = Speech2Text(**speech2text_kwargs)
+ if export_mode:
+ speech2text = Speech2TextExport(**speech2text_kwargs)
+ else:
+ speech2text = Speech2Text(**speech2text_kwargs)
def _forward(
data_path_and_name_and_type,
diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py
index dbb2719..c01c6ba 100644
--- a/funasr/bin/asr_inference_paraformer_vad.py
+++ b/funasr/bin/asr_inference_paraformer_vad.py
@@ -38,7 +38,6 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_lfr6
from funasr.bin.punctuation_infer import Text2Punc
from funasr.bin.asr_inference_paraformer_vad_punc import Speech2Text
from funasr.bin.asr_inference_paraformer_vad_punc import Speech2VadSegment
diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index ee36135..f194830 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -39,7 +39,7 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_lfr6, time_stamp_lfr6_pl
+from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
from funasr.bin.punctuation_infer import Text2Punc
from funasr.models.e2e_asr_paraformer import BiCifParaformer
@@ -282,13 +282,10 @@
else:
text = None
- if isinstance(self.asr_model, BiCifParaformer):
- timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
- results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
- else:
- time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time,
- end_time)
- results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
+
+ timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
+ results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
+
# assert check_return_type(results)
return results
@@ -636,7 +633,8 @@
text, token, token_int = result[0], result[1], result[2]
time_stamp = None if len(result) < 4 else result[3]
- if use_timestamp and time_stamp is not None:
+
+ if use_timestamp and time_stamp is not None:
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
else:
postprocessed_result = postprocess_utils.sentence_postprocess(token)
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index b5c6fa8..972f92f 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -44,6 +44,7 @@
model,
self.export_config,
)
+ model.eval()
# self._export_onnx(model, verbose, export_dir)
if self.onnx:
self._export_onnx(model, verbose, export_dir)
diff --git a/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
index d51c6bf..64dbaf8 100644
--- a/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
+++ b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
@@ -41,8 +41,8 @@
self.ort_infer = OrtInferSession(model_file, device_id)
self.batch_size = batch_size
- def __call__(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
- waveform_list = self.load_data(wav_content, fs)
+ def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
+ waveform_list = self.load_data(wav_content, self.frontend.opts.samp_freq)
waveform_nums = len(waveform_list)
asr_res = []
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 33d1255..f966aee 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -4,88 +4,6 @@
import numpy as np
from typing import Any, List, Tuple, Union
-def cut_interval(alphas: torch.Tensor, start: int, end: int, tail: bool):
- if not tail:
- if end == start + 1:
- cut = (end + start) / 2.0
- else:
- alpha = alphas[start+1: end].tolist()
- reverse_steps = 1
- for reverse_alpha in alpha[::-1]:
- if reverse_alpha > 0.35:
- reverse_steps += 1
- else:
- break
- cut = end - reverse_steps
- else:
- if end != len(alphas) - 1:
- cut = end + 1
- else:
- cut = start + 1
- return float(cut)
-
-def time_stamp_lfr6(alphas: torch.Tensor, speech_lengths: torch.Tensor, raw_text: List[str], begin: int = 0, end: int = None):
- time_stamp_list = []
- alphas = alphas[0]
- text = copy.deepcopy(raw_text)
- if end is None:
- time = speech_lengths * 60 / 1000
- sacle_rate = (time / speech_lengths[0]).tolist()
- else:
- time = (end - begin) / 1000
- sacle_rate = (time / speech_lengths[0]).tolist()
-
- predictor = (alphas > 0.5).int()
- fire_places = torch.nonzero(predictor == 1).squeeze(1).tolist()
-
- cuts = []
- npeak = int(predictor.sum())
- nchar = len(raw_text)
- if npeak - 1 == nchar:
- fire_places = torch.where((alphas > 0.5) == 1)[0].tolist()
- for i in range(len(fire_places)):
- if fire_places[i] < len(alphas) - 1:
- if 0.05 < alphas[fire_places[i]+1] < 0.5:
- fire_places[i] += 1
- elif npeak < nchar:
- lost_num = nchar - npeak
- lost_fire = speech_lengths[0].tolist() - fire_places[-1]
- interval_distance = lost_fire // (lost_num + 1)
- for i in range(1, lost_num + 1):
- fire_places.append(fire_places[-1] + interval_distance)
- elif npeak - 1 > nchar:
- redundance_num = npeak - 1 - nchar
- for i in range(redundance_num):
- fire_places.pop()
-
- cuts.append(0)
- start_sil = True
- if start_sil:
- text.insert(0, '<sil>')
-
- for i in range(len(fire_places)-1):
- cuts.append(cut_interval(alphas, fire_places[i], fire_places[i+1], tail=(i==len(fire_places)-2)))
-
- for i in range(2, len(fire_places)-2):
- if fire_places[i-2] == fire_places[i-1] - 1 and fire_places[i-1] != fire_places[i] - 1:
- cuts[i-1] += 1
-
- if cuts[-1] != len(alphas) - 1:
- text.append('<sil>')
- cuts.append(speech_lengths[0].tolist())
- cuts.insert(-1, (cuts[-1] + cuts[-2]) * 0.5)
- sec_fire_places = np.array(cuts) * sacle_rate
- for i in range(1, len(sec_fire_places) - 1):
- start, end = sec_fire_places[i], sec_fire_places[i+1]
- if i == len(sec_fire_places) - 2:
- end = time
- time_stamp_list.append([int(round(start, 2) * 1000) + begin, int(round(end, 2) * 1000) + begin])
- text = text[1:]
- if npeak - 1 == nchar or npeak > nchar:
- return time_stamp_list[:-1]
- else:
- return time_stamp_list
-
def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None):
START_END_THRESHOLD = 5
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
--
Gitblit v1.9.1