nichongjia-2007
2023-06-07 c7fc6149b3c5c2de3107c4f1d4983309882d1a1a
Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR
39个文件已修改
2个文件已删除
13个文件已添加
1 文件已复制
4 文件已重命名
1672 ■■■■ 已修改文件
MODEL_LICENSE 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/model_zoo/modelscope_models.md 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/rnnt/README.md 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/rnnt/run.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/librispeech/rnnt/README.md 18 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/librispeech/rnnt/conf/decode_rnnt_conformer_streaming.yaml 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/librispeech/rnnt/conf/train_conformer_rnnt_unified.yaml 98 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/librispeech/rnnt/local/data_prep.sh 58 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/librispeech/rnnt/local/download_and_untar.sh 97 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/librispeech/rnnt/local/spm_encode.py 98 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/librispeech/rnnt/local/spm_train.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/librispeech/rnnt/path.sh 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/librispeech/rnnt/run.sh 222 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/librispeech/rnnt/utils 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/librispeech_100h/rnnt/README.md 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/TEMPLATE/README.md 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/tp/TEMPLATE/README.md 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
fun_text_processing/inverse_text_normalization/id/id_unit_test.tsv 16 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
fun_text_processing/inverse_text_normalization/id/taggers/cardinal.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/build_dataloader.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/html5/static/main.js 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/html5/static/wsconnecter.js 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/CMakeLists.txt 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/CMakeLists.txt 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp 20 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp 89 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/audio.h 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/funasrruntime.h 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/vad-model.h 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/CMakeLists.txt 17 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/audio.cpp 78 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/ct-transformer.cpp 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp 198 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/fsmn-vad-online.h 88 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/fsmn-vad.cpp 51 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/fsmn-vad.h 45 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/funasrruntime.cpp 18 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/online-feature.cpp 137 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/online-feature.h 58 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer.h 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/precomp.h 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/vad-model.cpp 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/wss_srv_asr.py 41 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/CMakeLists.txt 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/version.txt 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
MODEL_LICENSE
@@ -9,7 +9,7 @@
为了保证更好的社区合作,我们制定了以下协议,希望您仔细阅读并遵守本协议。
1 定义
本协议中,[FunASR 软件]指 FunASR 开源模型及其衍生品,包括 Finetune 后的模型;[您]指使用、修改、分享和学习[FunASR 软件]的个人或组织。
本协议中,[FunASR 软件]指 FunASR 开源模型权重及其衍生品,包括 Finetune 后的模型;[您]指使用、修改、分享和学习[FunASR 软件]的个人或组织。
2 许可和限制
2.1 许可
@@ -18,7 +18,7 @@
2.2 限制
您在使用、复制、修改和分享[FunASR 软件]时,必须注明出处以及作者信息。并且,将[FunASR 软件]上传至其他第三方平台以供下载,需要获得额外许可,可通过官方邮件(funasr@list.alibaba-inc.com)进行申请(免费)。
您在使用、复制、修改和分享[FunASR 软件]时,必须注明出处以及作者信息,并保留[FunASR 软件]中相关模型名称。
3 责任和风险承担
[FunASR 软件]仅作为参考和学习使用,不对您使用或修改[FunASR 软件]造成的任何直接或间接损失承担任何责任。您对[FunASR 软件]的使用和修改应该自行承担风险。
@@ -54,7 +54,7 @@
You are free to use, copy, modify, and share [FunASR software] under the conditions of this agreement.
2.2 Restrictions
You should indicate the code and model source and author information when using, copying, modifying and sharing [FunASR software]. To upload the [FunASR software] to other third-party platforms for download, an additional license is required, which can be applied for free by sending an email to the official email address (funasr@list.alibaba-inc.com).
You should indicate the code and model source and author information when using, copying, modifying and sharing [FunASR software]. You should keep the relevant names of models in [FunASR software].
3 Responsibility and Risk
[FunASR software] is for reference and learning purposes only and is not responsible for any direct or indirect losses caused by your use or modification of [FunASR software]. You should take responsibility and risks for your use and modification of [FunASR software].
README.md
@@ -13,7 +13,6 @@
| [**Highlights**](#highlights)
| [**Installation**](#installation)
| [**Docs**](https://alibaba-damo-academy.github.io/FunASR/en/index.html)
| [**Tutorial_CN**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C)
| [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
| [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
| [**Model Zoo**](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/model_zoo/modelscope_models.md)
docs/model_zoo/modelscope_models.md
@@ -1,8 +1,10 @@
# Pretrained Models on ModelScope
## Model License
You are free to use, copy, modify, and share FunASR models under the conditions of this agreement. You should indicate the model source and author information when using, copying, modifying and sharing FunASR models. To upload the FunASR models to other third-party platforms for download, an additional license is required, which can be applied for free by sending an email to the official email address (funasr@list.alibaba-inc.com). Full model license could see [license](https://github.com/alibaba-damo-academy/FunASR/blob/main/MODEL_LICENSE)
You are free to use, copy, modify, and share FunASR models under the conditions of this agreement. You should indicate the model source and author information when using, copying, modifying and sharing FunASR models. You should keep the relevant names of models in [FunASR software].. Full model license could see [license](https://github.com/alibaba-damo-academy/FunASR/blob/main/MODEL_LICENSE)
## Model Usage
Ref to [docs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)
## Model Zoo
Here we provided several pretrained models on different datasets. The details of models and datasets can be found on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition).
egs/aishell/rnnt/README.md
@@ -5,14 +5,14 @@
- 8 gpu(Tesla V100)
- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
- Train config: conf/train_conformer_rnnt_unified.yaml
- chunk config: chunk size 16, full left chunk
- chunk config: chunk size 16, 1 left chunk
- LM config: LM was not used
- Model size: 90M
## Results (CER)
- Decode config: conf/train_conformer_rnnt_unified.yaml
- Decode config: conf/decode_rnnt_conformer_streaming.yaml
|   testset   | CER(%)  |
|   testset   |  CER(%) |
|:-----------:|:-------:|
|     dev     |  5.53   |
|    test     |  6.24   |
|     dev     |  5.43   |
|    test     |  6.04   |
egs/aishell/rnnt/run.sh
@@ -4,7 +4,7 @@
# machines configuration
CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
gpu_num=2
gpu_num=8
count=1
gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
egs/librispeech/rnnt/README.md
New file
@@ -0,0 +1,18 @@
# Streaming RNN-T Result
## Training Config
- 8 gpu(Tesla V100)
- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
- Train config: conf/train_conformer_rnnt_unified.yaml
- chunk config: chunk size 16, 1 left chunk
- LM config: LM was not used
- Model size: 90M
## Results (CER)
- Decode config: conf/decode_rnnt_conformer_streaming.yaml
|      testset   |  WER(%) |
|:--------------:|:-------:|
|    test_clean  |   3.58  |
|    test_other  |   9.27  |
egs/librispeech/rnnt/conf/decode_rnnt_conformer_streaming.yaml
New file
@@ -0,0 +1,8 @@
# The conformer transducer decoding configuration from @jeon30c
beam_size: 10
simu_streaming: false
streaming: true
chunk_size: 16
left_context: 16
right_context: 0
egs/librispeech/rnnt/conf/train_conformer_rnnt_unified.yaml
New file
@@ -0,0 +1,98 @@
encoder: chunk_conformer
encoder_conf:
      activation_type: swish
      time_reduction_factor: 2
      unified_model_training: true
      default_chunk_size: 16
      jitter_range: 4
      left_chunk_size: 1
      embed_vgg_like: false
      subsampling_factor: 4
      linear_units: 2048
      output_size: 512
      attention_heads: 8
      dropout_rate: 0.5
      positional_dropout_rate: 0.5
      attention_dropout_rate: 0.5
      cnn_module_kernel: 15
      num_blocks: 12
# decoder related
rnnt_decoder: rnnt
rnnt_decoder_conf:
    embed_size: 512
    hidden_size: 512
    embed_dropout_rate: 0.5
    dropout_rate: 0.5
    use_embed_mask: true
joint_network_conf:
    joint_space_size: 512
# frontend related
frontend: wav_frontend
frontend_conf:
    fs: 16000
    window: hamming
    n_mels: 80
    frame_length: 25
    frame_shift: 10
    lfr_m: 1
    lfr_n: 1
# Auxiliary CTC
model: rnnt_unified
model_conf:
    auxiliary_ctc_weight: 0.0
# minibatch related
use_amp: true
# optimization related
accum_grad: 4
grad_clip: 5
max_epoch: 100
val_scheduler_criterion:
    - valid
    - loss
best_model_criterion:
-   - valid
    - cer_transducer_chunk
    - min
keep_nbest_models: 10
optim: adam
optim_conf:
   lr: 0.001
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 25000
specaug: specaug
specaug_conf:
    apply_time_warp: true
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 40
    num_freq_mask: 2
    apply_time_mask: true
    time_mask_width_range:
    - 0
    - 50
    num_time_mask: 5
dataset_conf:
    shuffle: True
    shuffle_conf:
        shuffle_size: 1024
        sort_size: 500
    batch_conf:
        batch_type: token
        batch_size: 10000
    num_workers: 8
log_interval: 50
normalize: None
egs/librispeech/rnnt/local/data_prep.sh
New file
@@ -0,0 +1,58 @@
#!/usr/bin/env bash
# Copyright 2014  Vassil Panayotov
#           2014  Johns Hopkins University (author: Daniel Povey)
# Apache 2.0
if [ "$#" -ne 2 ]; then
  echo "Usage: $0 <src-dir> <dst-dir>"
  echo "e.g.: $0 /export/a15/vpanayotov/data/LibriSpeech/dev-clean data/dev-clean"
  exit 1
fi
src=$1
dst=$2
# all utterances are FLAC compressed
if ! which flac >&/dev/null; then
   echo "Please install 'flac' on ALL worker nodes!"
   exit 1
fi
spk_file=$src/../SPEAKERS.TXT
mkdir -p $dst || exit 1
[ ! -d $src ] && echo "$0: no such directory $src" && exit 1
[ ! -f $spk_file ] && echo "$0: expected file $spk_file to exist" && exit 1
wav_scp=$dst/wav.scp; [[ -f "$wav_scp" ]] && rm $wav_scp
trans=$dst/text; [[ -f "$trans" ]] && rm $trans
for reader_dir in $(find -L $src -mindepth 1 -maxdepth 1 -type d | sort); do
  reader=$(basename $reader_dir)
  if ! [ $reader -eq $reader ]; then  # not integer.
    echo "$0: unexpected subdirectory name $reader"
    exit 1
  fi
  for chapter_dir in $(find -L $reader_dir/ -mindepth 1 -maxdepth 1 -type d | sort); do
    chapter=$(basename $chapter_dir)
    if ! [ "$chapter" -eq "$chapter" ]; then
      echo "$0: unexpected chapter-subdirectory name $chapter"
      exit 1
    fi
    find -L $chapter_dir/ -iname "*.flac" | sort | xargs -I% basename % .flac | \
      awk -v "dir=$chapter_dir" '{printf "%s %s/%s.flac \n", $0, dir, $0}' >>$wav_scp|| exit 1
    chapter_trans=$chapter_dir/${reader}-${chapter}.trans.txt
    [ ! -f  $chapter_trans ] && echo "$0: expected file $chapter_trans to exist" && exit 1
    cat $chapter_trans >>$trans
  done
done
echo "$0: successfully prepared data in $dst"
exit 0
egs/librispeech/rnnt/local/download_and_untar.sh
New file
@@ -0,0 +1,97 @@
#!/usr/bin/env bash
# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
# Apache 2.0
remove_archive=false
if [ "$1" == --remove-archive ]; then
  remove_archive=true
  shift
fi
if [ $# -ne 3 ]; then
  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
  echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
  echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
  echo "          train-clean-100, train-clean-360, train-other-500."
  exit 1
fi
data=$1
url=$2
part=$3
if [ ! -d "$data" ]; then
  echo "$0: no such directory $data"
  exit 1
fi
part_ok=false
list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
for x in $list; do
  if [ "$part" == $x ]; then part_ok=true; fi
done
if ! $part_ok; then
  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
  exit 1
fi
if [ -z "$url" ]; then
  echo "$0: empty URL base."
  exit 1
fi
if [ -f $data/LibriSpeech/$part/.complete ]; then
  echo "$0: data part $part was already successfully extracted, nothing to do."
  exit 0
fi
# sizes of the archive files in bytes.  This is some older versions.
sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
# sizes_new is the archive file sizes of the final release.  Some of these sizes are of
# things we probably won't download.
sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
if [ -f $data/$part.tar.gz ]; then
  size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
  size_ok=false
  for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
  if ! $size_ok; then
    echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
    echo "does not equal the size of one of the archives."
    rm $data/$part.tar.gz
  else
    echo "$data/$part.tar.gz exists and appears to be complete."
  fi
fi
if [ ! -f $data/$part.tar.gz ]; then
  if ! which wget >/dev/null; then
    echo "$0: wget is not installed."
    exit 1
  fi
  full_url=$url/$part.tar.gz
  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
  if ! wget -P $data --no-check-certificate $full_url; then
    echo "$0: error executing wget $full_url"
    exit 1
  fi
fi
if ! tar -C $data -xvzf $data/$part.tar.gz; then
  echo "$0: error un-tarring archive $data/$part.tar.gz"
  exit 1
fi
touch $data/LibriSpeech/$part/.complete
echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
if $remove_archive; then
  echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
  rm $data/$part.tar.gz
fi
egs/librispeech/rnnt/local/spm_encode.py
New file
@@ -0,0 +1,98 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in
# https://github.com/pytorch/fairseq/blob/master/LICENSE
import argparse
import contextlib
import sys
import sentencepiece as spm
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", required=True,
                        help="sentencepiece model to use for encoding")
    parser.add_argument("--inputs", nargs="+", default=['-'],
                        help="input files to filter/encode")
    parser.add_argument("--outputs", nargs="+", default=['-'],
                        help="path to save encoded outputs")
    parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
    parser.add_argument("--min-len", type=int, metavar="N",
                        help="filter sentence pairs with fewer than N tokens")
    parser.add_argument("--max-len", type=int, metavar="N",
                        help="filter sentence pairs with more than N tokens")
    args = parser.parse_args()
    assert len(args.inputs) == len(args.outputs), \
        "number of input and output paths should match"
    sp = spm.SentencePieceProcessor()
    sp.Load(args.model)
    if args.output_format == "piece":
        def encode(l):
            return sp.EncodeAsPieces(l)
    elif args.output_format == "id":
        def encode(l):
            return list(map(str, sp.EncodeAsIds(l)))
    else:
        raise NotImplementedError
    if args.min_len is not None or args.max_len is not None:
        def valid(line):
            return (
                (args.min_len is None or len(line) >= args.min_len) and
                (args.max_len is None or len(line) <= args.max_len)
            )
    else:
        def valid(lines):
            return True
    with contextlib.ExitStack() as stack:
        inputs = [
            stack.enter_context(open(input, "r", encoding="utf-8"))
            if input != "-" else sys.stdin
            for input in args.inputs
        ]
        outputs = [
            stack.enter_context(open(output, "w", encoding="utf-8"))
            if output != "-" else sys.stdout
            for output in args.outputs
        ]
        stats = {
            "num_empty": 0,
            "num_filtered": 0,
        }
        def encode_line(line):
            line = line.strip()
            if len(line) > 0:
                line = encode(line)
                if valid(line):
                    return line
                else:
                    stats["num_filtered"] += 1
            else:
                stats["num_empty"] += 1
            return None
        for i, lines in enumerate(zip(*inputs), start=1):
            enc_lines = list(map(encode_line, lines))
            if not any(enc_line is None for enc_line in enc_lines):
                for enc_line, output_h in zip(enc_lines, outputs):
                    print(" ".join(enc_line), file=output_h)
            if i % 10000 == 0:
                print("processed {} lines".format(i), file=sys.stderr)
        print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
        print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
if __name__ == "__main__":
    main()
egs/librispeech/rnnt/local/spm_train.py
New file
@@ -0,0 +1,12 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# https://github.com/pytorch/fairseq/blob/master/LICENSE
import sys
import sentencepiece as spm
if __name__ == "__main__":
    spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
egs/librispeech/rnnt/path.sh
New file
@@ -0,0 +1,5 @@
export FUNASR_DIR=$PWD/../../..
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH
egs/librispeech/rnnt/run.sh
New file
@@ -0,0 +1,222 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
gpu_num=8
count=1
gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=5
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="" #feature output dictionary
exp_dir=""
lang=en
token_type=bpe
type=sound
scp=wav.scp
speed_perturb="0.9 1.0 1.1"
stage=0
stop_stage=5
# feature configuration
feats_dim=80
nj=64
# data
raw_data=
data_url=www.openslr.org/resources/12
# bpe model
nbpe=5000
bpemode=unigram
# exp tag
tag="exp1"
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=train_960
valid_set=dev
test_sets="test_clean test_other dev_clean dev_other"
asr_config=conf/train_conformer_rnnt_unified.yaml
model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_rnnt_conformer_streaming.yaml
inference_asr_model=valid.cer_transducer_chunk.ave_10best.pb
# you can set gpu num for decoding here
gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
    inference_nj=$[${ngpu}*${njob}]
    _ngpu=1
else
    inference_nj=$njob
    _ngpu=0
fi
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
    echo "stage -1: Data Download"
    for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
        local/download_and_untar.sh ${raw_data} ${data_url} ${part}
    done
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
    echo "stage 0: Data preparation"
    # Data preparation
    for x in dev-clean dev-other test-clean test-other train-clean-100 train-clean-360 train-other-500; do
        local/data_prep.sh ${raw_data}/LibriSpeech/${x} ${feats_dir}/data/${x//-/_}
    done
    mkdir $feats_dir/data/$valid_set
    dev_sets="dev_clean dev_other"
    for file in wav.scp text; do
        ( for f in $dev_sets; do cat $feats_dir/data/$f/$file; done ) | sort -k1 > $feats_dir/data/$valid_set/$file || exit 1;
    done
    mkdir $feats_dir/data/$train_set
    train_sets="train_clean_100 train_clean_360 train_other_500"
    for file in wav.scp text; do
        ( for f in $train_sets; do cat $feats_dir/data/$f/$file; done ) | sort -k1 > $feats_dir/data/$train_set/$file || exit 1;
    done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    echo "stage 1: Feature and CMVN Generation"
    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
bpemodel=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}
echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
    ### Task dependent. You have to check non-linguistic symbols used in the corpus.
    echo "stage 2: Dictionary and Json Data Preparation"
    mkdir -p ${feats_dir}/data/lang_char/
    echo "<blank>" > ${token_list}
    echo "<s>" >> ${token_list}
    echo "</s>" >> ${token_list}
    cut -f 2- -d" " ${feats_dir}/data/${train_set}/text > ${feats_dir}/data/lang_char/input.txt
    local/spm_train.py --input=${feats_dir}/data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
    local/spm_encode.py --model=${bpemodel}.model --output_format=piece < ${feats_dir}/data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0}' >> ${token_list}
    echo "<unk>" >> ${token_list}
fi
# LM Training Stage
world_size=$gpu_num  # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    echo "stage 3: LM Training"
fi
# ASR Training Stage
world_size=$gpu_num  # run on one machine
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
    echo "stage 4: ASR Training"
    mkdir -p ${exp_dir}/exp/${model_dir}
    mkdir -p ${exp_dir}/exp/${model_dir}/log
    INIT_FILE=./ddp_init
    if [ -f $INIT_FILE ];then
        rm -f $INIT_FILE
    fi
    init_method=file://$(readlink -f $INIT_FILE)
    echo "$0: init method is $init_method"
    for ((i = 0; i < $gpu_num; ++i)); do
        {
            rank=$i
            local_rank=$i
            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
            train.py \
                --task_name asr \
                --gpu_id $gpu_id \
                --use_preprocessor true \
                --split_with_space false \
                --bpemodel ${bpemodel}.model \
                --token_type $token_type \
                --token_list $token_list \
                --data_dir ${feats_dir}/data \
                --train_set ${train_set} \
                --valid_set ${valid_set} \
                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
                --speed_perturb ${speed_perturb} \
                --resume true \
                --output_dir ${exp_dir}/exp/${model_dir} \
                --config $asr_config \
                --ngpu $gpu_num \
                --num_worker_count $count \
                --multiprocessing_distributed true \
                --dist_init_method $init_method \
                --dist_world_size $world_size \
                --dist_rank $rank \
                --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
        } &
        done
        wait
fi
# Testing Stage
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
    echo "stage 5: Inference"
    for dset in ${test_sets}; do
        asr_exp=${exp_dir}/exp/${model_dir}
        inference_tag="$(basename "${inference_config}" .yaml)"
        _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
        _logdir="${_dir}/logdir"
        if [ -d ${_dir} ]; then
            echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
            exit 0
        fi
        mkdir -p "${_logdir}"
        _data="${feats_dir}/data/${dset}"
        key_file=${_data}/${scp}
        num_scp_file="$(<${key_file} wc -l)"
        _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
        split_scps=
        for n in $(seq "${_nj}"); do
            split_scps+=" ${_logdir}/keys.${n}.scp"
        done
        # shellcheck disable=SC2086
        utils/split_scp.pl "${key_file}" ${split_scps}
        _opts=
        if [ -n "${inference_config}" ]; then
            _opts+="--config ${inference_config} "
        fi
        ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
            python -m funasr.bin.asr_inference_launch \
                --batch_size 1 \
                --ngpu "${_ngpu}" \
                --njob ${njob} \
                --gpuid_list ${gpuid_list} \
                --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
                --key_file "${_logdir}"/keys.JOB.scp \
                --asr_train_config "${asr_exp}"/config.yaml \
                --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
                --output_dir "${_logdir}"/output.JOB \
                --mode asr \
                ${_opts}
        for f in token token_int score text; do
            if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
                for i in $(seq "${_nj}"); do
                    cat "${_logdir}/output.${i}/1best_recog/${f}"
                done | sort -k1 >"${_dir}/${f}"
            fi
        done
        python utils/compute_wer.py ${_data}/text ${_dir}/text ${_dir}/text.cer
        tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
        cat ${_dir}/text.cer.txt
    done
fi
egs/librispeech/rnnt/utils
New file
@@ -0,0 +1 @@
../../aishell/transformer/utils
egs/librispeech_100h/rnnt/README.md
@@ -8,9 +8,9 @@
- Model size: 30.54M
## Results (CER)
- Decode config: conf/decode_rnnt_transformer.yaml (ctc weight:0.5)
- Decode config: conf/decode_rnnt_conformer.yaml
|      testset   | WER(%)  |
|      testset   |  WER(%) |
|:--------------:|:-------:|
|    test_clean  |  6.64   |
|    test_other  |  17.12  |
egs_modelscope/asr/TEMPLATE/README.md
@@ -26,7 +26,7 @@
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.6',
    update_model='v1.0.6',
    update_model=False,
    mode='paraformer_streaming'
    )
import soundfile
@@ -54,7 +54,7 @@
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.6',
    update_model='v1.0.6',
    update_model=False,
    mode="paraformer_fake_streaming"
)
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py
@@ -5,6 +5,7 @@
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.6',
    update_model=False,
    mode="paraformer_fake_streaming"
)
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py
@@ -15,7 +15,7 @@
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.6',
    update_model='v1.0.6',
    update_model=False,
    mode="paraformer_streaming"
)
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py
@@ -15,6 +15,7 @@
    kwargs = dict(
        model=params.model,
        model_revision='v1.0.6',
        update_model=False,
        data_dir=ds_dict,
        dataset_type=params.dataset_type,
        work_dir=params.output_dir,
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
@@ -12,6 +12,7 @@
        output_dir=args.output_dir,
        batch_size=args.batch_size,
        model_revision='v1.0.6',
        update_model=False,
        mode="paraformer_fake_streaming",
        param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
    )
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh
@@ -54,7 +54,7 @@
            --audio_in ${output_dir}/split/wav.$JOB.scp \
            --output_dir ${output_dir}/output.$JOB \
            --batch_size ${batch_size} \
            --gpuid ${gpuid}
            --gpuid ${gpuid} \
            --mode "paraformer_fake_streaming"
        }&
    done
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py
@@ -5,6 +5,7 @@
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.6',
    update_model=False,
    mode="paraformer_fake_streaming"
)
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py
@@ -15,7 +15,7 @@
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.6',
    update_model='v1.0.6',
    update_model=False,
    mode="paraformer_streaming"
)
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py
@@ -15,6 +15,7 @@
    kwargs = dict(
        model=params.model,
        model_revision='v1.0.6',
        update_model=False,
        data_dir=ds_dict,
        dataset_type=params.dataset_type,
        work_dir=params.output_dir,
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
@@ -12,6 +12,7 @@
        output_dir=args.output_dir,
        batch_size=args.batch_size,
        model_revision='v1.0.6',
        update_model=False,
        mode="paraformer_fake_streaming",
        param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
    )
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh
@@ -54,7 +54,7 @@
            --audio_in ${output_dir}/split/wav.$JOB.scp \
            --output_dir ${output_dir}/output.$JOB \
            --batch_size ${batch_size} \
            --gpuid ${gpuid}
            --gpuid ${gpuid} \
            --mode "paraformer_fake_streaming"
        }&
    done
