From bc723ea200144bd6fa8a5dff4b9a780feda144fc Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 29 六月 2023 18:55:01 +0800
Subject: [PATCH] dcos
---
funasr/runtime/onnxruntime/src/paraformer.cpp | 89 ++++++++++++++++++--------------------------
1 files changed, 37 insertions(+), 52 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index 493dd6d..b605fff 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -1,36 +1,19 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
#include "precomp.h"
using namespace std;
-using namespace paraformer;
-Paraformer::Paraformer(const char* path,int thread_num, bool quantize, bool use_vad, bool use_punc)
+namespace funasr {
+
+Paraformer::Paraformer()
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
- string model_path;
- string cmvn_path;
- string config_path;
+}
- // VAD model
- if(use_vad){
- string vad_path = PathAppend(path, "vad_model.onnx");
- string mvn_path = PathAppend(path, "vad.mvn");
- vad_handle = make_unique<FsmnVad>();
- vad_handle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
- }
-
- // PUNC model
- if(use_punc){
- punc_handle = make_unique<CTTransformer>(path, thread_num);
- }
-
- if(quantize)
- {
- model_path = PathAppend(path, "model_quant.onnx");
- }else{
- model_path = PathAppend(path, "model.onnx");
- }
- cmvn_path = PathAppend(path, "am.mvn");
- config_path = PathAppend(path, "config.yaml");
-
+void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
// knf options
fbank_opts.frame_opts.dither = 0;
fbank_opts.mel_opts.num_bins = 80;
@@ -48,12 +31,13 @@
// DisableCpuMemArena can improve performance
session_options.DisableCpuMemArena();
-#ifdef _WIN32
- wstring wstrPath = strToWstr(model_path);
- m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
-#else
- m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
-#endif
+ try {
+ m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
+ LOG(INFO) << "Successfully load model from " << am_model;
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load am onnx model: " << e.what();
+ exit(0);
+ }
string strName;
GetInputName(m_session.get(), strName);
@@ -70,8 +54,8 @@
m_szInputNames.push_back(item.c_str());
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
- vocab = new Vocab(config_path.c_str());
- LoadCmvn(cmvn_path.c_str());
+ vocab = new Vocab(am_config.c_str());
+ LoadCmvn(am_cmvn.c_str());
}
Paraformer::~Paraformer()
@@ -84,17 +68,13 @@
{
}
-vector<std::vector<int>> Paraformer::VadSeg(std::vector<float>& pcm_data){
- return vad_handle->Infer(pcm_data);
-}
-
-string Paraformer::AddPunc(const char* sz_input){
- return punc_handle->AddPunc(sz_input);
-}
-
vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) {
knf::OnlineFbank fbank_(fbank_opts);
- fbank_.AcceptWaveform(sample_rate, waves, len);
+ std::vector<float> buf(len);
+ for (int32_t i = 0; i != len; ++i) {
+ buf[i] = waves[i] * 32768;
+ }
+ fbank_.AcceptWaveform(sample_rate, buf.data(), buf.size());
//fbank_->InputFinished();
int32_t frames = fbank_.NumFramesReady();
int32_t feature_dim = fbank_opts.mel_opts.num_bins;
@@ -113,6 +93,10 @@
void Paraformer::LoadCmvn(const char *filename)
{
ifstream cmvn_stream(filename);
+ if (!cmvn_stream.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << filename;
+ exit(0);
+ }
string line;
while (getline(cmvn_stream, line)) {
@@ -143,14 +127,14 @@
}
}
-string Paraformer::GreedySearch(float * in, int n_len )
+string Paraformer::GreedySearch(float * in, int n_len, int64_t token_nums)
{
vector<int> hyps;
int Tmax = n_len;
for (int i = 0; i < Tmax; i++) {
int max_idx;
float max_val;
- FindMax(in + i * 8404, 8404, max_val, max_idx);
+ FindMax(in + i * token_nums, token_nums, max_val, max_idx);
hyps.push_back(max_idx);
}
@@ -238,11 +222,11 @@
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
float* floatData = outputTensor[0].GetTensorMutableData<float>();
auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
- result = GreedySearch(floatData, *encoder_out_lens);
+ result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
}
catch (std::exception const &e)
{
- printf(e.what());
+ LOG(ERROR)<<e.what();
}
return result;
@@ -251,12 +235,13 @@
string Paraformer::ForwardChunk(float* din, int len, int flag)
{
- printf("Not Imp!!!!!!\n");
- return "Hello";
+ LOG(ERROR)<<"Not Imp!!!!!!";
+ return "";
}
string Paraformer::Rescoring()
{
- printf("Not Imp!!!!!!\n");
- return "Hello";
+ LOG(ERROR)<<"Not Imp!!!!!!";
+ return "";
}
+} // namespace funasr
--
Gitblit v1.9.1