From a0048dc766f233827f2e7b5ebed0a0e22fae44b1 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期三, 13 十二月 2023 09:50:58 +0800
Subject: [PATCH] 8k (#1174) (#1175)
---
runtime/onnxruntime/src/paraformer.h | 5 +-
runtime/onnxruntime/src/fsmn-vad-online.cpp | 5 ++
runtime/onnxruntime/src/paraformer-online.cpp | 8 +++-
runtime/onnxruntime/include/vad-model.h | 1
runtime/onnxruntime/src/fsmn-vad-online.h | 2 +
runtime/onnxruntime/src/fsmn-vad.h | 2 +
runtime/onnxruntime/src/audio.cpp | 44 +++++++++++++--------
runtime/onnxruntime/src/funasrruntime.cpp | 12 +++---
runtime/onnxruntime/include/model.h | 2 +
runtime/onnxruntime/src/paraformer-online.h | 3 +
runtime/onnxruntime/include/audio.h | 5 +-
runtime/onnxruntime/src/paraformer.cpp | 12 ++++-
12 files changed, 67 insertions(+), 34 deletions(-)
diff --git a/runtime/onnxruntime/include/audio.h b/runtime/onnxruntime/include/audio.h
index 34fcbaf..ce9e16b 100644
--- a/runtime/onnxruntime/include/audio.h
+++ b/runtime/onnxruntime/include/audio.h
@@ -52,10 +52,11 @@
queue<AudioFrame *> frame_queue;
queue<AudioFrame *> asr_online_queue;
queue<AudioFrame *> asr_offline_queue;
-
+ int dest_sample_rate;
public:
Audio(int data_type);
- Audio(int data_type, int size);
+ Audio(int model_sample_rate,int data_type);
+ Audio(int model_sample_rate,int data_type, int size);
~Audio();
void Disp();
void WavResample(int32_t sampling_rate, const float *waveform, int32_t n);
diff --git a/runtime/onnxruntime/include/model.h b/runtime/onnxruntime/include/model.h
index 7b58e92..33caec8 100644
--- a/runtime/onnxruntime/include/model.h
+++ b/runtime/onnxruntime/include/model.h
@@ -23,6 +23,8 @@
virtual void InitSegDict(const std::string &seg_dict_model){};
virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
virtual std::string GetLang(){return "";};
+ virtual int GetAsrSampleRate() = 0;
+
};
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
diff --git a/runtime/onnxruntime/include/vad-model.h b/runtime/onnxruntime/include/vad-model.h
index 07f1833..adb1e20 100644
--- a/runtime/onnxruntime/include/vad-model.h
+++ b/runtime/onnxruntime/include/vad-model.h
@@ -12,6 +12,7 @@
virtual ~VadModel(){};
virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
virtual std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true)=0;
+ virtual int GetVadSampleRate() = 0;
};
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num);
diff --git a/runtime/onnxruntime/src/audio.cpp b/runtime/onnxruntime/src/audio.cpp
index b543797..132f47d 100644
--- a/runtime/onnxruntime/src/audio.cpp
+++ b/runtime/onnxruntime/src/audio.cpp
@@ -193,18 +193,28 @@
return 0;
}
-Audio::Audio(int data_type) : data_type(data_type)
+Audio::Audio(int data_type) : dest_sample_rate(MODEL_SAMPLE_RATE), data_type(data_type)
{
speech_buff = NULL;
speech_data = NULL;
align_size = 1360;
+ seg_sample = dest_sample_rate / 1000;
}
-Audio::Audio(int data_type, int size) : data_type(data_type)
+Audio::Audio(int model_sample_rate, int data_type) : dest_sample_rate(model_sample_rate), data_type(data_type)
+{
+ speech_buff = NULL;
+ speech_data = NULL;
+ align_size = 1360;
+ seg_sample = dest_sample_rate / 1000;
+}
+
+Audio::Audio(int model_sample_rate, int data_type, int size) : dest_sample_rate(model_sample_rate), data_type(data_type)
{
speech_buff = NULL;
speech_data = NULL;
align_size = (float)size;
+ seg_sample = dest_sample_rate / 1000;
}
Audio::~Audio()
@@ -222,12 +232,12 @@
void Audio::Disp()
{
- LOG(INFO) << "Audio time is " << (float)speech_len / MODEL_SAMPLE_RATE << " s. len is " << speech_len;
+ LOG(INFO) << "Audio time is " << (float)speech_len / dest_sample_rate << " s. len is " << speech_len;
}
float Audio::GetTimeLen()
{
- return (float)speech_len / MODEL_SAMPLE_RATE;
+ return (float)speech_len / dest_sample_rate;
}
void Audio::WavResample(int32_t sampling_rate, const float *waveform,
@@ -237,13 +247,13 @@
<< " in_sample_rate: "<< sampling_rate << "\n"
<< " output_sample_rate: " << static_cast<int32_t>(MODEL_SAMPLE_RATE);
float min_freq =
- std::min<int32_t>(sampling_rate, MODEL_SAMPLE_RATE);
+ std::min<int32_t>(sampling_rate, dest_sample_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
auto resampler = std::make_unique<LinearResample>(
- sampling_rate, MODEL_SAMPLE_RATE, lowpass_cutoff, lowpass_filter_width);
+ sampling_rate, dest_sample_rate, lowpass_cutoff, lowpass_filter_width);
std::vector<float> samples;
resampler->Resample(waveform, n, true, &samples);
//reset speech_data
@@ -311,7 +321,7 @@
nullptr, // allocate a new context
AV_CH_LAYOUT_MONO, // output channel layout (stereo)
AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
- 16000, // output sample rate (same as input)
+ dest_sample_rate, // output sample rate (same as input)
av_get_default_channel_layout(codecContext->channels), // input channel layout
codecContext->sample_fmt, // input sample format
codecContext->sample_rate, // input sample rate
@@ -347,7 +357,7 @@
int in_samples = frame->nb_samples;
uint8_t **in_data = frame->extended_data;
int out_samples = av_rescale_rnd(in_samples,
- 16000,
+ dest_sample_rate,
codecContext->sample_rate,
AV_ROUND_DOWN);
@@ -494,7 +504,7 @@
nullptr, // allocate a new context
AV_CH_LAYOUT_MONO, // output channel layout (stereo)
AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
- 16000, // output sample rate (same as input)
+ dest_sample_rate, // output sample rate (same as input)
av_get_default_channel_layout(codecContext->channels), // input channel layout
codecContext->sample_fmt, // input sample format
codecContext->sample_rate, // input sample rate
@@ -532,7 +542,7 @@
int in_samples = frame->nb_samples;
uint8_t **in_data = frame->extended_data;
int out_samples = av_rescale_rnd(in_samples,
- 16000,
+ dest_sample_rate,
codecContext->sample_rate,
AV_ROUND_DOWN);
@@ -666,7 +676,7 @@
}
//resample
- if(*sampling_rate != MODEL_SAMPLE_RATE){
+ if(*sampling_rate != dest_sample_rate){
WavResample(*sampling_rate, speech_data, speech_len);
}
@@ -752,7 +762,7 @@
}
//resample
- if(*sampling_rate != MODEL_SAMPLE_RATE){
+ if(*sampling_rate != dest_sample_rate){
WavResample(*sampling_rate, speech_data, speech_len);
}
@@ -795,7 +805,7 @@
}
//resample
- if(*sampling_rate != MODEL_SAMPLE_RATE){
+ if(*sampling_rate != dest_sample_rate){
WavResample(*sampling_rate, speech_data, speech_len);
}
@@ -840,7 +850,7 @@
}
//resample
- if(*sampling_rate != MODEL_SAMPLE_RATE){
+ if(*sampling_rate != dest_sample_rate){
WavResample(*sampling_rate, speech_data, speech_len);
}
@@ -898,7 +908,7 @@
}
//resample
- if(*sampling_rate != MODEL_SAMPLE_RATE){
+ if(*sampling_rate != dest_sample_rate){
WavResample(*sampling_rate, speech_data, speech_len);
}
@@ -1009,7 +1019,7 @@
AudioFrame *frame = frame_queue.front();
frame_queue.pop();
- start_time = (float)(frame->GetStart())/MODEL_SAMPLE_RATE;
+ start_time = (float)(frame->GetStart())/ dest_sample_rate;
dout = speech_data + frame->GetStart();
len = frame->GetLen();
delete frame;
@@ -1248,7 +1258,7 @@
}
// erase all_samples
- int vector_cache = MODEL_SAMPLE_RATE*2;
+ int vector_cache = dest_sample_rate*2;
if(speech_offline_start == -1){
if(all_samples.size() > vector_cache){
int erase_num = all_samples.size() - vector_cache;
diff --git a/runtime/onnxruntime/src/fsmn-vad-online.cpp b/runtime/onnxruntime/src/fsmn-vad-online.cpp
index a8cc5d8..30627fc 100644
--- a/runtime/onnxruntime/src/fsmn-vad-online.cpp
+++ b/runtime/onnxruntime/src/fsmn-vad-online.cpp
@@ -187,8 +187,11 @@
vad_max_len_ = vad_max_len;
vad_speech_noise_thres_ = vad_speech_noise_thres;
+ frame_sample_length_ = vad_sample_rate_ / 1000 * 25;;
+ frame_shift_sample_length_ = vad_sample_rate_ / 1000 * 10;
+
// 2pass
- audio_handle = make_unique<Audio>(1);
+ audio_handle = make_unique<Audio>(vad_sample_rate,1);
}
FsmnVadOnline::~FsmnVadOnline() {
diff --git a/runtime/onnxruntime/src/fsmn-vad-online.h b/runtime/onnxruntime/src/fsmn-vad-online.h
index 9191304..4c82d11 100644
--- a/runtime/onnxruntime/src/fsmn-vad-online.h
+++ b/runtime/onnxruntime/src/fsmn-vad-online.h
@@ -21,6 +21,8 @@
std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
void Reset();
+ int GetVadSampleRate() { return vad_sample_rate_; };
+
// 2pass
std::unique_ptr<Audio> audio_handle = nullptr;
diff --git a/runtime/onnxruntime/src/fsmn-vad.h b/runtime/onnxruntime/src/fsmn-vad.h
index adceb1f..f06a965 100644
--- a/runtime/onnxruntime/src/fsmn-vad.h
+++ b/runtime/onnxruntime/src/fsmn-vad.h
@@ -28,6 +28,8 @@
std::vector<std::vector<float>> *in_cache,
bool is_final);
void Reset();
+
+ int GetVadSampleRate() { return vad_sample_rate_; };
std::shared_ptr<Ort::Session> vad_session_ = nullptr;
Ort::Env env_;
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index 5c2653f..21f7d82 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -57,7 +57,7 @@
if (!recog_obj)
return nullptr;
- funasr::Audio audio(1);
+ funasr::Audio audio(recog_obj->GetAsrSampleRate(),1);
if(wav_format == "pcm" || wav_format == "PCM"){
if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
return nullptr;
@@ -93,7 +93,7 @@
if (!recog_obj)
return nullptr;
- funasr::Audio audio(1);
+ funasr::Audio audio(recog_obj->GetAsrSampleRate(),1);
if(funasr::is_target_file(sz_filename, "wav")){
int32_t sampling_rate_ = -1;
if(!audio.LoadWav(sz_filename, &sampling_rate_))
@@ -134,7 +134,7 @@
if (!vad_obj)
return nullptr;
- funasr::Audio audio(1);
+ funasr::Audio audio(vad_obj->GetVadSampleRate(),1);
if(wav_format == "pcm" || wav_format == "PCM"){
if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
return nullptr;
@@ -162,7 +162,7 @@
if (!vad_obj)
return nullptr;
- funasr::Audio audio(1);
+ funasr::Audio audio(vad_obj->GetVadSampleRate(),1);
if(funasr::is_target_file(sz_filename, "wav")){
int32_t sampling_rate_ = -1;
if(!audio.LoadWav(sz_filename, &sampling_rate_))
@@ -222,7 +222,7 @@
if (!offline_stream)
return nullptr;
- funasr::Audio audio(1);
+ funasr::Audio audio(offline_stream->asr_handle->GetAsrSampleRate(),1);
try{
if(wav_format == "pcm" || wav_format == "PCM"){
if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
@@ -314,7 +314,7 @@
if (!offline_stream)
return nullptr;
- funasr::Audio audio(1);
+ funasr::Audio audio((offline_stream->asr_handle)->GetAsrSampleRate(),1);
try{
if(funasr::is_target_file(sz_filename, "wav")){
int32_t sampling_rate_ = -1;
diff --git a/runtime/onnxruntime/src/paraformer-online.cpp b/runtime/onnxruntime/src/paraformer-online.cpp
index d08b57e..55a4fd1 100644
--- a/runtime/onnxruntime/src/paraformer-online.cpp
+++ b/runtime/onnxruntime/src/paraformer-online.cpp
@@ -61,7 +61,11 @@
for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
fsmn_init_cache_.emplace_back(0);
}
- chunk_len = chunk_size[1]*frame_shift*lfr_n*MODEL_SAMPLE_RATE/1000;
+ chunk_len = chunk_size[1]*frame_shift*lfr_n*para_handle_->asr_sample_rate/1000;
+
+ frame_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_length;
+ frame_shift_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_shift;
+
}
void ParaformerOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
@@ -489,7 +493,7 @@
if(is_first_chunk){
is_first_chunk = false;
}
- ExtractFeats(MODEL_SAMPLE_RATE, wav_feats, waves, input_finished);
+ ExtractFeats(para_handle_->asr_sample_rate, wav_feats, waves, input_finished);
if(wav_feats.size() == 0){
return result;
}
diff --git a/runtime/onnxruntime/src/paraformer-online.h b/runtime/onnxruntime/src/paraformer-online.h
index 138c77c..8c9bb88 100644
--- a/runtime/onnxruntime/src/paraformer-online.h
+++ b/runtime/onnxruntime/src/paraformer-online.h
@@ -111,6 +111,9 @@
string ForwardChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
string Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
string Rescoring();
+
+ int GetAsrSampleRate() { return para_handle_->asr_sample_rate; };
+
// 2pass
std::string online_res;
int chunk_len;
diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp
index 3de3e39..2c78fe5 100644
--- a/runtime/onnxruntime/src/paraformer.cpp
+++ b/runtime/onnxruntime/src/paraformer.cpp
@@ -19,10 +19,11 @@
// offline
void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+ LoadConfigFromYaml(am_config.c_str());
// knf options
fbank_opts_.frame_opts.dither = 0;
fbank_opts_.mel_opts.num_bins = n_mels;
- fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
+ fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
fbank_opts_.frame_opts.window_type = window_type;
fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
fbank_opts_.frame_opts.frame_length_ms = frame_length;
@@ -65,7 +66,6 @@
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
vocab = new Vocab(am_config.c_str());
- LoadConfigFromYaml(am_config.c_str());
phone_set_ = new PhoneSet(am_config.c_str());
LoadCmvn(am_cmvn.c_str());
}
@@ -77,7 +77,7 @@
// knf options
fbank_opts_.frame_opts.dither = 0;
fbank_opts_.mel_opts.num_bins = n_mels;
- fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
+ fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
fbank_opts_.frame_opts.window_type = window_type;
fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
fbank_opts_.frame_opts.frame_length_ms = frame_length;
@@ -216,6 +216,9 @@
}
try{
+ YAML::Node frontend_conf = config["frontend_conf"];
+ this->asr_sample_rate = frontend_conf["fs"].as<int>();
+
YAML::Node lang_conf = config["lang"];
if (lang_conf.IsDefined()){
language = lang_conf.as<string>();
@@ -258,6 +261,9 @@
this->cif_threshold = predictor_conf["threshold"].as<double>();
this->tail_alphas = predictor_conf["tail_threshold"].as<double>();
+ this->asr_sample_rate = frontend_conf["fs"].as<int>();
+
+
}catch(exception const &e){
LOG(ERROR) << "Error when load argument from vad config YAML.";
exit(-1);
diff --git a/runtime/onnxruntime/src/paraformer.h b/runtime/onnxruntime/src/paraformer.h
index 89c8b09..de05657 100644
--- a/runtime/onnxruntime/src/paraformer.h
+++ b/runtime/onnxruntime/src/paraformer.h
@@ -57,7 +57,7 @@
string Rescoring();
string GetLang(){return language;};
-
+ int GetAsrSampleRate() { return asr_sample_rate; };
void StartUtterance();
void EndUtterance();
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
@@ -107,8 +107,7 @@
int fsmn_dims = 512;
float cif_threshold = 1.0;
float tail_alphas = 0.45;
-
-
+ int asr_sample_rate = MODEL_SAMPLE_RATE;
};
} // namespace funasr
--
Gitblit v1.9.1