egs_modelscope/tp/TEMPLATE/README.md
@@ -11,7 +11,7 @@
inference_pipeline = pipeline(
    task=Tasks.speech_timestamp,
    model='damo/speech_timestamp_prediction-v1-16k-offline',
    output_dir=None)
    model_revision='v1.1.0')
rec_result = inference_pipeline(
    audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav',
fun_text_processing/inverse_text_normalization/id/id_unit_test.tsv
@@ -1,10 +1,10 @@
dua ribu dua puluh dua    2022
tiga ribu    300
tiga ribu    3000
sembilan ribu sembilan ratus sembilan puluh sembilan    9999
seribu satu    100001
ribu    100
seribu satu    1001
ribu    1000
seribu    1000
seribu dua ratus delapan puluh sembilan    10289
seribu dua ratus delapan puluh sembilan    1289
ribu dua ratus delapan puluh sembilan    1289
nol satu dua tiga empat lima enam tujuh delapan sembilan    01 2345-6789
empat belas    14
@@ -22,8 +22,8 @@
seratus dua puluh tiga    123
ratus dua puluh tiga    123
dua puluh empat maret     24 maret
ribu tujuh puluh enam    10076
seribu tujuh puluh enam    100076
ribu tujuh puluh enam rupiah    10076 rupiah
ribu tujuh puluh enam    1076
seribu tujuh puluh enam    1076
ribu tujuh puluh enam rupiah    1076 rupiah
tujuh puluh enam    76
ditambah enam dua dua satu enam lima tiga sembilan nol enam nol lima    +62 21 6539-0605
ditambah enam dua dua satu enam lima tiga sembilan nol enam nol lima    +62 21 6539-0605
fun_text_processing/inverse_text_normalization/id/taggers/cardinal.py
@@ -26,11 +26,10 @@
        graph_teen = pynini.string_file(get_abs_path("data/numbers/teen.tsv"))
        graph_hundreds = pynini.string_file(get_abs_path("data/numbers/hundreds.tsv"))
        graph_thousand = pynini.string_file(get_abs_path("data/numbers/thousand.tsv"))
        graph_cents = pynini.cross("seratus", "100") | pynini.cross("ratus", "100") | pynini.union(graph_hundreds, pynutil.insert("0"))
        graph_hundred = pynini.cross("ratus", "") | pynini.cross("seratus", "")
        graph_hundred_component = pynini.union(graph_digit + delete_space + graph_hundred, pynutil.insert("00"))
        graph_hundred_component = pynini.union(graph_digit + delete_space + graph_hundred, pynutil.insert("0"))
        graph_hundred_component += delete_space
        graph_hundred_component += pynini.union(
            graph_teen | pynutil.insert("00"),
@@ -44,8 +43,8 @@
                (graph_ties | pynutil.insert("0")) + delete_space + (
                            graph_digit | pynutil.insert("0")),
        )
        graph_hundred_component = graph_hundred_component | graph_cents | graph_one_hundred_component
        graph_hundred_component = graph_hundred_component | graph_one_hundred_component
        graph_hundred_component_at_least_one_none_zero_digit = graph_hundred_component @ (
            pynini.closure(DAMO_DIGIT) + (DAMO_DIGIT - "0") + pynini.closure(DAMO_DIGIT)
        )
@@ -54,14 +53,12 @@
        )
        graph_thousand = pynini.cross("ribu", "") | pynini.cross("seribu", "")
        graph_one_thousand_component = pynini.union(pynini.cross("ribu", "1") | pynini.cross("seribu", "1"))
        graph_thousand_cents = pynini.cross("seribu", "10") | pynini.cross("ribu","10") | pynini.union(graph_thousand, pynutil.insert(""))
        graph_thousands = pynini.union(
            graph_hundred_component_at_least_one_none_zero_digit + delete_space + (pynutil.delete("ribu") | pynutil.delete("seribu")),
            pynutil.insert("000", weight=0.1),
        )
        graph_thousand_component = pynini.union(graph_digit + delete_space + graph_thousand, pynutil.insert("000"))
        graph_thousand_component += delete_space
        graph_thousands = graph_thousands | graph_thousand_cents | graph_thousand_component | graph_one_thousand_component
        graph_thousands = graph_thousands | (pynutil.insert("00") + graph_one_thousand_component)
        graph_million = pynini.union(
            graph_hundred_component_at_least_one_none_zero_digit + delete_space + (pynutil.delete("juta") | pynutil.delete("sejuta")),
funasr/datasets/large_datasets/build_dataloader.py
@@ -77,12 +77,13 @@
            bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel)
        self.dataset_conf = args.dataset_conf
        self.frontend_conf = args.frontend_conf
        self.speed_perturb = args.speed_perturb if hasattr(args, "speed_perturb") else None
        logging.info("dataloader config: {}".format(self.dataset_conf))
        batch_mode = self.dataset_conf.get("batch_mode", "padding")
        data_list = args.train_data_file if mode == "train" else args.valid_data_file
        self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
                               self.dataset_conf, self.frontend_conf,
                               speed_perturb=args.speed_perturb if mode == "train" else None,
                               speed_perturb=self.speed_perturb if mode == "train" else None,
                               mode=mode, batch_mode=batch_mode)
    def build_iter(self, epoch, shuffle=True):
