From 653088af82ab334c3d4a8ca52bb9ac4d724d18a4 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 09 八月 2023 16:40:14 +0800
Subject: [PATCH] update paraformer-online
---
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp | 2 +-
funasr/runtime/onnxruntime/src/paraformer-online.cpp | 36 ++++++++++++++++++------------------
2 files changed, 19 insertions(+), 19 deletions(-)
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp
index 806c1be..1681ce2 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp
@@ -202,7 +202,7 @@
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
- TCLAP::ValueArg<std::int32_t> thread_num_("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
+ TCLAP::ValueArg<std::int32_t> thread_num_("", THREAD_NUM, "multi-thread num for rtf", false, 1, "int32_t");
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
diff --git a/funasr/runtime/onnxruntime/src/paraformer-online.cpp b/funasr/runtime/onnxruntime/src/paraformer-online.cpp
index 267d30a..dd7e8e1 100644
--- a/funasr/runtime/onnxruntime/src/paraformer-online.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer-online.cpp
@@ -101,27 +101,27 @@
waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
}
if (lfr_splice_cache_.empty()) {
- for (int i = 0; i < (lfr_m - 1) / 2; i++) {
- lfr_splice_cache_.emplace_back(wav_feats[0]);
- }
+ for (int i = 0; i < (lfr_m - 1) / 2; i++) {
+ lfr_splice_cache_.emplace_back(wav_feats[0]);
+ }
}
if (wav_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
- wav_feats.insert(wav_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
- int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
- int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
- int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished);
- int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
- reserve_waveforms_.clear();
- reserve_waveforms_.insert(reserve_waveforms_.begin(),
- waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
- waves.begin() + frame_from_waves * frame_shift_sample_length_);
- int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
- waves.erase(waves.begin() + sample_length, waves.end());
+ wav_feats.insert(wav_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
+ int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
+ int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
+ int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished);
+ int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
+ reserve_waveforms_.clear();
+ reserve_waveforms_.insert(reserve_waveforms_.begin(),
+ waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
+ waves.begin() + frame_from_waves * frame_shift_sample_length_);
+ int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
+ waves.erase(waves.begin() + sample_length, waves.end());
} else {
- reserve_waveforms_.clear();
- reserve_waveforms_.insert(reserve_waveforms_.begin(),
- waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
- lfr_splice_cache_.insert(lfr_splice_cache_.end(), wav_feats.begin(), wav_feats.end());
+ reserve_waveforms_.clear();
+ reserve_waveforms_.insert(reserve_waveforms_.begin(),
+ waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
+ lfr_splice_cache_.insert(lfr_splice_cache_.end(), wav_feats.begin(), wav_feats.end());
}
} else {
if (input_finished) {
--
Gitblit v1.9.1