From c7fc6149b3c5c2de3107c4f1d4983309882d1a1a Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期三, 07 六月 2023 14:57:49 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR
---
funasr/runtime/onnxruntime/src/ct-transformer.cpp | 7
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py | 1
funasr/runtime/onnxruntime/src/CMakeLists.txt | 17
funasr/runtime/html5/static/wsconnecter.js | 7
funasr/runtime/python/websocket/wss_srv_asr.py | 41 +
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py | 1
docs/model_zoo/modelscope_models.md | 4
MODEL_LICENSE | 6
funasr/datasets/large_datasets/build_dataloader.py | 3
funasr/runtime/onnxruntime/src/fsmn-vad-online.h | 88 +++
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py | 1
egs/librispeech/rnnt/README.md | 18
funasr/runtime/onnxruntime/src/fsmn-vad.cpp | 51 +
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py | 1
egs/librispeech_100h/rnnt/README.md | 4
egs/librispeech/rnnt/run.sh | 222 ++++++++
funasr/runtime/onnxruntime/src/precomp.h | 3
funasr/runtime/onnxruntime/include/vad-model.h | 9
funasr/runtime/onnxruntime/src/audio.cpp | 78 ++
funasr/runtime/onnxruntime/bin/CMakeLists.txt | 16
egs/librispeech/rnnt/local/download_and_untar.sh | 97 +++
egs/librispeech/rnnt/utils | 1
README.md | 1
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py | 2
funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp | 198 +++++++
funasr/runtime/onnxruntime/src/fsmn-vad.h | 45
funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp | 14
funasr/version.txt | 2
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh | 2
egs/librispeech/rnnt/conf/decode_rnnt_conformer_streaming.yaml | 8
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp | 20
funasr/runtime/onnxruntime/src/funasrruntime.cpp | 18
funasr/runtime/onnxruntime/src/vad-model.cpp | 15
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh | 2
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py | 1
egs/aishell/rnnt/run.sh | 2
egs/librispeech/rnnt/path.sh | 5
egs_modelscope/tp/TEMPLATE/README.md | 2
funasr/runtime/onnxruntime/include/audio.h | 13
funasr/runtime/onnxruntime/include/funasrruntime.h | 13
fun_text_processing/inverse_text_normalization/id/id_unit_test.tsv | 16
funasr/runtime/onnxruntime/CMakeLists.txt | 13
funasr/runtime/html5/static/main.js | 4
egs/librispeech/rnnt/local/data_prep.sh | 58 ++
egs/librispeech/rnnt/local/spm_train.py | 12
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp | 12
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py | 2
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp | 0
funasr/runtime/onnxruntime/src/paraformer.h | 4
fun_text_processing/inverse_text_normalization/id/taggers/cardinal.py | 15
egs/librispeech/rnnt/conf/train_conformer_rnnt_unified.yaml | 98 +++
/dev/null | 58 --
funasr/runtime/websocket/CMakeLists.txt | 2
funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp | 89 ++
egs_modelscope/asr/TEMPLATE/README.md | 4
egs/librispeech/rnnt/local/spm_encode.py | 98 +++
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py | 1
egs/aishell/rnnt/README.md | 10
58 files changed, 1,281 insertions(+), 254 deletions(-)
diff --git a/MODEL_LICENSE b/MODEL_LICENSE
index e505980..3d9e410 100644
--- a/MODEL_LICENSE
+++ b/MODEL_LICENSE
@@ -9,7 +9,7 @@
涓轰簡淇濊瘉鏇村ソ鐨勭ぞ鍖哄悎浣滐紝鎴戜滑鍒跺畾浜嗕互涓嬪崗璁紝甯屾湜鎮ㄤ粩缁嗛槄璇诲苟閬靛畧鏈崗璁��
1 瀹氫箟
-鏈崗璁腑锛孾FunASR 杞欢]鎸� FunASR 寮�婧愭ā鍨嬪強鍏惰鐢熷搧锛屽寘鎷� Finetune 鍚庣殑妯″瀷锛沎鎮╙鎸囦娇鐢ㄣ�佷慨鏀广�佸垎浜拰瀛︿範[FunASR 杞欢]鐨勪釜浜烘垨缁勭粐銆�
+鏈崗璁腑锛孾FunASR 杞欢]鎸� FunASR 寮�婧愭ā鍨嬫潈閲嶅強鍏惰鐢熷搧锛屽寘鎷� Finetune 鍚庣殑妯″瀷锛沎鎮╙鎸囦娇鐢ㄣ�佷慨鏀广�佸垎浜拰瀛︿範[FunASR 杞欢]鐨勪釜浜烘垨缁勭粐銆�
2 璁稿彲鍜岄檺鍒�
2.1 璁稿彲
@@ -18,7 +18,7 @@
2.2 闄愬埗
-鎮ㄥ湪浣跨敤銆佸鍒躲�佷慨鏀瑰拰鍒嗕韩[FunASR 杞欢]鏃讹紝蹇呴』娉ㄦ槑鍑哄浠ュ強浣滆�呬俊鎭�傚苟涓旓紝灏哰FunASR 杞欢]涓婁紶鑷冲叾浠栫涓夋柟骞冲彴浠ヤ緵涓嬭浇锛岄渶瑕佽幏寰楅澶栬鍙紝鍙�氳繃瀹樻柟閭欢锛坒unasr@list.alibaba-inc.com锛夎繘琛岀敵璇凤紙鍏嶈垂锛夈��
+鎮ㄥ湪浣跨敤銆佸鍒躲�佷慨鏀瑰拰鍒嗕韩[FunASR 杞欢]鏃讹紝蹇呴』娉ㄦ槑鍑哄浠ュ強浣滆�呬俊鎭紝骞朵繚鐣橻FunASR 杞欢]涓浉鍏虫ā鍨嬪悕绉般��
3 璐d换鍜岄闄╂壙鎷�
[FunASR 杞欢]浠呬綔涓哄弬鑰冨拰瀛︿範浣跨敤锛屼笉瀵规偍浣跨敤鎴栦慨鏀筟FunASR 杞欢]閫犳垚鐨勪换浣曠洿鎺ユ垨闂存帴鎹熷け鎵挎媴浠讳綍璐d换銆傛偍瀵筟FunASR 杞欢]鐨勪娇鐢ㄥ拰淇敼搴旇鑷鎵挎媴椋庨櫓銆�
@@ -54,7 +54,7 @@
You are free to use, copy, modify, and share [FunASR software] under the conditions of this agreement.
2.2 Restrictions
-You should indicate the code and model source and author information when using, copying, modifying and sharing [FunASR software]. To upload the [FunASR software] to other third-party platforms for download, an additional license is required, which can be applied for free by sending an email to the official email address (funasr@list.alibaba-inc.com).
+You should indicate the code and model source and author information when using, copying, modifying and sharing [FunASR software]. You should keep the relevant names of models in [FunASR software].
3 Responsibility and Risk
[FunASR software] is for reference and learning purposes only and is not responsible for any direct or indirect losses caused by your use or modification of [FunASR software]. You should take responsibility and risks for your use and modification of [FunASR software].
diff --git a/README.md b/README.md
index c31d616..7c289e0 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,6 @@
| [**Highlights**](#highlights)
| [**Installation**](#installation)
| [**Docs**](https://alibaba-damo-academy.github.io/FunASR/en/index.html)
-| [**Tutorial_CN**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C)
| [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
| [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
| [**Model Zoo**](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/model_zoo/modelscope_models.md)
diff --git a/docs/model_zoo/modelscope_models.md b/docs/model_zoo/modelscope_models.md
index cdbe3ff..2e4e51c 100644
--- a/docs/model_zoo/modelscope_models.md
+++ b/docs/model_zoo/modelscope_models.md
@@ -1,8 +1,10 @@
# Pretrained Models on ModelScope
## Model License
-You are free to use, copy, modify, and share FunASR models under the conditions of this agreement. You should indicate the model source and author information when using, copying, modifying and sharing FunASR models. To upload the FunASR models to other third-party platforms for download, an additional license is required, which can be applied for free by sending an email to the official email address (funasr@list.alibaba-inc.com). Full model license could see [license](https://github.com/alibaba-damo-academy/FunASR/blob/main/MODEL_LICENSE)
+You are free to use, copy, modify, and share FunASR models under the conditions of this agreement. You should indicate the model source and author information when using, copying, modifying and sharing FunASR models. You should keep the relevant names of models in [FunASR software].. Full model license could see [license](https://github.com/alibaba-damo-academy/FunASR/blob/main/MODEL_LICENSE)
+## Model Usage
+Ref to [docs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)
## Model Zoo
Here we provided several pretrained models on different datasets. The details of models and datasets can be found on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition).
diff --git a/egs/aishell/rnnt/README.md b/egs/aishell/rnnt/README.md
index 45f1f3f..817c9b2 100644
--- a/egs/aishell/rnnt/README.md
+++ b/egs/aishell/rnnt/README.md
@@ -5,14 +5,14 @@
- 8 gpu(Tesla V100)
- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
- Train config: conf/train_conformer_rnnt_unified.yaml
-- chunk config: chunk size 16, full left chunk
+- chunk config: chunk size 16, 1 left chunk
- LM config: LM was not used
- Model size: 90M
## Results (CER)
-- Decode config: conf/train_conformer_rnnt_unified.yaml
+- Decode config: conf/decode_rnnt_conformer_streaming.yaml
-| testset | CER(%) |
+| testset | CER(%) |
|:-----------:|:-------:|
-| dev | 5.53 |
-| test | 6.24 |
+| dev | 5.43 |
+| test | 6.04 |
diff --git a/egs/aishell/rnnt/run.sh b/egs/aishell/rnnt/run.sh
index 893a197..9facc8b 100755
--- a/egs/aishell/rnnt/run.sh
+++ b/egs/aishell/rnnt/run.sh
@@ -4,7 +4,7 @@
# machines configuration
CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
-gpu_num=2
+gpu_num=8
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
diff --git a/egs/librispeech/rnnt/README.md b/egs/librispeech/rnnt/README.md
new file mode 100644
index 0000000..c64e2e8
--- /dev/null
+++ b/egs/librispeech/rnnt/README.md
@@ -0,0 +1,18 @@
+
+# Streaming RNN-T Result
+
+## Training Config
+- 8 gpu(Tesla V100)
+- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train config: conf/train_conformer_rnnt_unified.yaml
+- chunk config: chunk size 16, 1 left chunk
+- LM config: LM was not used
+- Model size: 90M
+
+## Results (CER)
+- Decode config: conf/decode_rnnt_conformer_streaming.yaml
+
+| testset | WER(%) |
+|:--------------:|:-------:|
+| test_clean | 3.58 |
+| test_other | 9.27 |
diff --git a/egs/librispeech/rnnt/conf/decode_rnnt_conformer_streaming.yaml b/egs/librispeech/rnnt/conf/decode_rnnt_conformer_streaming.yaml
new file mode 100644
index 0000000..26e43c6
--- /dev/null
+++ b/egs/librispeech/rnnt/conf/decode_rnnt_conformer_streaming.yaml
@@ -0,0 +1,8 @@
+# The conformer transducer decoding configuration from @jeon30c
+beam_size: 10
+simu_streaming: false
+streaming: true
+chunk_size: 16
+left_context: 16
+right_context: 0
+
diff --git a/egs/librispeech/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/librispeech/rnnt/conf/train_conformer_rnnt_unified.yaml
new file mode 100644
index 0000000..39c6a6a
--- /dev/null
+++ b/egs/librispeech/rnnt/conf/train_conformer_rnnt_unified.yaml
@@ -0,0 +1,98 @@
+encoder: chunk_conformer
+encoder_conf:
+ activation_type: swish
+ time_reduction_factor: 2
+ unified_model_training: true
+ default_chunk_size: 16
+ jitter_range: 4
+ left_chunk_size: 1
+ embed_vgg_like: false
+ subsampling_factor: 4
+ linear_units: 2048
+ output_size: 512
+ attention_heads: 8
+ dropout_rate: 0.5
+ positional_dropout_rate: 0.5
+ attention_dropout_rate: 0.5
+ cnn_module_kernel: 15
+ num_blocks: 12
+
+# decoder related
+rnnt_decoder: rnnt
+rnnt_decoder_conf:
+ embed_size: 512
+ hidden_size: 512
+ embed_dropout_rate: 0.5
+ dropout_rate: 0.5
+ use_embed_mask: true
+
+joint_network_conf:
+ joint_space_size: 512
+
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
+# Auxiliary CTC
+model: rnnt_unified
+model_conf:
+ auxiliary_ctc_weight: 0.0
+
+# minibatch related
+use_amp: true
+
+# optimization related
+accum_grad: 4
+grad_clip: 5
+max_epoch: 100
+val_scheduler_criterion:
+ - valid
+ - loss
+best_model_criterion:
+- - valid
+ - cer_transducer_chunk
+ - min
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+ lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 25000
+
+specaug: specaug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 40
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 50
+ num_time_mask: 5
+
+dataset_conf:
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 1024
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 10000
+ num_workers: 8
+
+log_interval: 50
+normalize: None
diff --git a/egs/librispeech/rnnt/local/data_prep.sh b/egs/librispeech/rnnt/local/data_prep.sh
new file mode 100755
index 0000000..c939b5f
--- /dev/null
+++ b/egs/librispeech/rnnt/local/data_prep.sh
@@ -0,0 +1,58 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Vassil Panayotov
+# 2014 Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+if [ "$#" -ne 2 ]; then
+ echo "Usage: $0 <src-dir> <dst-dir>"
+ echo "e.g.: $0 /export/a15/vpanayotov/data/LibriSpeech/dev-clean data/dev-clean"
+ exit 1
+fi
+
+src=$1
+dst=$2
+
+# all utterances are FLAC compressed
+if ! which flac >&/dev/null; then
+ echo "Please install 'flac' on ALL worker nodes!"
+ exit 1
+fi
+
+spk_file=$src/../SPEAKERS.TXT
+
+mkdir -p $dst || exit 1
+
+[ ! -d $src ] && echo "$0: no such directory $src" && exit 1
+[ ! -f $spk_file ] && echo "$0: expected file $spk_file to exist" && exit 1
+
+
+wav_scp=$dst/wav.scp; [[ -f "$wav_scp" ]] && rm $wav_scp
+trans=$dst/text; [[ -f "$trans" ]] && rm $trans
+
+for reader_dir in $(find -L $src -mindepth 1 -maxdepth 1 -type d | sort); do
+ reader=$(basename $reader_dir)
+ if ! [ $reader -eq $reader ]; then # not integer.
+ echo "$0: unexpected subdirectory name $reader"
+ exit 1
+ fi
+
+ for chapter_dir in $(find -L $reader_dir/ -mindepth 1 -maxdepth 1 -type d | sort); do
+ chapter=$(basename $chapter_dir)
+ if ! [ "$chapter" -eq "$chapter" ]; then
+ echo "$0: unexpected chapter-subdirectory name $chapter"
+ exit 1
+ fi
+
+ find -L $chapter_dir/ -iname "*.flac" | sort | xargs -I% basename % .flac | \
+ awk -v "dir=$chapter_dir" '{printf "%s %s/%s.flac \n", $0, dir, $0}' >>$wav_scp|| exit 1
+
+ chapter_trans=$chapter_dir/${reader}-${chapter}.trans.txt
+ [ ! -f $chapter_trans ] && echo "$0: expected file $chapter_trans to exist" && exit 1
+ cat $chapter_trans >>$trans
+ done
+done
+
+echo "$0: successfully prepared data in $dst"
+
+exit 0
diff --git a/egs/librispeech/rnnt/local/download_and_untar.sh b/egs/librispeech/rnnt/local/download_and_untar.sh
new file mode 100755
index 0000000..fe322e4
--- /dev/null
+++ b/egs/librispeech/rnnt/local/download_and_untar.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
+ echo " train-clean-100, train-clean-360, train-other-500."
+ exit 1
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1
+fi
+
+part_ok=false
+list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1
+fi
+
+if [ -f $data/LibriSpeech/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0
+fi
+
+
+# sizes of the archive files in bytes. This is some older versions.
+sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
+# sizes_new is the archive file sizes of the final release. Some of these sizes are of
+# things we probably won't download.
+sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
+
+if [ -f $data/$part.tar.gz ]; then
+ size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tar.gz
+ else
+ echo "$data/$part.tar.gz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tar.gz ]; then
+ if ! which wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1
+ fi
+ full_url=$url/$part.tar.gz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ if ! wget -P $data --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1
+ fi
+fi
+
+if ! tar -C $data -xvzf $data/$part.tar.gz; then
+ echo "$0: error un-tarring archive $data/$part.tar.gz"
+ exit 1
+fi
+
+touch $data/LibriSpeech/$part/.complete
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
+ rm $data/$part.tar.gz
+fi
diff --git a/egs/librispeech/rnnt/local/spm_encode.py b/egs/librispeech/rnnt/local/spm_encode.py
new file mode 100755
index 0000000..9e1c15f
--- /dev/null
+++ b/egs/librispeech/rnnt/local/spm_encode.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+
+
+import argparse
+import contextlib
+import sys
+
+import sentencepiece as spm
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", required=True,
+ help="sentencepiece model to use for encoding")
+ parser.add_argument("--inputs", nargs="+", default=['-'],
+ help="input files to filter/encode")
+ parser.add_argument("--outputs", nargs="+", default=['-'],
+ help="path to save encoded outputs")
+ parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
+ parser.add_argument("--min-len", type=int, metavar="N",
+ help="filter sentence pairs with fewer than N tokens")
+ parser.add_argument("--max-len", type=int, metavar="N",
+ help="filter sentence pairs with more than N tokens")
+ args = parser.parse_args()
+
+ assert len(args.inputs) == len(args.outputs), \
+ "number of input and output paths should match"
+
+ sp = spm.SentencePieceProcessor()
+ sp.Load(args.model)
+
+ if args.output_format == "piece":
+ def encode(l):
+ return sp.EncodeAsPieces(l)
+ elif args.output_format == "id":
+ def encode(l):
+ return list(map(str, sp.EncodeAsIds(l)))
+ else:
+ raise NotImplementedError
+
+ if args.min_len is not None or args.max_len is not None:
+ def valid(line):
+ return (
+ (args.min_len is None or len(line) >= args.min_len) and
+ (args.max_len is None or len(line) <= args.max_len)
+ )
+ else:
+ def valid(lines):
+ return True
+
+ with contextlib.ExitStack() as stack:
+ inputs = [
+ stack.enter_context(open(input, "r", encoding="utf-8"))
+ if input != "-" else sys.stdin
+ for input in args.inputs
+ ]
+ outputs = [
+ stack.enter_context(open(output, "w", encoding="utf-8"))
+ if output != "-" else sys.stdout
+ for output in args.outputs
+ ]
+
+ stats = {
+ "num_empty": 0,
+ "num_filtered": 0,
+ }
+
+ def encode_line(line):
+ line = line.strip()
+ if len(line) > 0:
+ line = encode(line)
+ if valid(line):
+ return line
+ else:
+ stats["num_filtered"] += 1
+ else:
+ stats["num_empty"] += 1
+ return None
+
+ for i, lines in enumerate(zip(*inputs), start=1):
+ enc_lines = list(map(encode_line, lines))
+ if not any(enc_line is None for enc_line in enc_lines):
+ for enc_line, output_h in zip(enc_lines, outputs):
+ print(" ".join(enc_line), file=output_h)
+ if i % 10000 == 0:
+ print("processed {} lines".format(i), file=sys.stderr)
+
+ print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
+ print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/rnnt/local/spm_train.py b/egs/librispeech/rnnt/local/spm_train.py
new file mode 100755
index 0000000..134a0b1
--- /dev/null
+++ b/egs/librispeech/rnnt/local/spm_train.py
@@ -0,0 +1,12 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+import sys
+
+import sentencepiece as spm
+
+if __name__ == "__main__":
+ spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
diff --git a/egs/librispeech/rnnt/path.sh b/egs/librispeech/rnnt/path.sh
new file mode 100644
index 0000000..7972642
--- /dev/null
+++ b/egs/librispeech/rnnt/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/egs/librispeech/rnnt/run.sh b/egs/librispeech/rnnt/run.sh
new file mode 100755
index 0000000..efccd85
--- /dev/null
+++ b/egs/librispeech/rnnt/run.sh
@@ -0,0 +1,222 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+gpu_num=8
+count=1
+gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=5
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="" #feature output dictionary
+exp_dir=""
+lang=en
+token_type=bpe
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=0
+stop_stage=5
+
+# feature configuration
+feats_dim=80
+nj=64
+
+# data
+raw_data=
+data_url=www.openslr.org/resources/12
+
+# bpe model
+nbpe=5000
+bpemode=unigram
+
+# exp tag
+tag="exp1"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train_960
+valid_set=dev
+test_sets="test_clean test_other dev_clean dev_other"
+
+asr_config=conf/train_conformer_rnnt_unified.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+
+inference_config=conf/decode_rnnt_conformer_streaming.yaml
+inference_asr_model=valid.cer_transducer_chunk.ave_10best.pb
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+ inference_nj=$[${ngpu}*${njob}]
+ _ngpu=1
+else
+ inference_nj=$njob
+ _ngpu=0
+fi
+
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
+ local/download_and_untar.sh ${raw_data} ${data_url} ${part}
+ done
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Data preparation"
+ # Data preparation
+ for x in dev-clean dev-other test-clean test-other train-clean-100 train-clean-360 train-other-500; do
+ local/data_prep.sh ${raw_data}/LibriSpeech/${x} ${feats_dir}/data/${x//-/_}
+ done
+ mkdir $feats_dir/data/$valid_set
+ dev_sets="dev_clean dev_other"
+ for file in wav.scp text; do
+ ( for f in $dev_sets; do cat $feats_dir/data/$f/$file; done ) | sort -k1 > $feats_dir/data/$valid_set/$file || exit 1;
+ done
+ mkdir $feats_dir/data/$train_set
+ train_sets="train_clean_100 train_clean_360 train_other_500"
+ for file in wav.scp text; do
+ ( for f in $train_sets; do cat $feats_dir/data/$f/$file; done ) | sort -k1 > $feats_dir/data/$train_set/$file || exit 1;
+ done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
+fi
+
+token_list=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
+bpemodel=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ ### Task dependent. You have to check non-linguistic symbols used in the corpus.
+ echo "stage 2: Dictionary and Json Data Preparation"
+ mkdir -p ${feats_dir}/data/lang_char/
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
+ cut -f 2- -d" " ${feats_dir}/data/${train_set}/text > ${feats_dir}/data/lang_char/input.txt
+ local/spm_train.py --input=${feats_dir}/data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
+ local/spm_encode.py --model=${bpemodel}.model --output_format=piece < ${feats_dir}/data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0}' >> ${token_list}
+ echo "<unk>" >> ${token_list}
+fi
+
+# LM Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: ASR Training"
+ mkdir -p ${exp_dir}/exp/${model_dir}
+ mkdir -p ${exp_dir}/exp/${model_dir}/log
+ INIT_FILE=./ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $gpu_num; ++i)); do
+ {
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ train.py \
+ --task_name asr \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --split_with_space false \
+ --bpemodel ${bpemodel}.model \
+ --token_type $token_type \
+ --token_list $token_list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --resume true \
+ --output_dir ${exp_dir}/exp/${model_dir} \
+ --config $asr_config \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --multiprocessing_distributed true \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+fi
+
+# Testing Stage
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5: Inference"
+ for dset in ${test_sets}; do
+ asr_exp=${exp_dir}/exp/${model_dir}
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
+ _logdir="${_dir}/logdir"
+ if [ -d ${_dir} ]; then
+ echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+ exit 0
+ fi
+ mkdir -p "${_logdir}"
+ _data="${feats_dir}/data/${dset}"
+ key_file=${_data}/${scp}
+ num_scp_file="$(<${key_file} wc -l)"
+ _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+ split_scps=
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob} \
+ --gpuid_list ${gpuid_list} \
+ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${asr_exp}"/config.yaml \
+ --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode asr \
+ ${_opts}
+
+ for f in token token_int score text; do
+ if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${_dir}/${f}"
+ fi
+ done
+ python utils/compute_wer.py ${_data}/text ${_dir}/text ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+ done
+fi
diff --git a/egs/librispeech/rnnt/utils b/egs/librispeech/rnnt/utils
new file mode 120000
index 0000000..fe070dd
--- /dev/null
+++ b/egs/librispeech/rnnt/utils
@@ -0,0 +1 @@
+../../aishell/transformer/utils
\ No newline at end of file
diff --git a/egs/librispeech_100h/rnnt/README.md b/egs/librispeech_100h/rnnt/README.md
index 7fd1af3..abadb2d 100644
--- a/egs/librispeech_100h/rnnt/README.md
+++ b/egs/librispeech_100h/rnnt/README.md
@@ -8,9 +8,9 @@
- Model size: 30.54M
## Results (CER)
-- Decode config: conf/decode_rnnt_transformer.yaml (ctc weight:0.5)
+- Decode config: conf/decode_rnnt_conformer.yaml
-| testset | WER(%) |
+| testset | WER(%) |
|:--------------:|:-------:|
| test_clean | 6.64 |
| test_other | 17.12 |
diff --git a/egs_modelscope/asr/TEMPLATE/README.md b/egs_modelscope/asr/TEMPLATE/README.md
index 2dfaefe..0219c5b 100644
--- a/egs_modelscope/asr/TEMPLATE/README.md
+++ b/egs_modelscope/asr/TEMPLATE/README.md
@@ -26,7 +26,7 @@
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
model_revision='v1.0.6',
- update_model='v1.0.6',
+ update_model=False,
mode='paraformer_streaming'
)
import soundfile
@@ -54,7 +54,7 @@
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
model_revision='v1.0.6',
- update_model='v1.0.6',
+ update_model=False,
mode="paraformer_fake_streaming"
)
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py
index 283f529..bef3849 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py
@@ -5,6 +5,7 @@
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
model_revision='v1.0.6',
+ update_model=False,
mode="paraformer_fake_streaming"
)
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py
index 668db95..c1f4afe 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py
@@ -15,7 +15,7 @@
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
model_revision='v1.0.6',
- update_model='v1.0.6',
+ update_model=False,
mode="paraformer_streaming"
)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py
index a72321a..30034aa 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py
@@ -15,6 +15,7 @@
kwargs = dict(
model=params.model,
model_revision='v1.0.6',
+ update_model=False,
data_dir=ds_dict,
dataset_type=params.dataset_type,
work_dir=params.output_dir,
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
index 4a823aa..241ebef 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
@@ -12,6 +12,7 @@
output_dir=args.output_dir,
batch_size=args.batch_size,
model_revision='v1.0.6',
+ update_model=False,
mode="paraformer_fake_streaming",
param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh
index 6d4d625..0cab31e 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh
@@ -54,7 +54,7 @@
--audio_in ${output_dir}/split/wav.$JOB.scp \
--output_dir ${output_dir}/output.$JOB \
--batch_size ${batch_size} \
- --gpuid ${gpuid}
+ --gpuid ${gpuid} \
--mode "paraformer_fake_streaming"
}&
done
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py
index 5fa417b..dae7766 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py
@@ -5,6 +5,7 @@
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
model_revision='v1.0.6',
+ update_model=False,
mode="paraformer_fake_streaming"
)
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py
index af5b22a..199fcd8 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py
@@ -15,7 +15,7 @@
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
model_revision='v1.0.6',
- update_model='v1.0.6',
+ update_model=False,
mode="paraformer_streaming"
)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py
index a9251ef..bd58c87 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py
@@ -15,6 +15,7 @@
kwargs = dict(
model=params.model,
model_revision='v1.0.6',
+ update_model=False,
data_dir=ds_dict,
dataset_type=params.dataset_type,
work_dir=params.output_dir,
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
index 4a823aa..241ebef 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
@@ -12,6 +12,7 @@
output_dir=args.output_dir,
batch_size=args.batch_size,
model_revision='v1.0.6',
+ update_model=False,
mode="paraformer_fake_streaming",
param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh
index 46c2bb3..f565825 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh
@@ -54,7 +54,7 @@
--audio_in ${output_dir}/split/wav.$JOB.scp \
--output_dir ${output_dir}/output.$JOB \
--batch_size ${batch_size} \
- --gpuid ${gpuid}
+ --gpuid ${gpuid} \
--mode "paraformer_fake_streaming"
}&
done
diff --git a/egs_modelscope/tp/TEMPLATE/README.md b/egs_modelscope/tp/TEMPLATE/README.md
index 7cc8508..3c7129f 100644
--- a/egs_modelscope/tp/TEMPLATE/README.md
+++ b/egs_modelscope/tp/TEMPLATE/README.md
@@ -11,7 +11,7 @@
inference_pipeline = pipeline(
task=Tasks.speech_timestamp,
model='damo/speech_timestamp_prediction-v1-16k-offline',
- output_dir=None)
+ model_revision='v1.1.0')
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav',
diff --git a/fun_text_processing/inverse_text_normalization/id/id_unit_test.tsv b/fun_text_processing/inverse_text_normalization/id/id_unit_test.tsv
index 662bf04..49eda5f 100644
--- a/fun_text_processing/inverse_text_normalization/id/id_unit_test.tsv
+++ b/fun_text_processing/inverse_text_normalization/id/id_unit_test.tsv
@@ -1,10 +1,10 @@
dua ribu dua puluh dua 2022
-tiga ribu 300
+tiga ribu 3000
sembilan ribu sembilan ratus sembilan puluh sembilan 9999
-seribu satu 100001
-ribu 100
+seribu satu 1001
+ribu 1000
seribu 1000
-seribu dua ratus delapan puluh sembilan 10289
+seribu dua ratus delapan puluh sembilan 1289
ribu dua ratus delapan puluh sembilan 1289
nol satu dua tiga empat lima enam tujuh delapan sembilan 01 2345-6789
empat belas 14
@@ -22,8 +22,8 @@
seratus dua puluh tiga 123
ratus dua puluh tiga 123
dua puluh empat maret 24 maret
-ribu tujuh puluh enam 10076
-seribu tujuh puluh enam 100076
-ribu tujuh puluh enam rupiah 10076 rupiah
+ribu tujuh puluh enam 1076
+seribu tujuh puluh enam 1076
+ribu tujuh puluh enam rupiah 1076 rupiah
tujuh puluh enam 76
-ditambah enam dua dua satu enam lima tiga sembilan nol enam nol lima +62 21 6539-0605
\ No newline at end of file
+ditambah enam dua dua satu enam lima tiga sembilan nol enam nol lima +62 21 6539-0605
diff --git a/fun_text_processing/inverse_text_normalization/id/taggers/cardinal.py b/fun_text_processing/inverse_text_normalization/id/taggers/cardinal.py
index 539acbc..d2f1a77 100644
--- a/fun_text_processing/inverse_text_normalization/id/taggers/cardinal.py
+++ b/fun_text_processing/inverse_text_normalization/id/taggers/cardinal.py
@@ -26,11 +26,10 @@
graph_teen = pynini.string_file(get_abs_path("data/numbers/teen.tsv"))
graph_hundreds = pynini.string_file(get_abs_path("data/numbers/hundreds.tsv"))
graph_thousand = pynini.string_file(get_abs_path("data/numbers/thousand.tsv"))
-
- graph_cents = pynini.cross("seratus", "100") | pynini.cross("ratus", "100") | pynini.union(graph_hundreds, pynutil.insert("0"))
+
graph_hundred = pynini.cross("ratus", "") | pynini.cross("seratus", "")
- graph_hundred_component = pynini.union(graph_digit + delete_space + graph_hundred, pynutil.insert("00"))
+ graph_hundred_component = pynini.union(graph_digit + delete_space + graph_hundred, pynutil.insert("0"))
graph_hundred_component += delete_space
graph_hundred_component += pynini.union(
graph_teen | pynutil.insert("00"),
@@ -44,8 +43,8 @@
(graph_ties | pynutil.insert("0")) + delete_space + (
graph_digit | pynutil.insert("0")),
)
- graph_hundred_component = graph_hundred_component | graph_cents | graph_one_hundred_component
-
+ graph_hundred_component = graph_hundred_component | graph_one_hundred_component
+
graph_hundred_component_at_least_one_none_zero_digit = graph_hundred_component @ (
pynini.closure(DAMO_DIGIT) + (DAMO_DIGIT - "0") + pynini.closure(DAMO_DIGIT)
)
@@ -54,14 +53,12 @@
)
graph_thousand = pynini.cross("ribu", "") | pynini.cross("seribu", "")
graph_one_thousand_component = pynini.union(pynini.cross("ribu", "1") | pynini.cross("seribu", "1"))
- graph_thousand_cents = pynini.cross("seribu", "10") | pynini.cross("ribu","10") | pynini.union(graph_thousand, pynutil.insert(""))
+
graph_thousands = pynini.union(
graph_hundred_component_at_least_one_none_zero_digit + delete_space + (pynutil.delete("ribu") | pynutil.delete("seribu")),
pynutil.insert("000", weight=0.1),
)
- graph_thousand_component = pynini.union(graph_digit + delete_space + graph_thousand, pynutil.insert("000"))
- graph_thousand_component += delete_space
- graph_thousands = graph_thousands | graph_thousand_cents | graph_thousand_component | graph_one_thousand_component
+ graph_thousands = graph_thousands | (pynutil.insert("00") + graph_one_thousand_component)
graph_million = pynini.union(
graph_hundred_component_at_least_one_none_zero_digit + delete_space + (pynutil.delete("juta") | pynutil.delete("sejuta")),
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 339292f..aa5d9be 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -77,12 +77,13 @@
bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel)
self.dataset_conf = args.dataset_conf
self.frontend_conf = args.frontend_conf
+ self.speed_perturb = args.speed_perturb if hasattr(args, "speed_perturb") else None
logging.info("dataloader config: {}".format(self.dataset_conf))
batch_mode = self.dataset_conf.get("batch_mode", "padding")
data_list = args.train_data_file if mode == "train" else args.valid_data_file
self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
self.dataset_conf, self.frontend_conf,
- speed_perturb=args.speed_perturb if mode == "train" else None,
+ speed_perturb=self.speed_perturb if mode == "train" else None,
mode=mode, batch_mode=batch_mode)
def build_iter(self, epoch, shuffle=True):
diff --git a/funasr/runtime/html5/static/main.js b/funasr/runtime/html5/static/main.js
index be57df1..9317778 100644
--- a/funasr/runtime/html5/static/main.js
+++ b/funasr/runtime/html5/static/main.js
@@ -145,7 +145,9 @@
isRec = false;
info_div.innerHTML="璇风瓑鍊�...";
btnStop.disabled = true;
- setTimeout(function(){btnStart.disabled = false;info_div.innerHTML="璇风偣鍑诲紑濮�";}, 3000 );
+ setTimeout(function(){
+ console.log("call stop ws!");
+ wsconnecter.wsStop();btnStart.disabled = false;info_div.innerHTML="璇风偣鍑诲紑濮�";}, 3000 );
rec.stop(function(blob,duration){
console.log(blob);
diff --git a/funasr/runtime/html5/static/wsconnecter.js b/funasr/runtime/html5/static/wsconnecter.js
index dfa8235..676a94a 100644
--- a/funasr/runtime/html5/static/wsconnecter.js
+++ b/funasr/runtime/html5/static/wsconnecter.js
@@ -28,7 +28,11 @@
if ( 'WebSocket' in window ) {
speechSokt = new WebSocket( Uri ); // 瀹氫箟socket杩炴帴瀵硅薄
speechSokt.onopen = function(e){onOpen(e);}; // 瀹氫箟鍝嶅簲鍑芥暟
- speechSokt.onclose = function(e){onClose(e);};
+ speechSokt.onclose = function(e){
+ console.log("onclose ws!");
+ speechSokt.close();
+ onClose(e);
+ };
speechSokt.onmessage = function(e){onMessage(e);};
speechSokt.onerror = function(e){onError(e);};
return 1;
@@ -42,6 +46,7 @@
// 瀹氫箟鍋滄涓庡彂閫佸嚱鏁�
this.wsStop = function () {
if(speechSokt != undefined) {
+ console.log("stop ws!");
speechSokt.close();
}
};
diff --git a/funasr/runtime/onnxruntime/CMakeLists.txt b/funasr/runtime/onnxruntime/CMakeLists.txt
index 9f6013f..0847d1f 100644
--- a/funasr/runtime/onnxruntime/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/CMakeLists.txt
@@ -7,6 +7,8 @@
# set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
+
include(TestBigEndian)
test_big_endian(BIG_ENDIAN)
@@ -30,12 +32,13 @@
include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi-native-fbank)
include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include)
-add_subdirectory(third_party/yaml-cpp)
-add_subdirectory(third_party/kaldi-native-fbank/kaldi-native-fbank/csrc)
-add_subdirectory(src)
-
if(ENABLE_GLOG)
include_directories(${PROJECT_SOURCE_DIR}/third_party/glog)
set(BUILD_TESTING OFF)
add_subdirectory(third_party/glog)
-endif()
\ No newline at end of file
+endif()
+
+add_subdirectory(third_party/yaml-cpp)
+add_subdirectory(third_party/kaldi-native-fbank/kaldi-native-fbank/csrc)
+add_subdirectory(src)
+add_subdirectory(bin)
diff --git a/funasr/runtime/onnxruntime/bin/CMakeLists.txt b/funasr/runtime/onnxruntime/bin/CMakeLists.txt
new file mode 100644
index 0000000..962da0b
--- /dev/null
+++ b/funasr/runtime/onnxruntime/bin/CMakeLists.txt
@@ -0,0 +1,16 @@
+include_directories(${CMAKE_SOURCE_DIR}/include)
+
+add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp")
+target_link_libraries(funasr-onnx-offline PUBLIC funasr)
+
+add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp")
+target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
+
+add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp")
+target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
+
+add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
+target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
+
+add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
+target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
similarity index 100%
rename from funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp
rename to funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
similarity index 93%
rename from funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
rename to funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index 2d182e0..d2692ce 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -28,7 +28,7 @@
std::atomic<int> wav_index(0);
std::mutex mtx;
-void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list,
+void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wav_ids,
float* total_length, long* total_time, int core_id) {
struct timeval start, end;
@@ -59,7 +59,7 @@
if(result){
string msg = FunASRGetResult(result, 0);
- LOG(INFO) << "Thread: " << this_thread::get_id() <<" Result: " << msg.c_str();
+ LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg.c_str();
float snippet_time = FunASRGetRetSnippetTime(result);
n_total_length += snippet_time;
@@ -146,9 +146,12 @@
// read wav_path
vector<string> wav_list;
+ vector<string> wav_ids;
+ string default_id = "wav_default_id";
string wav_path_ = model_path.at(WAV_PATH);
if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
wav_list.emplace_back(wav_path_);
+ wav_ids.emplace_back(default_id);
}
else if(is_target_file(wav_path_, "scp")){
ifstream in(wav_path_);
@@ -162,7 +165,8 @@
istringstream iss(line);
string column1, column2;
iss >> column1 >> column2;
- wav_list.emplace_back(column2);
+ wav_list.emplace_back(column2);
+ wav_ids.emplace_back(column1);
}
in.close();
}else{
@@ -178,7 +182,7 @@
int rtf_threds = thread_num.getValue();
for (int i = 0; i < rtf_threds; i++)
{
- threads.emplace_back(thread(runReg, asr_handle, wav_list, &total_length, &total_time, i));
+ threads.emplace_back(thread(runReg, asr_handle, wav_list, wav_ids, &total_length, &total_time, i));
}
for (auto& thread : threads)
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp
similarity index 90%
rename from funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
rename to funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp
index 0f606c6..3aaffbd 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp
@@ -38,8 +38,8 @@
}
}
-void print_segs(vector<vector<int>>* vec) {
- string seg_out="[";
+void print_segs(vector<vector<int>>* vec, string &wav_id) {
+ string seg_out=wav_id + ": [";
for (int i = 0; i < vec->size(); i++) {
vector<int> inner_vec = (*vec)[i];
seg_out += "[";
@@ -97,9 +97,12 @@
// read wav_path
vector<string> wav_list;
+ vector<string> wav_ids;
+ string default_id = "wav_default_id";
string wav_path_ = model_path.at(WAV_PATH);
if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
wav_list.emplace_back(wav_path_);
+ wav_ids.emplace_back(default_id);
}
else if(is_target_file(wav_path_, "scp")){
ifstream in(wav_path_);
@@ -113,7 +116,8 @@
istringstream iss(line);
string column1, column2;
iss >> column1 >> column2;
- wav_list.emplace_back(column2);
+ wav_list.emplace_back(column2);
+ wav_ids.emplace_back(column1);
}
in.close();
}else{
@@ -123,9 +127,11 @@
float snippet_time = 0.0f;
long taking_micros = 0;
- for(auto& wav_file : wav_list){
+ for (int i = 0; i < wav_list.size(); i++) {
+ auto& wav_file = wav_list[i];
+ auto& wav_id = wav_ids[i];
gettimeofday(&start, NULL);
- FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), FSMN_VAD_OFFLINE, NULL, 16000);
+ FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), NULL, 16000);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
@@ -133,7 +139,7 @@
if (result)
{
vector<std::vector<int>>* vad_segments = FsmnVadGetResult(result, 0);
- print_segs(vad_segments);
+ print_segs(vad_segments, wav_id);
snippet_time += FsmnVadGetRetSnippetTime(result);
FsmnVadFreeResult(result);
}
@@ -142,7 +148,7 @@
LOG(ERROR) << ("No return data!\n");
}
}
-
+
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
similarity index 92%
rename from funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
rename to funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
index 3472925..82668f8 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
@@ -88,9 +88,12 @@
// read wav_path
vector<string> wav_list;
- string wav_path_ = model_path.at(WAV_PATH);
+ vector<string> wav_ids;
+ string default_id = "wav_default_id";
+ string wav_path_ = model_path.at(WAV_PATH);
if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
wav_list.emplace_back(wav_path_);
+ wav_ids.emplace_back(default_id);
}
else if(is_target_file(wav_path_, "scp")){
ifstream in(wav_path_);
@@ -104,7 +107,8 @@
istringstream iss(line);
string column1, column2;
iss >> column1 >> column2;
- wav_list.emplace_back(column2);
+ wav_list.emplace_back(column2);
+ wav_ids.emplace_back(column1);
}
in.close();
}else{
@@ -114,7 +118,9 @@
float snippet_time = 0.0f;
long taking_micros = 0;
- for(auto& wav_file : wav_list){
+ for (int i = 0; i < wav_list.size(); i++) {
+ auto& wav_file = wav_list[i];
+ auto& wav_id = wav_ids[i];
gettimeofday(&start, NULL);
FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, 16000);
gettimeofday(&end, NULL);
@@ -124,7 +130,7 @@
if (result)
{
string msg = FunASRGetResult(result, 0);
- LOG(INFO)<<"Result: "<<msg;
+ LOG(INFO)<< wav_id <<" : "<<msg;
snippet_time += FunASRGetRetSnippetTime(result);
FunASRFreeResult(result);
}
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
similarity index 62%
copy from funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
copy to funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
index 0f606c6..68e32e5 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
@@ -18,6 +18,7 @@
#include "funasrruntime.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
+#include "audio.h"
using namespace std;
@@ -38,10 +39,16 @@
}
}
-void print_segs(vector<vector<int>>* vec) {
- string seg_out="[";
+void print_segs(vector<vector<int>>* vec, string &wav_id) {
+ if((*vec).size() == 0){
+ return;
+ }
+ string seg_out=wav_id + ": [";
for (int i = 0; i < vec->size(); i++) {
vector<int> inner_vec = (*vec)[i];
+ if(inner_vec.size() == 0){
+ continue;
+ }
seg_out += "[";
for (int j = 0; j < inner_vec.size(); j++) {
seg_out += to_string(inner_vec[j]);
@@ -97,9 +104,12 @@
// read wav_path
vector<string> wav_list;
+ vector<string> wav_ids;
+ string default_id = "wav_default_id";
string wav_path_ = model_path.at(WAV_PATH);
if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
wav_list.emplace_back(wav_path_);
+ wav_ids.emplace_back(default_id);
}
else if(is_target_file(wav_path_, "scp")){
ifstream in(wav_path_);
@@ -113,39 +123,76 @@
istringstream iss(line);
string column1, column2;
iss >> column1 >> column2;
- wav_list.emplace_back(column2);
+ wav_list.emplace_back(column2);
+ wav_ids.emplace_back(column1);
}
in.close();
}else{
LOG(ERROR)<<"Please check the wav extension!";
exit(-1);
}
-
+ // init online features
+ FUNASR_HANDLE online_hanlde=FsmnVadOnlineInit(vad_hanlde);
float snippet_time = 0.0f;
long taking_micros = 0;
- for(auto& wav_file : wav_list){
- gettimeofday(&start, NULL);
- FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), FSMN_VAD_OFFLINE, NULL, 16000);
- gettimeofday(&end, NULL);
- seconds = (end.tv_sec - start.tv_sec);
- taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ for (int i = 0; i < wav_list.size(); i++) {
+ auto& wav_file = wav_list[i];
+ auto& wav_id = wav_ids[i];
- if (result)
- {
- vector<std::vector<int>>* vad_segments = FsmnVadGetResult(result, 0);
- print_segs(vad_segments);
- snippet_time += FsmnVadGetRetSnippetTime(result);
- FsmnVadFreeResult(result);
- }
- else
- {
- LOG(ERROR) << ("No return data!\n");
+ int32_t sampling_rate_ = -1;
+ funasr::Audio audio(1);
+ if(is_target_file(wav_file.c_str(), "wav")){
+ int32_t sampling_rate_ = -1;
+ if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_file;
+ exit(-1);
+ }
+ }else if(is_target_file(wav_file.c_str(), "pcm")){
+ if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_file;
+ exit(-1);
+ }
+ }else{
+ LOG(ERROR)<<"Wrong wav extension";
+ exit(-1);
+ }
+ char* speech_buff = audio.GetSpeechChar();
+ int buff_len = audio.GetSpeechLen()*2;
+
+ int step = 3200;
+ bool is_final = false;
+
+ for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
+ if (sample_offset + step >= buff_len - 1) {
+ step = buff_len - sample_offset;
+ is_final = true;
+ } else {
+ is_final = false;
+ }
+ gettimeofday(&start, NULL);
+ FUNASR_RESULT result = FsmnVadInferBuffer(online_hanlde, speech_buff+sample_offset, step, NULL, is_final, 16000);
+ gettimeofday(&end, NULL);
+ seconds = (end.tv_sec - start.tv_sec);
+ taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+
+ if (result)
+ {
+ vector<std::vector<int>>* vad_segments = FsmnVadGetResult(result, 0);
+ print_segs(vad_segments, wav_id);
+ snippet_time += FsmnVadGetRetSnippetTime(result);
+ FsmnVadFreeResult(result);
+ }
+ else
+ {
+ LOG(ERROR) << ("No return data!\n");
+ }
}
}
-
+
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
+ FsmnVadUninit(online_hanlde);
FsmnVadUninit(vad_hanlde);
return 0;
}
diff --git a/funasr/runtime/onnxruntime/include/audio.h b/funasr/runtime/onnxruntime/include/audio.h
index 1eabd3e..d2100a4 100644
--- a/funasr/runtime/onnxruntime/include/audio.h
+++ b/funasr/runtime/onnxruntime/include/audio.h
@@ -33,8 +33,9 @@
class Audio {
private:
- float *speech_data;
- int16_t *speech_buff;
+ float *speech_data=nullptr;
+ int16_t *speech_buff=nullptr;
+ char* speech_char=nullptr;
int speech_len;
int speech_align_len;
int offset;
@@ -47,18 +48,22 @@
Audio(int data_type, int size);
~Audio();
void Disp();
- bool LoadWav(const char* filename, int32_t* sampling_rate);
void WavResample(int32_t sampling_rate, const float *waveform, int32_t n);
bool LoadWav(const char* buf, int n_len, int32_t* sampling_rate);
+ bool LoadWav(const char* filename, int32_t* sampling_rate);
+ bool LoadWav2Char(const char* filename, int32_t* sampling_rate);
bool LoadPcmwav(const char* buf, int n_file_len, int32_t* sampling_rate);
bool LoadPcmwav(const char* filename, int32_t* sampling_rate);
+ bool LoadPcmwav2Char(const char* filename, int32_t* sampling_rate);
int FetchChunck(float *&dout, int len);
int Fetch(float *&dout, int &len, int &flag);
void Padding();
void Split(OfflineStream* offline_streamj);
- void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments);
+ void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
float GetTimeLen();
int GetQueueSize() { return (int)frame_queue.size(); }
+ char* GetSpeechChar(){return speech_char;}
+ int GetSpeechLen(){return speech_len;}
};
} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/include/funasrruntime.h b/funasr/runtime/onnxruntime/include/funasrruntime.h
index 5cfdb47..af430f7 100644
--- a/funasr/runtime/onnxruntime/include/funasrruntime.h
+++ b/funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -46,12 +46,6 @@
FUNASR_MODEL_PARAFORMER = 3,
}FUNASR_MODEL_TYPE;
-typedef enum
-{
- FSMN_VAD_OFFLINE=0,
- FSMN_VAD_ONLINE = 1,
-}FSMN_VAD_MODE;
-
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
// ASR
@@ -68,11 +62,12 @@
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result);
// VAD
-_FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num, FSMN_VAD_MODE mode=FSMN_VAD_OFFLINE);
+_FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num);
+_FUNASRAPI FUNASR_HANDLE FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle);
// buffer
-_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
+_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000);
// file, support wav & pcm
-_FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
+_FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate=16000);
_FUNASRAPI std::vector<std::vector<int>>* FsmnVadGetResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI void FsmnVadFreeResult(FUNASR_RESULT result);
diff --git a/funasr/runtime/onnxruntime/include/vad-model.h b/funasr/runtime/onnxruntime/include/vad-model.h
index b1b1e9d..07f1833 100644
--- a/funasr/runtime/onnxruntime/include/vad-model.h
+++ b/funasr/runtime/onnxruntime/include/vad-model.h
@@ -12,14 +12,9 @@
virtual ~VadModel(){};
virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
virtual std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true)=0;
- virtual void ReadModel(const char* vad_model)=0;
- virtual void LoadConfigFromYaml(const char* filename)=0;
- virtual void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
- std::vector<float> &waves)=0;
- virtual void LoadCmvn(const char *filename)=0;
- virtual void InitCache()=0;
};
-VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num, int mode);
+VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num);
+VadModel *CreateVadModel(void* fsmnvad_handle);
} // namespace funasr
#endif
diff --git a/funasr/runtime/onnxruntime/src/CMakeLists.txt b/funasr/runtime/onnxruntime/src/CMakeLists.txt
index 341a16a..d083d8e 100644
--- a/funasr/runtime/onnxruntime/src/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/src/CMakeLists.txt
@@ -1,11 +1,8 @@
file(GLOB files1 "*.cpp")
-file(GLOB files2 "*.cc")
+set(files ${files1})
-set(files ${files1} ${files2})
-set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
-
-add_library(funasr ${files})
+add_library(funasr SHARED ${files})
if(WIN32)
set(EXTRA_LIBS pthread yaml-cpp csrc glog)
@@ -24,13 +21,3 @@
include_directories(${CMAKE_SOURCE_DIR}/include)
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
-
-add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp")
-add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp")
-add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
-add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
-target_link_libraries(funasr-onnx-offline PUBLIC funasr)
-target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
-target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
-target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
-
diff --git a/funasr/runtime/onnxruntime/src/audio.cpp b/funasr/runtime/onnxruntime/src/audio.cpp
index 6d63d67..23d0010 100644
--- a/funasr/runtime/onnxruntime/src/audio.cpp
+++ b/funasr/runtime/onnxruntime/src/audio.cpp
@@ -176,12 +176,12 @@
{
if (speech_buff != NULL) {
free(speech_buff);
-
}
-
if (speech_data != NULL) {
-
free(speech_data);
+ }
+ if (speech_char != NULL) {
+ free(speech_char);
}
}
@@ -296,8 +296,47 @@
return false;
}
-bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
+bool Audio::LoadWav2Char(const char *filename, int32_t* sampling_rate)
{
+ WaveHeader header;
+ if (speech_char != NULL) {
+ free(speech_char);
+ }
+ offset = 0;
+ std::ifstream is(filename, std::ifstream::binary);
+ is.read(reinterpret_cast<char *>(&header), sizeof(header));
+ if(!is){
+ LOG(ERROR) << "Failed to read " << filename;
+ return false;
+ }
+ if (!header.Validate()) {
+ return false;
+ }
+ header.SeekToDataChunk(is);
+ if (!is) {
+ return false;
+ }
+ if (!header.Validate()) {
+ return false;
+ }
+ header.SeekToDataChunk(is);
+ if (!is) {
+ return false;
+ }
+
+ *sampling_rate = header.sample_rate;
+ // header.subchunk2_size contains the number of bytes in the data.
+ // As we assume each sample contains two bytes, so it is divided by 2 here
+ speech_len = header.subchunk2_size / 2;
+ speech_char = (char *)malloc(header.subchunk2_size);
+ memset(speech_char, 0, header.subchunk2_size);
+ is.read(speech_char, header.subchunk2_size);
+
+ return true;
+}
+
+bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
+{
WaveHeader header;
if (speech_data != NULL) {
free(speech_data);
@@ -441,6 +480,33 @@
}
+bool Audio::LoadPcmwav2Char(const char* filename, int32_t* sampling_rate)
+{
+ if (speech_char != NULL) {
+ free(speech_char);
+ }
+ offset = 0;
+
+ FILE* fp;
+ fp = fopen(filename, "rb");
+ if (fp == nullptr)
+ {
+ LOG(ERROR) << "Failed to read " << filename;
+ return false;
+ }
+ fseek(fp, 0, SEEK_END);
+ uint32_t n_file_len = ftell(fp);
+ fseek(fp, 0, SEEK_SET);
+
+ speech_len = (n_file_len) / 2;
+ speech_char = (char *)malloc(n_file_len);
+ memset(speech_char, 0, n_file_len);
+ fread(speech_char, sizeof(int16_t), n_file_len/2, fp);
+ fclose(fp);
+
+ return true;
+}
+
int Audio::FetchChunck(float *&dout, int len)
{
if (offset >= speech_align_len) {
@@ -541,7 +607,7 @@
}
-void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments)
+void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished)
{
AudioFrame *frame;
@@ -552,7 +618,7 @@
frame = NULL;
std::vector<float> pcm_data(speech_data, speech_data+sp_len);
- vad_segments = vad_obj->Infer(pcm_data);
+ vad_segments = vad_obj->Infer(pcm_data, input_finished);
}
} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/ct-transformer.cpp b/funasr/runtime/onnxruntime/src/ct-transformer.cpp
index 30a1810..58eec25 100644
--- a/funasr/runtime/onnxruntime/src/ct-transformer.cpp
+++ b/funasr/runtime/onnxruntime/src/ct-transformer.cpp
@@ -103,9 +103,10 @@
vector<string> WordWithPunc;
for (int i = 0; i < InputStr.size(); i++)
{
- if (i > 0 && !(InputStr[i][0] & 0x80) && (i + 1) <InputStr.size() && !(InputStr[i+1][0] & 0x80))// 锟叫硷拷锟接拷模锟�
+ // if (i > 0 && !(InputStr[i][0] & 0x80) && (i + 1) <InputStr.size() && !(InputStr[i+1][0] & 0x80))// 锟叫硷拷锟接拷模锟�
+ if (i > 0 && !(InputStr[i-1][0] & 0x80) && !(InputStr[i][0] & 0x80))
{
- InputStr[i] = InputStr[i]+ " ";
+ InputStr[i] = " " + InputStr[i];
}
WordWithPunc.push_back(InputStr[i]);
@@ -187,4 +188,4 @@
return punction;
}
-} // namespace funasr
\ No newline at end of file
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp
new file mode 100644
index 0000000..0346916
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp
@@ -0,0 +1,198 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#include <fstream>
+#include "precomp.h"
+
+namespace funasr {
+
+void FsmnVadOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
+ std::vector<float> &waves) {
+ knf::OnlineFbank fbank(fbank_opts_);
+ // cache merge
+ waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
+ int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
+ // Send the audio after the last frame shift position to the cache
+ input_cache_.clear();
+ input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
+ if (frame_number == 0) {
+ return;
+ }
+ // Delete audio that haven't undergone fbank processing
+ waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
+
+ std::vector<float> buf(waves.size());
+ for (int32_t i = 0; i != waves.size(); ++i) {
+ buf[i] = waves[i] * 32768;
+ }
+ fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
+ // fbank.AcceptWaveform(sample_rate, &waves[0], waves.size());
+ int32_t frames = fbank.NumFramesReady();
+ for (int32_t i = 0; i != frames; ++i) {
+ const float *frame = fbank.GetFrame(i);
+ vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
+ vad_feats.emplace_back(frame_vector);
+ }
+}
+
+void FsmnVadOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &vad_feats,
+ vector<float> &waves, bool input_finished) {
+ FbankKaldi(sample_rate, vad_feats, waves);
+ // cache deal & online lfr,cmvn
+ if (vad_feats.size() > 0) {
+ if (!reserve_waveforms_.empty()) {
+ waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
+ }
+ if (lfr_splice_cache_.empty()) {
+ for (int i = 0; i < (lfr_m - 1) / 2; i++) {
+ lfr_splice_cache_.emplace_back(vad_feats[0]);
+ }
+ }
+ if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
+ vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
+ int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
+ int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
+ int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats, input_finished);
+ int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
+ reserve_waveforms_.clear();
+ reserve_waveforms_.insert(reserve_waveforms_.begin(),
+ waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
+ waves.begin() + frame_from_waves * frame_shift_sample_length_);
+ int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
+ waves.erase(waves.begin() + sample_length, waves.end());
+ } else {
+ reserve_waveforms_.clear();
+ reserve_waveforms_.insert(reserve_waveforms_.begin(),
+ waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
+ lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end());
+ }
+ } else {
+ if (input_finished) {
+ if (!reserve_waveforms_.empty()) {
+ waves = reserve_waveforms_;
+ }
+ vad_feats = lfr_splice_cache_;
+ OnlineLfrCmvn(vad_feats, input_finished);
+ }
+ }
+ if(input_finished){
+ Reset();
+ ResetCache();
+ }
+}
+
+int FsmnVadOnline::OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished) {
+ vector<vector<float>> out_feats;
+ int T = vad_feats.size();
+ int T_lrf = ceil((T - (lfr_m - 1) / 2) / lfr_n);
+ int lfr_splice_frame_idxs = T_lrf;
+ vector<float> p;
+ for (int i = 0; i < T_lrf; i++) {
+ if (lfr_m <= T - i * lfr_n) {
+ for (int j = 0; j < lfr_m; j++) {
+ p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
+ }
+ out_feats.emplace_back(p);
+ p.clear();
+ } else {
+ if (input_finished) {
+ int num_padding = lfr_m - (T - i * lfr_n);
+ for (int j = 0; j < (vad_feats.size() - i * lfr_n); j++) {
+ p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
+ }
+ for (int j = 0; j < num_padding; j++) {
+ p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
+ }
+ out_feats.emplace_back(p);
+ } else {
+ lfr_splice_frame_idxs = i;
+ break;
+ }
+ }
+ }
+ lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n);
+ lfr_splice_cache_.clear();
+ lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end());
+
+ // Apply cmvn
+ for (auto &out_feat: out_feats) {
+ for (int j = 0; j < means_list_.size(); j++) {
+ out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
+ }
+ }
+ vad_feats = out_feats;
+ return lfr_splice_frame_idxs;
+}
+
+std::vector<std::vector<int>>
+FsmnVadOnline::Infer(std::vector<float> &waves, bool input_finished) {
+ std::vector<std::vector<float>> vad_feats;
+ std::vector<std::vector<float>> vad_probs;
+ ExtractFeats(vad_sample_rate_, vad_feats, waves, input_finished);
+ fsmnvad_handle_->Forward(vad_feats, &vad_probs, &in_cache_, input_finished);
+
+ std::vector<std::vector<int>> vad_segments;
+ vad_segments = vad_scorer(vad_probs, waves, input_finished, true, vad_silence_duration_, vad_max_len_,
+ vad_speech_noise_thres_, vad_sample_rate_);
+ return vad_segments;
+}
+
+void FsmnVadOnline::InitCache(){
+ std::vector<float> cache_feats(128 * 19 * 1, 0);
+ for (int i=0;i<4;i++){
+ in_cache_.emplace_back(cache_feats);
+ }
+};
+
+void FsmnVadOnline::Reset(){
+ in_cache_.clear();
+ InitCache();
+};
+
+void FsmnVadOnline::Test() {
+}
+
+void FsmnVadOnline::InitOnline(std::shared_ptr<Ort::Session> &vad_session,
+ Ort::Env &env,
+ std::vector<const char *> &vad_in_names,
+ std::vector<const char *> &vad_out_names,
+ knf::FbankOptions &fbank_opts,
+ std::vector<float> &means_list,
+ std::vector<float> &vars_list,
+ int vad_sample_rate,
+ int vad_silence_duration,
+ int vad_max_len,
+ double vad_speech_noise_thres) {
+ vad_session_ = vad_session;
+ vad_in_names_ = vad_in_names;
+ vad_out_names_ = vad_out_names;
+ fbank_opts_ = fbank_opts;
+ means_list_ = means_list;
+ vars_list_ = vars_list;
+ vad_sample_rate_ = vad_sample_rate;
+ vad_silence_duration_ = vad_silence_duration;
+ vad_max_len_ = vad_max_len;
+ vad_speech_noise_thres_ = vad_speech_noise_thres;
+}
+
+FsmnVadOnline::~FsmnVadOnline() {
+}
+
+FsmnVadOnline::FsmnVadOnline(FsmnVad* fsmnvad_handle):fsmnvad_handle_(std::move(fsmnvad_handle)),session_options_{}{
+ InitCache();
+ InitOnline(fsmnvad_handle_->vad_session_,
+ fsmnvad_handle_->env_,
+ fsmnvad_handle_->vad_in_names_,
+ fsmnvad_handle_->vad_out_names_,
+ fsmnvad_handle_->fbank_opts_,
+ fsmnvad_handle_->means_list_,
+ fsmnvad_handle_->vars_list_,
+ fsmnvad_handle_->vad_sample_rate_,
+ fsmnvad_handle_->vad_silence_duration_,
+ fsmnvad_handle_->vad_max_len_,
+ fsmnvad_handle_->vad_speech_noise_thres_);
+}
+
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad-online.h b/funasr/runtime/onnxruntime/src/fsmn-vad-online.h
new file mode 100644
index 0000000..4d429b6
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad-online.h
@@ -0,0 +1,88 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#pragma once
+#include "precomp.h"
+
+namespace funasr {
+class FsmnVadOnline : public VadModel {
+/**
+ * Author: Speech Lab of DAMO Academy, Alibaba Group
+ * Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ * https://arxiv.org/abs/1803.05030
+*/
+
+public:
+ explicit FsmnVadOnline(FsmnVad* fsmnvad_handle);
+ ~FsmnVadOnline();
+ void Test();
+ std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
+ void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
+ void Reset();
+
+private:
+ E2EVadModel vad_scorer = E2EVadModel();
+ // std::unique_ptr<FsmnVad> fsmnvad_handle_;
+ FsmnVad* fsmnvad_handle_ = nullptr;
+
+ void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
+ std::vector<float> &waves);
+ int OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished);
+ void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num){}
+ void InitCache();
+ void InitOnline(std::shared_ptr<Ort::Session> &vad_session,
+ Ort::Env &env,
+ std::vector<const char *> &vad_in_names,
+ std::vector<const char *> &vad_out_names,
+ knf::FbankOptions &fbank_opts,
+ std::vector<float> &means_list,
+ std::vector<float> &vars_list,
+ int vad_sample_rate,
+ int vad_silence_duration,
+ int vad_max_len,
+ double vad_speech_noise_thres);
+
+ static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
+ int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
+ if (frame_num >= 1 && sample_length >= frame_sample_length)
+ return frame_num;
+ else
+ return 0;
+ }
+ void ResetCache() {
+ reserve_waveforms_.clear();
+ input_cache_.clear();
+ lfr_splice_cache_.clear();
+ }
+
+ // from fsmnvad_handle_
+ std::shared_ptr<Ort::Session> vad_session_ = nullptr;
+ Ort::Env env_;
+ Ort::SessionOptions session_options_;
+ std::vector<const char *> vad_in_names_;
+ std::vector<const char *> vad_out_names_;
+ knf::FbankOptions fbank_opts_;
+ std::vector<float> means_list_;
+ std::vector<float> vars_list_;
+
+ std::vector<std::vector<float>> in_cache_;
+ // The reserved waveforms by fbank
+ std::vector<float> reserve_waveforms_;
+ // waveforms reserved after last shift position
+ std::vector<float> input_cache_;
+ // lfr reserved cache
+ std::vector<std::vector<float>> lfr_splice_cache_;
+
+ int vad_sample_rate_ = MODEL_SAMPLE_RATE;
+ int vad_silence_duration_ = VAD_SILENCE_DURATION;
+ int vad_max_len_ = VAD_MAX_LEN;
+ double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
+ int lfr_m = VAD_LFR_M;
+ int lfr_n = VAD_LFR_N;
+ int frame_sample_length_ = vad_sample_rate_ / 1000 * 25;;
+ int frame_shift_sample_length_ = vad_sample_rate_ / 1000 * 10;
+};
+
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
index 516dc88..697828b 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -37,14 +37,14 @@
this->vad_max_len_ = post_conf["max_single_segment_time"].as<int>();
this->vad_speech_noise_thres_ = post_conf["speech_noise_thres"].as<double>();
- fbank_opts.frame_opts.dither = frontend_conf["dither"].as<float>();
- fbank_opts.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
- fbank_opts.frame_opts.samp_freq = (float)vad_sample_rate_;
- fbank_opts.frame_opts.window_type = frontend_conf["window"].as<string>();
- fbank_opts.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
- fbank_opts.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
- fbank_opts.energy_floor = 0;
- fbank_opts.mel_opts.debug_mel = false;
+ fbank_opts_.frame_opts.dither = frontend_conf["dither"].as<float>();
+ fbank_opts_.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
+ fbank_opts_.frame_opts.samp_freq = (float)vad_sample_rate_;
+ fbank_opts_.frame_opts.window_type = frontend_conf["window"].as<string>();
+ fbank_opts_.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
+ fbank_opts_.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
+ fbank_opts_.energy_floor = 0;
+ fbank_opts_.mel_opts.debug_mel = false;
}catch(exception const &e){
LOG(ERROR) << "Error when load argument from vad config YAML.";
exit(-1);
@@ -55,6 +55,7 @@
try {
vad_session_ = std::make_shared<Ort::Session>(
env_, vad_model, session_options_);
+ LOG(INFO) << "Successfully load model from " << vad_model;
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load vad onnx model: " << e.what();
exit(0);
@@ -109,7 +110,9 @@
void FsmnVad::Forward(
const std::vector<std::vector<float>> &chunk_feats,
- std::vector<std::vector<float>> *out_prob) {
+ std::vector<std::vector<float>> *out_prob,
+ std::vector<std::vector<float>> *in_cache,
+ bool is_final) {
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
@@ -132,9 +135,9 @@
// 4 caches
// cache node {batch,128,19,1}
const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
- for (int i = 0; i < in_cache_.size(); i++) {
+ for (int i = 0; i < in_cache->size(); i++) {
vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
- memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4)));
+ memory_info, (*in_cache)[i].data(), (*in_cache)[i].size(), cache_feats_shape, 4)));
}
// 4. Onnx infer
@@ -162,15 +165,17 @@
}
// get 4 caches outputs,each size is 128*19
- // for (int i = 1; i < 5; i++) {
- // float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
- // memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
- // }
+ if(!is_final){
+ for (int i = 1; i < 5; i++) {
+ float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
+ memcpy((*in_cache)[i-1].data(), data, sizeof(float) * 128*19);
+ }
+ }
}
void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
std::vector<float> &waves) {
- knf::OnlineFbank fbank(fbank_opts);
+ knf::OnlineFbank fbank(fbank_opts_);
std::vector<float> buf(waves.size());
for (int32_t i = 0; i != waves.size(); ++i) {
@@ -180,7 +185,7 @@
int32_t frames = fbank.NumFramesReady();
for (int32_t i = 0; i != frames; ++i) {
const float *frame = fbank.GetFrame(i);
- std::vector<float> frame_vector(frame, frame + fbank_opts.mel_opts.num_bins);
+ std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
vad_feats.emplace_back(frame_vector);
}
}
@@ -205,7 +210,7 @@
vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
if (means_lines[0] == "<LearnRateCoef>") {
for (int j = 3; j < means_lines.size() - 1; j++) {
- means_list.push_back(stof(means_lines[j]));
+ means_list_.push_back(stof(means_lines[j]));
}
continue;
}
@@ -216,8 +221,8 @@
vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
if (vars_lines[0] == "<LearnRateCoef>") {
for (int j = 3; j < vars_lines.size() - 1; j++) {
- // vars_list.push_back(stof(vars_lines[j])*scale);
- vars_list.push_back(stof(vars_lines[j]));
+ // vars_list_.push_back(stof(vars_lines[j])*scale);
+ vars_list_.push_back(stof(vars_lines[j]));
}
continue;
}
@@ -263,8 +268,8 @@
}
// Apply cmvn
for (auto &out_feat: out_feats) {
- for (int j = 0; j < means_list.size(); j++) {
- out_feat[j] = (out_feat[j] + means_list[j]) * vars_list[j];
+ for (int j = 0; j < means_list_.size(); j++) {
+ out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
}
}
vad_feats = out_feats;
@@ -276,7 +281,7 @@
std::vector<std::vector<float>> vad_probs;
FbankKaldi(vad_sample_rate_, vad_feats, waves);
LfrCmvn(vad_feats);
- Forward(vad_feats, &vad_probs);
+ Forward(vad_feats, &vad_probs, &in_cache_, input_finished);
E2EVadModel vad_scorer = E2EVadModel();
std::vector<std::vector<int>> vad_segments;
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.h b/funasr/runtime/onnxruntime/src/fsmn-vad.h
index a8ec4ce..adceb1f 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.h
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.h
@@ -22,7 +22,30 @@
void Test();
void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num);
std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true);
+ void Forward(
+ const std::vector<std::vector<float>> &chunk_feats,
+ std::vector<std::vector<float>> *out_prob,
+ std::vector<std::vector<float>> *in_cache,
+ bool is_final);
void Reset();
+
+ std::shared_ptr<Ort::Session> vad_session_ = nullptr;
+ Ort::Env env_;
+ Ort::SessionOptions session_options_;
+ std::vector<const char *> vad_in_names_;
+ std::vector<const char *> vad_out_names_;
+ std::vector<std::vector<float>> in_cache_;
+
+ knf::FbankOptions fbank_opts_;
+ std::vector<float> means_list_;
+ std::vector<float> vars_list_;
+
+ int vad_sample_rate_ = MODEL_SAMPLE_RATE;
+ int vad_silence_duration_ = VAD_SILENCE_DURATION;
+ int vad_max_len_ = VAD_MAX_LEN;
+ double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
+ int lfr_m = VAD_LFR_M;
+ int lfr_n = VAD_LFR_N;
private:
@@ -37,31 +60,9 @@
std::vector<float> &waves);
void LfrCmvn(std::vector<std::vector<float>> &vad_feats);
-
- void Forward(
- const std::vector<std::vector<float>> &chunk_feats,
- std::vector<std::vector<float>> *out_prob);
-
void LoadCmvn(const char *filename);
void InitCache();
- std::shared_ptr<Ort::Session> vad_session_ = nullptr;
- Ort::Env env_;
- Ort::SessionOptions session_options_;
- std::vector<const char *> vad_in_names_;
- std::vector<const char *> vad_out_names_;
- std::vector<std::vector<float>> in_cache_;
-
- knf::FbankOptions fbank_opts;
- std::vector<float> means_list;
- std::vector<float> vars_list;
-
- int vad_sample_rate_ = MODEL_SAMPLE_RATE;
- int vad_silence_duration_ = VAD_SILENCE_DURATION;
- int vad_max_len_ = VAD_MAX_LEN;
- double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
- int lfr_m = VAD_LFR_M;
- int lfr_n = VAD_LFR_N;
};
} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/funasrruntime.cpp b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
index adef504..f504b39 100644
--- a/funasr/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -11,9 +11,15 @@
return mm;
}
- _FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num, FSMN_VAD_MODE mode)
+ _FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num)
{
- funasr::VadModel* mm = funasr::CreateVadModel(model_path, thread_num, mode);
+ funasr::VadModel* mm = funasr::CreateVadModel(model_path, thread_num);
+ return mm;
+ }
+
+ _FUNASRAPI FUNASR_HANDLE FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle)
+ {
+ funasr::VadModel* mm = funasr::CreateVadModel(fsmnvad_handle);
return mm;
}
@@ -96,7 +102,7 @@
}
// APIs for VAD Infer
- _FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
+ _FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate)
{
funasr::VadModel* vad_obj = (funasr::VadModel*)handle;
if (!vad_obj)
@@ -110,13 +116,13 @@
p_result->snippet_time = audio.GetTimeLen();
vector<std::vector<int>> vad_segments;
- audio.Split(vad_obj, vad_segments);
+ audio.Split(vad_obj, vad_segments, input_finished);
p_result->segments = new vector<std::vector<int>>(vad_segments);
return p_result;
}
- _FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
+ _FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate)
{
funasr::VadModel* vad_obj = (funasr::VadModel*)handle;
if (!vad_obj)
@@ -139,7 +145,7 @@
p_result->snippet_time = audio.GetTimeLen();
vector<std::vector<int>> vad_segments;
- audio.Split(vad_obj, vad_segments);
+ audio.Split(vad_obj, vad_segments, true);
p_result->segments = new vector<std::vector<int>>(vad_segments);
return p_result;
diff --git a/funasr/runtime/onnxruntime/src/online-feature.cpp b/funasr/runtime/onnxruntime/src/online-feature.cpp
deleted file mode 100644
index a21589c..0000000
--- a/funasr/runtime/onnxruntime/src/online-feature.cpp
+++ /dev/null
@@ -1,137 +0,0 @@
-/**
- * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
- * MIT License (https://opensource.org/licenses/MIT)
- * Contributed by zhuzizyf(China Telecom).
-*/
-
-#include "online-feature.h"
-#include <utility>
-
-namespace funasr {
-OnlineFeature::OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m, int lfr_n,
- std::vector<std::vector<float>> cmvns)
- : sample_rate_(sample_rate),
- fbank_opts_(std::move(fbank_opts)),
- lfr_m_(lfr_m),
- lfr_n_(lfr_n),
- cmvns_(std::move(cmvns)) {
- frame_sample_length_ = sample_rate_ / 1000 * 25;;
- frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
-}
-
-void OnlineFeature::ExtractFeats(vector<std::vector<float>> &vad_feats,
- vector<float> waves, bool input_finished) {
- input_finished_ = input_finished;
- OnlineFbank(vad_feats, waves);
- // cache deal & online lfr,cmvn
- if (vad_feats.size() > 0) {
- if (!reserve_waveforms_.empty()) {
- waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
- }
- if (lfr_splice_cache_.empty()) {
- for (int i = 0; i < (lfr_m_ - 1) / 2; i++) {
- lfr_splice_cache_.emplace_back(vad_feats[0]);
- }
- }
- if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m_) {
- vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
- int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
- int minus_frame = reserve_waveforms_.empty() ? (lfr_m_ - 1) / 2 : 0;
- int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats);
- int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
- reserve_waveforms_.clear();
- reserve_waveforms_.insert(reserve_waveforms_.begin(),
- waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
- waves.begin() + frame_from_waves * frame_shift_sample_length_);
- int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
- waves.erase(waves.begin() + sample_length, waves.end());
- } else {
- reserve_waveforms_.clear();
- reserve_waveforms_.insert(reserve_waveforms_.begin(),
- waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
- lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end());
- }
-
- } else {
- if (input_finished_) {
- if (!reserve_waveforms_.empty()) {
- waves = reserve_waveforms_;
- }
- vad_feats = lfr_splice_cache_;
- OnlineLfrCmvn(vad_feats);
- ResetCache();
- }
- }
-
-}
-
-int OnlineFeature::OnlineLfrCmvn(vector<vector<float>> &vad_feats) {
- vector<vector<float>> out_feats;
- int T = vad_feats.size();
- int T_lrf = ceil((T - (lfr_m_ - 1) / 2) / lfr_n_);
- int lfr_splice_frame_idxs = T_lrf;
- vector<float> p;
- for (int i = 0; i < T_lrf; i++) {
- if (lfr_m_ <= T - i * lfr_n_) {
- for (int j = 0; j < lfr_m_; j++) {
- p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end());
- }
- out_feats.emplace_back(p);
- p.clear();
- } else {
- if (input_finished_) {
- int num_padding = lfr_m_ - (T - i * lfr_n_);
- for (int j = 0; j < (vad_feats.size() - i * lfr_n_); j++) {
- p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end());
- }
- for (int j = 0; j < num_padding; j++) {
- p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
- }
- out_feats.emplace_back(p);
- } else {
- lfr_splice_frame_idxs = i;
- break;
- }
- }
- }
- lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n_);
- lfr_splice_cache_.clear();
- lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end());
-
- // Apply cmvn
- for (auto &out_feat: out_feats) {
- for (int j = 0; j < cmvns_[0].size(); j++) {
- out_feat[j] = (out_feat[j] + cmvns_[0][j]) * cmvns_[1][j];
- }
- }
- vad_feats = out_feats;
- return lfr_splice_frame_idxs;
-}
-
-void OnlineFeature::OnlineFbank(vector<std::vector<float>> &vad_feats,
- vector<float> &waves) {
-
- knf::OnlineFbank fbank(fbank_opts_);
- // cache merge
- waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
- int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
- // Send the audio after the last frame shift position to the cache
- input_cache_.clear();
- input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
- if (frame_number == 0) {
- return;
- }
- // Delete audio that haven't undergone fbank processing
- waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
-
- fbank.AcceptWaveform(sample_rate_, &waves[0], waves.size());
- int32_t frames = fbank.NumFramesReady();
- for (int32_t i = 0; i != frames; ++i) {
- const float *frame = fbank.GetFrame(i);
- vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
- vad_feats.emplace_back(frame_vector);
- }
-
-}
-
-} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/online-feature.h b/funasr/runtime/onnxruntime/src/online-feature.h
deleted file mode 100644
index 16e6e4b..0000000
--- a/funasr/runtime/onnxruntime/src/online-feature.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/**
- * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
- * MIT License (https://opensource.org/licenses/MIT)
- * Contributed by zhuzizyf(China Telecom).
-*/
-#pragma once
-#include <vector>
-#include "precomp.h"
-
-using namespace std;
-namespace funasr {
-class OnlineFeature {
-
-public:
- OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m_, int lfr_n_,
- std::vector<std::vector<float>> cmvns_);
-
- void ExtractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
-
-private:
- void OnlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
- int OnlineLfrCmvn(vector<vector<float>> &vad_feats);
-
- static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
- int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
- if (frame_num >= 1 && sample_length >= frame_sample_length)
- return frame_num;
- else
- return 0;
- }
-
- void ResetCache() {
- reserve_waveforms_.clear();
- input_cache_.clear();
- lfr_splice_cache_.clear();
- input_finished_ = false;
-
- }
-
- knf::FbankOptions fbank_opts_;
- // The reserved waveforms by fbank
- std::vector<float> reserve_waveforms_;
- // waveforms reserved after last shift position
- std::vector<float> input_cache_;
- // lfr reserved cache
- std::vector<std::vector<float>> lfr_splice_cache_;
- std::vector<std::vector<float>> cmvns_;
-
- int sample_rate_ = 16000;
- int frame_sample_length_ = sample_rate_ / 1000 * 25;;
- int frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
- int lfr_m_;
- int lfr_n_;
- bool input_finished_ = false;
-
-};
-
-} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/paraformer.h b/funasr/runtime/onnxruntime/src/paraformer.h
index 533c16f..9df0977 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.h
+++ b/funasr/runtime/onnxruntime/src/paraformer.h
@@ -18,7 +18,7 @@
//std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions fbank_opts;
- Vocab* vocab;
+ Vocab* vocab = nullptr;
vector<float> means_list;
vector<float> vars_list;
const float scale = 22.6274169979695;
@@ -30,7 +30,7 @@
void ApplyCmvn(vector<float> *v);
string GreedySearch( float* in, int n_len, int64_t token_nums);
- std::shared_ptr<Ort::Session> m_session;
+ std::shared_ptr<Ort::Session> m_session = nullptr;
Ort::Env env_;
Ort::SessionOptions session_options;
diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h
index e607dbf..838dddc 100644
--- a/funasr/runtime/onnxruntime/src/precomp.h
+++ b/funasr/runtime/onnxruntime/src/precomp.h
@@ -36,8 +36,9 @@
#include "offline-stream.h"
#include "tokenizer.h"
#include "ct-transformer.h"
-#include "fsmn-vad.h"
#include "e2e-vad.h"
+#include "fsmn-vad.h"
+#include "fsmn-vad-online.h"
#include "vocab.h"
#include "audio.h"
#include "tensor.h"
diff --git a/funasr/runtime/onnxruntime/src/vad-model.cpp b/funasr/runtime/onnxruntime/src/vad-model.cpp
index 336758f..c164c3e 100644
--- a/funasr/runtime/onnxruntime/src/vad-model.cpp
+++ b/funasr/runtime/onnxruntime/src/vad-model.cpp
@@ -1,14 +1,10 @@
#include "precomp.h"
namespace funasr {
-VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num, int mode)
+VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num)
{
VadModel *mm;
- if(mode == FSMN_VAD_OFFLINE){
- mm = new FsmnVad();
- }else{
- LOG(ERROR)<<"Online fsmn vad not imp!";
- }
+ mm = new FsmnVad();
string vad_model_path;
string vad_cmvn_path;
@@ -25,4 +21,11 @@
return mm;
}
+VadModel *CreateVadModel(void* fsmnvad_handle)
+{
+ VadModel *mm;
+ mm = new FsmnVadOnline((FsmnVad*)fsmnvad_handle);
+ return mm;
+}
+
} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/python/websocket/wss_srv_asr.py b/funasr/runtime/python/websocket/wss_srv_asr.py
index 6460fbf..3810cd6 100644
--- a/funasr/runtime/python/websocket/wss_srv_asr.py
+++ b/funasr/runtime/python/websocket/wss_srv_asr.py
@@ -58,16 +58,36 @@
model=args.asr_model_online,
ngpu=args.ngpu,
ncpu=args.ncpu,
- model_revision='v1.0.6',
+ model_revision='v1.0.4',
+ update_model='v1.0.4',
mode='paraformer_streaming')
-print("model loaded")
+print("model loaded! only support one client at the same time now!!!!")
+async def ws_reset(websocket):
+ print("ws reset now, total num is ",len(websocket_users))
+ websocket.param_dict_asr_online = {"cache": dict()}
+ websocket.param_dict_vad = {'in_cache': dict(), "is_final": True}
+ websocket.param_dict_asr_online["is_final"]=True
+ audio_in=b''.join(np.zeros(int(16000),dtype=np.int16))
+ inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
+ inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online)
+ await websocket.close()
+
+
+async def clear_websocket():
+ for websocket in websocket_users:
+ await ws_reset(websocket)
+ websocket_users.clear()
+
+
+
async def ws_serve(websocket, path):
frames = []
frames_asr = []
frames_asr_online = []
global websocket_users
+ await clear_websocket()
websocket_users.add(websocket)
websocket.param_dict_asr = {}
websocket.param_dict_asr_online = {"cache": dict()}
@@ -75,7 +95,7 @@
websocket.param_dict_punc = {'cache': list()}
websocket.vad_pre_idx = 0
speech_start = False
- speech_end_i = False
+ speech_end_i = -1
websocket.wav_name = "microphone"
websocket.mode = "2pass"
print("new user connected", flush=True)
@@ -104,7 +124,7 @@
# asr online
frames_asr_online.append(message)
- websocket.param_dict_asr_online["is_final"] = speech_end_i
+ websocket.param_dict_asr_online["is_final"] = speech_end_i != -1
if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.param_dict_asr_online["is_final"]:
if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr_online)
@@ -114,14 +134,14 @@
frames_asr.append(message)
# vad online
speech_start_i, speech_end_i = await async_vad(websocket, message)
- if speech_start_i:
+ if speech_start_i != -1:
speech_start = True
beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
frames_pre = frames[-beg_bias:]
frames_asr = []
frames_asr.extend(frames_pre)
# asr punc offline
- if speech_end_i or not websocket.is_speaking:
+ if speech_end_i != -1 or not websocket.is_speaking:
# print("vad end point")
if websocket.mode == "2pass" or websocket.mode == "offline":
audio_in = b"".join(frames_asr)
@@ -139,7 +159,8 @@
except websockets.ConnectionClosed:
- print("ConnectionClosed...", websocket_users)
+ print("ConnectionClosed...", websocket_users,flush=True)
+ await ws_reset(websocket)
websocket_users.remove(websocket)
except websockets.InvalidState:
print("InvalidState...")
@@ -151,15 +172,15 @@
segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
- speech_start = False
- speech_end = False
+ speech_start = -1
+ speech_end = -1
if len(segments_result) == 0 or len(segments_result["text"]) > 1:
return speech_start, speech_end
if segments_result["text"][0][0] != -1:
speech_start = segments_result["text"][0][0]
if segments_result["text"][0][1] != -1:
- speech_end = True
+ speech_end = segments_result["text"][0][1]
return speech_start, speech_end
diff --git a/funasr/runtime/websocket/CMakeLists.txt b/funasr/runtime/websocket/CMakeLists.txt
index 8217b30..58ca972 100644
--- a/funasr/runtime/websocket/CMakeLists.txt
+++ b/funasr/runtime/websocket/CMakeLists.txt
@@ -2,6 +2,8 @@
project(FunASRWebscoket)
+set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
+set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
diff --git a/funasr/version.txt b/funasr/version.txt
index d3532a1..659914a 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.5.7
+0.5.8
--
Gitblit v1.9.1