funasr/runtime/html5/static/main.js
@@ -145,7 +145,9 @@
    isRec = false;
    info_div.innerHTML="请等候...";
    btnStop.disabled = true;
    setTimeout(function(){btnStart.disabled = false;info_div.innerHTML="请点击开始";}, 3000 );
    setTimeout(function(){
        console.log("call stop ws!");
        wsconnecter.wsStop();btnStart.disabled = false;info_div.innerHTML="请点击开始";}, 3000 );
    rec.stop(function(blob,duration){
  
        console.log(blob);
funasr/runtime/html5/static/wsconnecter.js
@@ -28,7 +28,11 @@
        if ( 'WebSocket' in window ) {
            speechSokt = new WebSocket( Uri ); // 定义socket连接对象
            speechSokt.onopen = function(e){onOpen(e);}; // 定义响应函数
            speechSokt.onclose = function(e){onClose(e);};
            speechSokt.onclose = function(e){
                console.log("onclose ws!");
                speechSokt.close();
                onClose(e);
                };
            speechSokt.onmessage = function(e){onMessage(e);};
            speechSokt.onerror = function(e){onError(e);};
            return 1;
@@ -42,6 +46,7 @@
    // 定义停止与发送函数
    this.wsStop = function () {
        if(speechSokt != undefined) {
            console.log("stop ws!");
            speechSokt.close();
        }
    };
funasr/runtime/onnxruntime/CMakeLists.txt
@@ -7,6 +7,8 @@
# set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
include(TestBigEndian)
test_big_endian(BIG_ENDIAN)
@@ -30,12 +32,13 @@
include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi-native-fbank)
include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include)
add_subdirectory(third_party/yaml-cpp)
add_subdirectory(third_party/kaldi-native-fbank/kaldi-native-fbank/csrc)
add_subdirectory(src)
if(ENABLE_GLOG)
    include_directories(${PROJECT_SOURCE_DIR}/third_party/glog)
    set(BUILD_TESTING OFF)
    add_subdirectory(third_party/glog)
endif()
endif()
add_subdirectory(third_party/yaml-cpp)
add_subdirectory(third_party/kaldi-native-fbank/kaldi-native-fbank/csrc)
add_subdirectory(src)
add_subdirectory(bin)
funasr/runtime/onnxruntime/bin/CMakeLists.txt
New file
@@ -0,0 +1,16 @@
include_directories(${CMAKE_SOURCE_DIR}/include)
add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp")
target_link_libraries(funasr-onnx-offline PUBLIC funasr)
add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp")
target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp")
target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
File was renamed from funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
@@ -28,7 +28,7 @@
std::atomic<int> wav_index(0);
std::mutex mtx;
void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list,
void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wav_ids,
            float* total_length, long* total_time, int core_id) {
    
    struct timeval start, end;
@@ -59,7 +59,7 @@
        if(result){
            string msg = FunASRGetResult(result, 0);
            LOG(INFO) << "Thread: " << this_thread::get_id() <<" Result: " << msg.c_str();
            LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg.c_str();
            float snippet_time = FunASRGetRetSnippetTime(result);
            n_total_length += snippet_time;
@@ -146,9 +146,12 @@
    // read wav_path
    vector<string> wav_list;
    vector<string> wav_ids;
    string default_id = "wav_default_id";
    string wav_path_ = model_path.at(WAV_PATH);
    if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
        wav_list.emplace_back(wav_path_);
        wav_ids.emplace_back(default_id);
    }
    else if(is_target_file(wav_path_, "scp")){
        ifstream in(wav_path_);
@@ -162,7 +165,8 @@
            istringstream iss(line);
            string column1, column2;
            iss >> column1 >> column2;
            wav_list.emplace_back(column2);
            wav_list.emplace_back(column2);
            wav_ids.emplace_back(column1);
        }
        in.close();
    }else{
@@ -178,7 +182,7 @@
    int rtf_threds = thread_num.getValue();
    for (int i = 0; i < rtf_threds; i++)
    {
        threads.emplace_back(thread(runReg, asr_handle, wav_list, &total_length, &total_time, i));
        threads.emplace_back(thread(runReg, asr_handle, wav_list, wav_ids, &total_length, &total_time, i));
    }
    for (auto& thread : threads)
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp
File was renamed from funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
@@ -38,8 +38,8 @@
    }
}
void print_segs(vector<vector<int>>* vec) {
    string seg_out="[";
void print_segs(vector<vector<int>>* vec, string &wav_id) {
    string seg_out=wav_id + ": [";
    for (int i = 0; i < vec->size(); i++) {
        vector<int> inner_vec = (*vec)[i];
        seg_out += "[";
@@ -97,9 +97,12 @@
    // read wav_path
    vector<string> wav_list;
    vector<string> wav_ids;
    string default_id = "wav_default_id";
    string wav_path_ = model_path.at(WAV_PATH);
    if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
        wav_list.emplace_back(wav_path_);
        wav_ids.emplace_back(default_id);
    }
    else if(is_target_file(wav_path_, "scp")){
        ifstream in(wav_path_);
@@ -113,7 +116,8 @@
            istringstream iss(line);
            string column1, column2;
            iss >> column1 >> column2;
            wav_list.emplace_back(column2);
            wav_list.emplace_back(column2);
            wav_ids.emplace_back(column1);
        }
        in.close();
    }else{
@@ -123,9 +127,11 @@
    
    float snippet_time = 0.0f;
    long taking_micros = 0;
    for(auto& wav_file : wav_list){
    for (int i = 0; i < wav_list.size(); i++) {
        auto& wav_file = wav_list[i];
        auto& wav_id = wav_ids[i];
        gettimeofday(&start, NULL);
        FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), FSMN_VAD_OFFLINE, NULL, 16000);
        FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), NULL, 16000);
        gettimeofday(&end, NULL);
        seconds = (end.tv_sec - start.tv_sec);
        taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
@@ -133,7 +139,7 @@
        if (result)
        {
            vector<std::vector<int>>* vad_segments = FsmnVadGetResult(result, 0);
            print_segs(vad_segments);
            print_segs(vad_segments, wav_id);
            snippet_time += FsmnVadGetRetSnippetTime(result);
            FsmnVadFreeResult(result);
        }
@@ -142,7 +148,7 @@
            LOG(ERROR) << ("No return data!\n");
        }
    }
    LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
    LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
    LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
File was renamed from funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
@@ -88,9 +88,12 @@
    // read wav_path
    vector<string> wav_list;
    string wav_path_ = model_path.at(WAV_PATH);
    vector<string> wav_ids;
    string default_id = "wav_default_id";
    string wav_path_ = model_path.at(WAV_PATH);
    if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
        wav_list.emplace_back(wav_path_);
        wav_ids.emplace_back(default_id);
    }
    else if(is_target_file(wav_path_, "scp")){
        ifstream in(wav_path_);
@@ -104,7 +107,8 @@
            istringstream iss(line);
            string column1, column2;
            iss >> column1 >> column2;
            wav_list.emplace_back(column2);
            wav_list.emplace_back(column2);
            wav_ids.emplace_back(column1);
        }
        in.close();
    }else{
@@ -114,7 +118,9 @@
    
    float snippet_time = 0.0f;
    long taking_micros = 0;
    for(auto& wav_file : wav_list){
    for (int i = 0; i < wav_list.size(); i++) {
        auto& wav_file = wav_list[i];
        auto& wav_id = wav_ids[i];
        gettimeofday(&start, NULL);
        FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, 16000);
        gettimeofday(&end, NULL);
@@ -124,7 +130,7 @@
        if (result)
        {
            string msg = FunASRGetResult(result, 0);
            LOG(INFO)<<"Result: "<<msg;
            LOG(INFO)<< wav_id <<" : "<<msg;
            snippet_time += FunASRGetRetSnippetTime(result);
            FunASRFreeResult(result);
        }
funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
copy from funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp copy to funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
File was copied from funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
@@ -18,6 +18,7 @@
#include "funasrruntime.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
#include "audio.h"
using namespace std;
@@ -38,10 +39,16 @@
    }
}
void print_segs(vector<vector<int>>* vec) {
    string seg_out="[";
void print_segs(vector<vector<int>>* vec, string &wav_id) {
    if((*vec).size() == 0){
        return;
    }
    string seg_out=wav_id + ": [";
    for (int i = 0; i < vec->size(); i++) {
        vector<int> inner_vec = (*vec)[i];
        if(inner_vec.size() == 0){
            continue;
        }
        seg_out += "[";
        for (int j = 0; j < inner_vec.size(); j++) {
            seg_out += to_string(inner_vec[j]);
@@ -97,9 +104,12 @@
    // read wav_path
    vector<string> wav_list;
    vector<string> wav_ids;
    string default_id = "wav_default_id";
    string wav_path_ = model_path.at(WAV_PATH);
    if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
        wav_list.emplace_back(wav_path_);
        wav_ids.emplace_back(default_id);
    }
    else if(is_target_file(wav_path_, "scp")){
        ifstream in(wav_path_);
@@ -113,39 +123,76 @@
            istringstream iss(line);
            string column1, column2;
            iss >> column1 >> column2;
            wav_list.emplace_back(column2);
            wav_list.emplace_back(column2);
            wav_ids.emplace_back(column1);
        }
        in.close();
    }else{
        LOG(ERROR)<<"Please check the wav extension!";
        exit(-1);
    }
    // init online features
    FUNASR_HANDLE online_hanlde=FsmnVadOnlineInit(vad_hanlde);
    float snippet_time = 0.0f;
    long taking_micros = 0;
    for(auto& wav_file : wav_list){
        gettimeofday(&start, NULL);
        FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), FSMN_VAD_OFFLINE, NULL, 16000);
        gettimeofday(&end, NULL);
        seconds = (end.tv_sec - start.tv_sec);
        taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
    for (int i = 0; i < wav_list.size(); i++) {
        auto& wav_file = wav_list[i];
        auto& wav_id = wav_ids[i];
        if (result)
        {
            vector<std::vector<int>>* vad_segments = FsmnVadGetResult(result, 0);
            print_segs(vad_segments);
            snippet_time += FsmnVadGetRetSnippetTime(result);
            FsmnVadFreeResult(result);
        }
        else
        {
            LOG(ERROR) << ("No return data!\n");
        int32_t sampling_rate_ = -1;
        funasr::Audio audio(1);
        if(is_target_file(wav_file.c_str(), "wav")){
            int32_t sampling_rate_ = -1;
            if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
                LOG(ERROR)<<"Failed to load "<< wav_file;
                exit(-1);
            }
        }else if(is_target_file(wav_file.c_str(), "pcm")){
            if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
                LOG(ERROR)<<"Failed to load "<< wav_file;
                exit(-1);
            }
        }else{
            LOG(ERROR)<<"Wrong wav extension";
            exit(-1);
        }
        char* speech_buff = audio.GetSpeechChar();
        int buff_len = audio.GetSpeechLen()*2;
        int step = 3200;
        bool is_final = false;
        for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
            if (sample_offset + step >= buff_len - 1) {
                    step = buff_len - sample_offset;
                    is_final = true;
                } else {
                    is_final = false;
            }
            gettimeofday(&start, NULL);
            FUNASR_RESULT result = FsmnVadInferBuffer(online_hanlde, speech_buff+sample_offset, step, NULL, is_final, 16000);
            gettimeofday(&end, NULL);
            seconds = (end.tv_sec - start.tv_sec);
            taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
            if (result)
            {
                vector<std::vector<int>>* vad_segments = FsmnVadGetResult(result, 0);
                print_segs(vad_segments, wav_id);
                snippet_time += FsmnVadGetRetSnippetTime(result);
                FsmnVadFreeResult(result);
            }
            else
            {
                LOG(ERROR) << ("No return data!\n");
            }
        }
    }
    LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
    LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
    LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
    FsmnVadUninit(online_hanlde);
    FsmnVadUninit(vad_hanlde);
    return 0;
}
funasr/runtime/onnxruntime/include/audio.h
@@ -33,8 +33,9 @@
class Audio {
  private:
    float *speech_data;
    int16_t *speech_buff;
    float *speech_data=nullptr;
    int16_t *speech_buff=nullptr;
    char* speech_char=nullptr;
    int speech_len;
    int speech_align_len;
    int offset;
@@ -47,18 +48,22 @@
    Audio(int data_type, int size);
    ~Audio();
    void Disp();
    bool LoadWav(const char* filename, int32_t* sampling_rate);
    void WavResample(int32_t sampling_rate, const float *waveform, int32_t n);
    bool LoadWav(const char* buf, int n_len, int32_t* sampling_rate);
    bool LoadWav(const char* filename, int32_t* sampling_rate);
    bool LoadWav2Char(const char* filename, int32_t* sampling_rate);
    bool LoadPcmwav(const char* buf, int n_file_len, int32_t* sampling_rate);
    bool LoadPcmwav(const char* filename, int32_t* sampling_rate);
    bool LoadPcmwav2Char(const char* filename, int32_t* sampling_rate);
    int FetchChunck(float *&dout, int len);
    int Fetch(float *&dout, int &len, int &flag);
    void Padding();
    void Split(OfflineStream* offline_streamj);
    void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments);
    void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
    float GetTimeLen();
    int GetQueueSize() { return (int)frame_queue.size(); }
    char* GetSpeechChar(){return speech_char;}
    int GetSpeechLen(){return speech_len;}
};
} // namespace funasr
funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -46,12 +46,6 @@
    FUNASR_MODEL_PARAFORMER = 3,
}FUNASR_MODEL_TYPE;
typedef enum
{
 FSMN_VAD_OFFLINE=0,
 FSMN_VAD_ONLINE = 1,
}FSMN_VAD_MODE;
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
    
// ASR
@@ -68,11 +62,12 @@
_FUNASRAPI const float    FunASRGetRetSnippetTime(FUNASR_RESULT result);
// VAD
_FUNASRAPI FUNASR_HANDLE      FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num, FSMN_VAD_MODE mode=FSMN_VAD_OFFLINE);
_FUNASRAPI FUNASR_HANDLE      FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num);
_FUNASRAPI FUNASR_HANDLE      FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle);
// buffer
_FUNASRAPI FUNASR_RESULT    FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
_FUNASRAPI FUNASR_RESULT    FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000);
// file, support wav & pcm
_FUNASRAPI FUNASR_RESULT    FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
_FUNASRAPI FUNASR_RESULT    FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate=16000);
_FUNASRAPI std::vector<std::vector<int>>*    FsmnVadGetResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI void                 FsmnVadFreeResult(FUNASR_RESULT result);
funasr/runtime/onnxruntime/include/vad-model.h
@@ -12,14 +12,9 @@
    virtual ~VadModel(){};
    virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
    virtual std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true)=0;
    virtual void ReadModel(const char* vad_model)=0;
    virtual void LoadConfigFromYaml(const char* filename)=0;
    virtual void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
                    std::vector<float> &waves)=0;
    virtual void LoadCmvn(const char *filename)=0;
    virtual void InitCache()=0;
};
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num, int mode);
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num);
VadModel *CreateVadModel(void* fsmnvad_handle);
} // namespace funasr
#endif
funasr/runtime/onnxruntime/src/CMakeLists.txt
@@ -1,11 +1,8 @@
file(GLOB files1 "*.cpp")
file(GLOB files2 "*.cc")
set(files ${files1})
set(files ${files1} ${files2})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
add_library(funasr ${files})
add_library(funasr SHARED ${files})
if(WIN32)
    set(EXTRA_LIBS pthread yaml-cpp csrc glog)
@@ -24,13 +21,3 @@
include_directories(${CMAKE_SOURCE_DIR}/include)
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp")
add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp")
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
target_link_libraries(funasr-onnx-offline PUBLIC funasr)
target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
funasr/runtime/onnxruntime/src/audio.cpp
@@ -176,12 +176,12 @@
{
    if (speech_buff != NULL) {
        free(speech_buff);
    }
    if (speech_data != NULL) {
        free(speech_data);
    }
    if (speech_char != NULL) {
        free(speech_char);
    }
}
@@ -296,8 +296,47 @@
        return false;
}
bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
bool Audio::LoadWav2Char(const char *filename, int32_t* sampling_rate)
{
    WaveHeader header;
    if (speech_char != NULL) {
        free(speech_char);
    }
    offset = 0;
    std::ifstream is(filename, std::ifstream::binary);
    is.read(reinterpret_cast<char *>(&header), sizeof(header));
    if(!is){
        LOG(ERROR) << "Failed to read " << filename;
        return false;
    }
    if (!header.Validate()) {
        return false;
    }
    header.SeekToDataChunk(is);
        if (!is) {
            return false;
    }
    if (!header.Validate()) {
        return false;
    }
    header.SeekToDataChunk(is);
    if (!is) {
        return false;
    }
    *sampling_rate = header.sample_rate;
    // header.subchunk2_size contains the number of bytes in the data.
    // As we assume each sample contains two bytes, so it is divided by 2 here
    speech_len = header.subchunk2_size / 2;
    speech_char = (char *)malloc(header.subchunk2_size);
    memset(speech_char, 0, header.subchunk2_size);
    is.read(speech_char, header.subchunk2_size);
    return true;
}
bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
{
    WaveHeader header;
    if (speech_data != NULL) {
        free(speech_data);
@@ -441,6 +480,33 @@
}
bool Audio::LoadPcmwav2Char(const char* filename, int32_t* sampling_rate)
{
    if (speech_char != NULL) {
        free(speech_char);
    }
    offset = 0;
    FILE* fp;
    fp = fopen(filename, "rb");
    if (fp == nullptr)
    {
        LOG(ERROR) << "Failed to read " << filename;
        return false;
    }
    fseek(fp, 0, SEEK_END);
    uint32_t n_file_len = ftell(fp);
    fseek(fp, 0, SEEK_SET);
    speech_len = (n_file_len) / 2;
    speech_char = (char *)malloc(n_file_len);
    memset(speech_char, 0, n_file_len);
    fread(speech_char, sizeof(int16_t), n_file_len/2, fp);
    fclose(fp);
    return true;
}
int Audio::FetchChunck(float *&dout, int len)
{
    if (offset >= speech_align_len) {
@@ -541,7 +607,7 @@
}
void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments)
void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished)
{
    AudioFrame *frame;
@@ -552,7 +618,7 @@
    frame = NULL;
    std::vector<float> pcm_data(speech_data, speech_data+sp_len);
    vad_segments = vad_obj->Infer(pcm_data);
    vad_segments = vad_obj->Infer(pcm_data, input_finished);
}
} // namespace funasr
funasr/runtime/onnxruntime/src/ct-transformer.cpp
@@ -103,9 +103,10 @@
        vector<string> WordWithPunc;
        for (int i = 0; i < InputStr.size(); i++)
        {
            if (i > 0 && !(InputStr[i][0] & 0x80) && (i + 1) <InputStr.size() && !(InputStr[i+1][0] & 0x80))// �м��Ӣ�ģ�
            // if (i > 0 && !(InputStr[i][0] & 0x80) && (i + 1) <InputStr.size() && !(InputStr[i+1][0] & 0x80))// �м��Ӣ�ģ�
            if (i > 0 && !(InputStr[i-1][0] & 0x80) && !(InputStr[i][0] & 0x80))
            {
                InputStr[i] = InputStr[i]+ " ";
                InputStr[i] = " " + InputStr[i];
            }
            WordWithPunc.push_back(InputStr[i]);
@@ -187,4 +188,4 @@
    return punction;
}
} // namespace funasr
} // namespace funasr
funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp
New file
@@ -0,0 +1,198 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#include <fstream>
#include "precomp.h"
namespace funasr {
void FsmnVadOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
                               std::vector<float> &waves) {
    knf::OnlineFbank fbank(fbank_opts_);
    // cache merge
    waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
    int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
    // Send the audio after the last frame shift position to the cache
    input_cache_.clear();
    input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
    if (frame_number == 0) {
        return;
    }
    // Delete audio that haven't undergone fbank processing
    waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
    std::vector<float> buf(waves.size());
    for (int32_t i = 0; i != waves.size(); ++i) {
        buf[i] = waves[i] * 32768;
    }
    fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
    // fbank.AcceptWaveform(sample_rate, &waves[0], waves.size());
    int32_t frames = fbank.NumFramesReady();
    for (int32_t i = 0; i != frames; ++i) {
        const float *frame = fbank.GetFrame(i);
        vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
        vad_feats.emplace_back(frame_vector);
    }
}
void FsmnVadOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &vad_feats,
                                 vector<float> &waves, bool input_finished) {
  FbankKaldi(sample_rate, vad_feats, waves);
  // cache deal & online lfr,cmvn
  if (vad_feats.size() > 0) {
    if (!reserve_waveforms_.empty()) {
      waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
    }
    if (lfr_splice_cache_.empty()) {
      for (int i = 0; i < (lfr_m - 1) / 2; i++) {
        lfr_splice_cache_.emplace_back(vad_feats[0]);
      }
    }
    if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
      vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
      int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
      int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
      int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats, input_finished);
      int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
      reserve_waveforms_.clear();
      reserve_waveforms_.insert(reserve_waveforms_.begin(),
                                waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
                                waves.begin() + frame_from_waves * frame_shift_sample_length_);
      int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
      waves.erase(waves.begin() + sample_length, waves.end());
    } else {
      reserve_waveforms_.clear();
      reserve_waveforms_.insert(reserve_waveforms_.begin(),
                                waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
      lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end());
    }
  } else {
    if (input_finished) {
      if (!reserve_waveforms_.empty()) {
        waves = reserve_waveforms_;
      }
      vad_feats = lfr_splice_cache_;
      OnlineLfrCmvn(vad_feats, input_finished);
    }
  }
  if(input_finished){
      Reset();
      ResetCache();
  }
}
int FsmnVadOnline::OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished) {
    vector<vector<float>> out_feats;
    int T = vad_feats.size();
    int T_lrf = ceil((T - (lfr_m - 1) / 2) / lfr_n);
    int lfr_splice_frame_idxs = T_lrf;
    vector<float> p;
    for (int i = 0; i < T_lrf; i++) {
        if (lfr_m <= T - i * lfr_n) {
            for (int j = 0; j < lfr_m; j++) {
                p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
            }
            out_feats.emplace_back(p);
            p.clear();
        } else {
            if (input_finished) {
                int num_padding = lfr_m - (T - i * lfr_n);
                for (int j = 0; j < (vad_feats.size() - i * lfr_n); j++) {
                    p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
                }
                for (int j = 0; j < num_padding; j++) {
                    p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
                }
                out_feats.emplace_back(p);
            } else {
                lfr_splice_frame_idxs = i;
                break;
            }
        }
    }
    lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n);
    lfr_splice_cache_.clear();
    lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end());
    // Apply cmvn
    for (auto &out_feat: out_feats) {
        for (int j = 0; j < means_list_.size(); j++) {
            out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
        }
    }
    vad_feats = out_feats;
    return lfr_splice_frame_idxs;
}
std::vector<std::vector<int>>
FsmnVadOnline::Infer(std::vector<float> &waves, bool input_finished) {
    std::vector<std::vector<float>> vad_feats;
    std::vector<std::vector<float>> vad_probs;
    ExtractFeats(vad_sample_rate_, vad_feats, waves, input_finished);
    fsmnvad_handle_->Forward(vad_feats, &vad_probs, &in_cache_, input_finished);
    std::vector<std::vector<int>> vad_segments;
    vad_segments = vad_scorer(vad_probs, waves, input_finished, true, vad_silence_duration_, vad_max_len_,
                              vad_speech_noise_thres_, vad_sample_rate_);
    return vad_segments;
}
void FsmnVadOnline::InitCache(){
  std::vector<float> cache_feats(128 * 19 * 1, 0);
  for (int i=0;i<4;i++){
    in_cache_.emplace_back(cache_feats);
  }
};
void FsmnVadOnline::Reset(){
  in_cache_.clear();
  InitCache();
};
void FsmnVadOnline::Test() {
}
void FsmnVadOnline::InitOnline(std::shared_ptr<Ort::Session> &vad_session,
                               Ort::Env &env,
                               std::vector<const char *> &vad_in_names,
                               std::vector<const char *> &vad_out_names,
                               knf::FbankOptions &fbank_opts,
                               std::vector<float> &means_list,
                               std::vector<float> &vars_list,
                               int vad_sample_rate,
                               int vad_silence_duration,
                               int vad_max_len,
                               double vad_speech_noise_thres) {
    vad_session_ = vad_session;
    vad_in_names_ = vad_in_names;
    vad_out_names_ = vad_out_names;
    fbank_opts_ = fbank_opts;
    means_list_ = means_list;
    vars_list_ = vars_list;
    vad_sample_rate_ = vad_sample_rate;
    vad_silence_duration_ = vad_silence_duration;
    vad_max_len_ = vad_max_len;
    vad_speech_noise_thres_ = vad_speech_noise_thres;
}
FsmnVadOnline::~FsmnVadOnline() {
}
FsmnVadOnline::FsmnVadOnline(FsmnVad* fsmnvad_handle):fsmnvad_handle_(std::move(fsmnvad_handle)),session_options_{}{
   InitCache();
   InitOnline(fsmnvad_handle_->vad_session_,
              fsmnvad_handle_->env_,
              fsmnvad_handle_->vad_in_names_,
              fsmnvad_handle_->vad_out_names_,
              fsmnvad_handle_->fbank_opts_,
              fsmnvad_handle_->means_list_,
              fsmnvad_handle_->vars_list_,
              fsmnvad_handle_->vad_sample_rate_,
              fsmnvad_handle_->vad_silence_duration_,
              fsmnvad_handle_->vad_max_len_,
              fsmnvad_handle_->vad_speech_noise_thres_);
}
} // namespace funasr
funasr/runtime/onnxruntime/src/fsmn-vad-online.h
New file
@@ -0,0 +1,88 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#pragma once
#include "precomp.h"
namespace funasr {
class FsmnVadOnline : public VadModel {
/**
 * Author: Speech Lab of DAMO Academy, Alibaba Group
 * Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 * https://arxiv.org/abs/1803.05030
*/
public:
    explicit FsmnVadOnline(FsmnVad* fsmnvad_handle);
    ~FsmnVadOnline();
    void Test();
    std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
    void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
    void Reset();
private:
    E2EVadModel vad_scorer = E2EVadModel();
    // std::unique_ptr<FsmnVad> fsmnvad_handle_;
    FsmnVad* fsmnvad_handle_ = nullptr;
    void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
                    std::vector<float> &waves);
    int OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished);
    void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num){}
    void InitCache();
    void InitOnline(std::shared_ptr<Ort::Session> &vad_session,
                    Ort::Env &env,
                    std::vector<const char *> &vad_in_names,
                    std::vector<const char *> &vad_out_names,
                    knf::FbankOptions &fbank_opts,
                    std::vector<float> &means_list,
                    std::vector<float> &vars_list,
                    int vad_sample_rate,
                    int vad_silence_duration,
                    int vad_max_len,
                    double vad_speech_noise_thres);
    static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
        int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
        if (frame_num >= 1 && sample_length >= frame_sample_length)
            return frame_num;
        else
            return 0;
    }
    void ResetCache() {
        reserve_waveforms_.clear();
        input_cache_.clear();
        lfr_splice_cache_.clear();
    }
    // from fsmnvad_handle_
    std::shared_ptr<Ort::Session> vad_session_ = nullptr;
    Ort::Env env_;
    Ort::SessionOptions session_options_;
    std::vector<const char *> vad_in_names_;
    std::vector<const char *> vad_out_names_;
    knf::FbankOptions fbank_opts_;
    std::vector<float> means_list_;
    std::vector<float> vars_list_;
    std::vector<std::vector<float>> in_cache_;
    // The reserved waveforms by fbank
    std::vector<float> reserve_waveforms_;
    // waveforms reserved after last shift position
    std::vector<float> input_cache_;
    // lfr reserved cache
    std::vector<std::vector<float>> lfr_splice_cache_;
    int vad_sample_rate_ = MODEL_SAMPLE_RATE;
    int vad_silence_duration_ = VAD_SILENCE_DURATION;
    int vad_max_len_ = VAD_MAX_LEN;
    double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
    int lfr_m = VAD_LFR_M;
    int lfr_n = VAD_LFR_N;
    int frame_sample_length_ = vad_sample_rate_ / 1000 * 25;;
    int frame_shift_sample_length_ = vad_sample_rate_ / 1000 * 10;
};
} // namespace funasr
funasr/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -37,14 +37,14 @@
        this->vad_max_len_ = post_conf["max_single_segment_time"].as<int>();
        this->vad_speech_noise_thres_ = post_conf["speech_noise_thres"].as<double>();
        fbank_opts.frame_opts.dither = frontend_conf["dither"].as<float>();
        fbank_opts.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
        fbank_opts.frame_opts.samp_freq = (float)vad_sample_rate_;
        fbank_opts.frame_opts.window_type = frontend_conf["window"].as<string>();
        fbank_opts.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
        fbank_opts.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
        fbank_opts.energy_floor = 0;
        fbank_opts.mel_opts.debug_mel = false;
        fbank_opts_.frame_opts.dither = frontend_conf["dither"].as<float>();
        fbank_opts_.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
        fbank_opts_.frame_opts.samp_freq = (float)vad_sample_rate_;
        fbank_opts_.frame_opts.window_type = frontend_conf["window"].as<string>();
        fbank_opts_.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
        fbank_opts_.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
        fbank_opts_.energy_floor = 0;
        fbank_opts_.mel_opts.debug_mel = false;
    }catch(exception const &e){
        LOG(ERROR) << "Error when load argument from vad config YAML.";
        exit(-1);
@@ -55,6 +55,7 @@
    try {
        vad_session_ = std::make_shared<Ort::Session>(
                env_, vad_model, session_options_);
        LOG(INFO) << "Successfully load model from " << vad_model;
    } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load vad onnx model: " << e.what();
        exit(0);
@@ -109,7 +110,9 @@
void FsmnVad::Forward(
        const std::vector<std::vector<float>> &chunk_feats,
        std::vector<std::vector<float>> *out_prob) {
        std::vector<std::vector<float>> *out_prob,
        std::vector<std::vector<float>> *in_cache,
        bool is_final) {
    Ort::MemoryInfo memory_info =
            Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
@@ -132,9 +135,9 @@
    // 4 caches
    // cache node {batch,128,19,1}
    const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
    for (int i = 0; i < in_cache_.size(); i++) {
    for (int i = 0; i < in_cache->size(); i++) {
      vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
              memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4)));
              memory_info, (*in_cache)[i].data(), (*in_cache)[i].size(), cache_feats_shape, 4)));
    }
  
    // 4. Onnx infer
@@ -162,15 +165,17 @@
    }
  
    // get 4 caches outputs,each size is 128*19
    // for (int i = 1; i < 5; i++) {
    //   float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
    //   memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
    // }
    if(!is_final){
        for (int i = 1; i < 5; i++) {
        float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
        memcpy((*in_cache)[i-1].data(), data, sizeof(float) * 128*19);
        }
    }
}
void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
                         std::vector<float> &waves) {
    knf::OnlineFbank fbank(fbank_opts);
    knf::OnlineFbank fbank(fbank_opts_);
    std::vector<float> buf(waves.size());
    for (int32_t i = 0; i != waves.size(); ++i) {
@@ -180,7 +185,7 @@
    int32_t frames = fbank.NumFramesReady();
    for (int32_t i = 0; i != frames; ++i) {
        const float *frame = fbank.GetFrame(i);
        std::vector<float> frame_vector(frame, frame + fbank_opts.mel_opts.num_bins);
        std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
        vad_feats.emplace_back(frame_vector);
    }
}
@@ -205,7 +210,7 @@
                vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
                if (means_lines[0] == "<LearnRateCoef>") {
                    for (int j = 3; j < means_lines.size() - 1; j++) {
                        means_list.push_back(stof(means_lines[j]));
                        means_list_.push_back(stof(means_lines[j]));
                    }
                    continue;
                }
@@ -216,8 +221,8 @@
                vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
                if (vars_lines[0] == "<LearnRateCoef>") {
                    for (int j = 3; j < vars_lines.size() - 1; j++) {
                        // vars_list.push_back(stof(vars_lines[j])*scale);
                        vars_list.push_back(stof(vars_lines[j]));
                        // vars_list_.push_back(stof(vars_lines[j])*scale);
                        vars_list_.push_back(stof(vars_lines[j]));
                    }
                    continue;
                }
@@ -263,8 +268,8 @@
    }
    // Apply cmvn
    for (auto &out_feat: out_feats) {
        for (int j = 0; j < means_list.size(); j++) {
            out_feat[j] = (out_feat[j] + means_list[j]) * vars_list[j];
        for (int j = 0; j < means_list_.size(); j++) {
            out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
        }
    }
    vad_feats = out_feats;
@@ -276,7 +281,7 @@
    std::vector<std::vector<float>> vad_probs;
    FbankKaldi(vad_sample_rate_, vad_feats, waves);
    LfrCmvn(vad_feats);
    Forward(vad_feats, &vad_probs);
    Forward(vad_feats, &vad_probs, &in_cache_, input_finished);
    E2EVadModel vad_scorer = E2EVadModel();
    std::vector<std::vector<int>> vad_segments;
funasr/runtime/onnxruntime/src/fsmn-vad.h
@@ -22,7 +22,30 @@
    void Test();
    void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num);
    std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true);
    void Forward(
        const std::vector<std::vector<float>> &chunk_feats,
        std::vector<std::vector<float>> *out_prob,
        std::vector<std::vector<float>> *in_cache,
        bool is_final);
    void Reset();
    std::shared_ptr<Ort::Session> vad_session_ = nullptr;
    Ort::Env env_;
    Ort::SessionOptions session_options_;
    std::vector<const char *> vad_in_names_;
    std::vector<const char *> vad_out_names_;
    std::vector<std::vector<float>> in_cache_;
    knf::FbankOptions fbank_opts_;
    std::vector<float> means_list_;
    std::vector<float> vars_list_;
    int vad_sample_rate_ = MODEL_SAMPLE_RATE;
    int vad_silence_duration_ = VAD_SILENCE_DURATION;
    int vad_max_len_ = VAD_MAX_LEN;
    double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
    int lfr_m = VAD_LFR_M;
    int lfr_n = VAD_LFR_N;
private:
@@ -37,31 +60,9 @@
                    std::vector<float> &waves);
    void LfrCmvn(std::vector<std::vector<float>> &vad_feats);
    void Forward(
            const std::vector<std::vector<float>> &chunk_feats,
            std::vector<std::vector<float>> *out_prob);
    void LoadCmvn(const char *filename);
    void InitCache();
    std::shared_ptr<Ort::Session> vad_session_ = nullptr;
    Ort::Env env_;
    Ort::SessionOptions session_options_;
    std::vector<const char *> vad_in_names_;
    std::vector<const char *> vad_out_names_;
    std::vector<std::vector<float>> in_cache_;
    knf::FbankOptions fbank_opts;
    std::vector<float> means_list;
    std::vector<float> vars_list;
    int vad_sample_rate_ = MODEL_SAMPLE_RATE;
    int vad_silence_duration_ = VAD_SILENCE_DURATION;
    int vad_max_len_ = VAD_MAX_LEN;
    double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
    int lfr_m = VAD_LFR_M;
    int lfr_n = VAD_LFR_N;
};
} // namespace funasr
funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -11,9 +11,15 @@
        return mm;
    }
    _FUNASRAPI FUNASR_HANDLE  FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num, FSMN_VAD_MODE mode)
    _FUNASRAPI FUNASR_HANDLE  FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num)
    {
        funasr::VadModel* mm = funasr::CreateVadModel(model_path, thread_num, mode);
        funasr::VadModel* mm = funasr::CreateVadModel(model_path, thread_num);
        return mm;
    }
    _FUNASRAPI FUNASR_HANDLE  FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle)
    {
        funasr::VadModel* mm = funasr::CreateVadModel(fsmnvad_handle);
        return mm;
    }
@@ -96,7 +102,7 @@
    }
    // APIs for VAD Infer
    _FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
    _FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate)
    {
        funasr::VadModel* vad_obj = (funasr::VadModel*)handle;
        if (!vad_obj)
@@ -110,13 +116,13 @@
        p_result->snippet_time = audio.GetTimeLen();
        
        vector<std::vector<int>> vad_segments;
        audio.Split(vad_obj, vad_segments);
        audio.Split(vad_obj, vad_segments, input_finished);
        p_result->segments = new vector<std::vector<int>>(vad_segments);
        return p_result;
    }
    _FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
    _FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate)
    {
        funasr::VadModel* vad_obj = (funasr::VadModel*)handle;
        if (!vad_obj)
@@ -139,7 +145,7 @@
        p_result->snippet_time = audio.GetTimeLen();
        
        vector<std::vector<int>> vad_segments;
        audio.Split(vad_obj, vad_segments);
        audio.Split(vad_obj, vad_segments, true);
        p_result->segments = new vector<std::vector<int>>(vad_segments);
        return p_result;
funasr/runtime/onnxruntime/src/online-feature.cpp
File was deleted
funasr/runtime/onnxruntime/src/online-feature.h
File was deleted
funasr/runtime/onnxruntime/src/paraformer.h
@@ -18,7 +18,7 @@
        //std::unique_ptr<knf::OnlineFbank> fbank_;
        knf::FbankOptions fbank_opts;
        Vocab* vocab;
        Vocab* vocab = nullptr;
        vector<float> means_list;
        vector<float> vars_list;
        const float scale = 22.6274169979695;
@@ -30,7 +30,7 @@
        void ApplyCmvn(vector<float> *v);
        string GreedySearch( float* in, int n_len, int64_t token_nums);
        std::shared_ptr<Ort::Session> m_session;
        std::shared_ptr<Ort::Session> m_session = nullptr;
        Ort::Env env_;
        Ort::SessionOptions session_options;
funasr/runtime/onnxruntime/src/precomp.h
@@ -36,8 +36,9 @@
#include "offline-stream.h"
#include "tokenizer.h"
#include "ct-transformer.h"
#include "fsmn-vad.h"
#include "e2e-vad.h"
#include "fsmn-vad.h"
#include "fsmn-vad-online.h"
#include "vocab.h"
#include "audio.h"
#include "tensor.h"
funasr/runtime/onnxruntime/src/vad-model.cpp
@@ -1,14 +1,10 @@
#include "precomp.h"
namespace funasr {
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num, int mode)
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num)
{
    VadModel *mm;
    if(mode == FSMN_VAD_OFFLINE){
        mm = new FsmnVad();
    }else{
        LOG(ERROR)<<"Online fsmn vad not imp!";
    }
    mm = new FsmnVad();
    string vad_model_path;
    string vad_cmvn_path;
@@ -25,4 +21,11 @@
    return mm;
}
VadModel *CreateVadModel(void* fsmnvad_handle)
{
    VadModel *mm;
    mm = new FsmnVadOnline((FsmnVad*)fsmnvad_handle);
    return mm;
}
} // namespace funasr
funasr/runtime/python/websocket/wss_srv_asr.py
@@ -58,16 +58,36 @@
    model=args.asr_model_online,
    ngpu=args.ngpu,
    ncpu=args.ncpu,
    model_revision='v1.0.6',
    model_revision='v1.0.4',
    update_model='v1.0.4',
    mode='paraformer_streaming')
print("model loaded")
print("model loaded! only support one client at the same time now!!!!")
async def ws_reset(websocket):
    print("ws reset now, total num is ",len(websocket_users))
    websocket.param_dict_asr_online = {"cache": dict()}
    websocket.param_dict_vad = {'in_cache': dict(), "is_final": True}
    websocket.param_dict_asr_online["is_final"]=True
    audio_in=b''.join(np.zeros(int(16000),dtype=np.int16))
    inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
    inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online)
    await websocket.close()
async def clear_websocket():
   for websocket in websocket_users:
       await ws_reset(websocket)
   websocket_users.clear()
async def ws_serve(websocket, path):
    frames = []
    frames_asr = []
    frames_asr_online = []
    global websocket_users
    await clear_websocket()
    websocket_users.add(websocket)
    websocket.param_dict_asr = {}
    websocket.param_dict_asr_online = {"cache": dict()}
@@ -75,7 +95,7 @@
    websocket.param_dict_punc = {'cache': list()}
    websocket.vad_pre_idx = 0
    speech_start = False
    speech_end_i = False
    speech_end_i = -1
    websocket.wav_name = "microphone"
    websocket.mode = "2pass"
    print("new user connected", flush=True)
@@ -104,7 +124,7 @@
        
                    # asr online
                    frames_asr_online.append(message)
                    websocket.param_dict_asr_online["is_final"] = speech_end_i
                    websocket.param_dict_asr_online["is_final"] = speech_end_i != -1
                    if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.param_dict_asr_online["is_final"]:
                        if websocket.mode == "2pass" or websocket.mode == "online":
                            audio_in = b"".join(frames_asr_online)
@@ -114,14 +134,14 @@
                        frames_asr.append(message)
                    # vad online
                    speech_start_i, speech_end_i = await async_vad(websocket, message)
                    if speech_start_i:
                    if speech_start_i != -1:
                        speech_start = True
                        beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
                        frames_pre = frames[-beg_bias:]
                        frames_asr = []
                        frames_asr.extend(frames_pre)
                # asr punc offline
                if speech_end_i or not websocket.is_speaking:
                if speech_end_i != -1 or not websocket.is_speaking:
                    # print("vad end point")
                    if websocket.mode == "2pass" or websocket.mode == "offline":
                        audio_in = b"".join(frames_asr)
@@ -139,7 +159,8 @@
     
    except websockets.ConnectionClosed:
        print("ConnectionClosed...", websocket_users)
        print("ConnectionClosed...", websocket_users,flush=True)
        await ws_reset(websocket)
        websocket_users.remove(websocket)
    except websockets.InvalidState:
        print("InvalidState...")
@@ -151,15 +172,15 @@
    segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
    speech_start = False
    speech_end = False
    speech_start = -1
    speech_end = -1
    
    if len(segments_result) == 0 or len(segments_result["text"]) > 1:
        return speech_start, speech_end
    if segments_result["text"][0][0] != -1:
        speech_start = segments_result["text"][0][0]
    if segments_result["text"][0][1] != -1:
        speech_end = True
        speech_end = segments_result["text"][0][1]
    return speech_start, speech_end
funasr/runtime/websocket/CMakeLists.txt
@@ -2,6 +2,8 @@
project(FunASRWebscoket) 
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
funasr/version.txt
@@ -1 +1 @@
0.5.7
0.5.8