From a58c3d4593892459614aabda689fb61af74e20fe Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 29 三月 2024 16:47:27 +0800
Subject: [PATCH] add batch for paraformer

---
 runtime/onnxruntime/src/paraformer.h         |    4 +
 runtime/onnxruntime/src/paraformer-torch.cpp |  127 +++++++++++++++++++++++-------------------
 runtime/onnxruntime/src/paraformer-torch.h   |    5 +
 runtime/onnxruntime/src/paraformer.cpp       |   24 +++++--
 4 files changed, 94 insertions(+), 66 deletions(-)

diff --git a/runtime/onnxruntime/src/paraformer-torch.cpp b/runtime/onnxruntime/src/paraformer-torch.cpp
index 06c88f6..e7fbadf 100644
--- a/runtime/onnxruntime/src/paraformer-torch.cpp
+++ b/runtime/onnxruntime/src/paraformer-torch.cpp
@@ -265,34 +265,45 @@
     asr_feats = out_feats;
 }
 
-string ParaformerTorch::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
+std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
 {
     WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
     int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
+    int32_t feature_dim = lfr_m*in_feat_dim;
 
-    std::vector<std::vector<float>> asr_feats;
-    FbankKaldi(asr_sample_rate, din, len, asr_feats);
-    if(asr_feats.size() == 0){
-      return "";
-    }
-    LfrCmvn(asr_feats);
-    int32_t feat_dim = lfr_m*in_feat_dim;
-    int32_t num_frames = asr_feats.size();
-
-    std::vector<float> wav_feats;
-    for (const auto &frame_feat: asr_feats) {
-        wav_feats.insert(wav_feats.end(), frame_feat.begin(), frame_feat.end());
-    }
+    std::vector<vector<float>> feats_batch;
     std::vector<int32_t> paraformer_length;
-    paraformer_length.emplace_back(num_frames);
+    int max_size = 0;
+    int max_frames = 0;
+    for(int index=0; index<batch_in; index++){
+        std::vector<std::vector<float>> asr_feats;
+        FbankKaldi(asr_sample_rate, din[index], len[index], asr_feats);
+        if(asr_feats.size() != 0){
+            LfrCmvn(asr_feats);
+        }
+        feats_batch.emplace_back(asr_feats);
+        int32_t num_frames  = asr_feats.size() / feature_dim;
+        paraformer_length.emplace_back(num_frames);
+        if(max_size < asr_feats.size()){
+            max_size = asr_feats.size();
+            max_frames = num_frames;
+        }
+    }
 
     torch::NoGradGuard no_grad;
     model_->eval();
+    // padding
+    std::vector<float> all_feats(batch_in * max_frames * feature_dim);
+    for(int index=0; index<batch_in; index++){
+        feats_batch[index].resize(max_size);
+        std::memcpy(&all_feats[index * max_frames * feature_dim], feats_batch[index].data(),
+                        max_frames * feature_dim * sizeof(float));
+    }
     torch::Tensor feats =
-        torch::from_blob(wav_feats.data(),
-                {1, num_frames, feat_dim}, torch::kFloat).contiguous();
+        torch::from_blob(all_feats.data(),
+                {batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
     torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
-                        {1}, torch::kInt32);
+                        {batch_in}, torch::kInt32);
 
     // 2. forward
     #ifdef USE_GPU
@@ -301,7 +312,7 @@
     #endif
     std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
 
-    string result="";
+    vector<std::string> results;
     try {
         auto outputs = model_->forward(inputs).toTuple()->elements();
         torch::Tensor am_scores;
@@ -314,47 +325,49 @@
         valid_token_lens = outputs[1].toTensor();
         #endif
         // timestamp
-        if(outputs.size() == 4){
-            torch::Tensor us_alphas_tensor;
-            torch::Tensor us_peaks_tensor;
-            #ifdef USE_GPU
-            us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
-            us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
-            #else
-            us_alphas_tensor = outputs[2].toTensor();
-            us_peaks_tensor = outputs[3].toTensor();
-            #endif
+        for(int index=0; index<batch_in; index++){
+            string result="";
+            if(outputs.size() == 4){
+                torch::Tensor us_alphas_tensor;
+                torch::Tensor us_peaks_tensor;
+                #ifdef USE_GPU
+                us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
+                us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
+                #else
+                us_alphas_tensor = outputs[2].toTensor();
+                us_peaks_tensor = outputs[3].toTensor();
+                #endif
 
-            int us_alphas_shape_1 = us_alphas_tensor.size(1);
-            float* us_alphas_data = us_alphas_tensor.data_ptr<float>();
-            std::vector<float> us_alphas(us_alphas_shape_1);
-            for (int i = 0; i < us_alphas_shape_1; i++) {
-                us_alphas[i] = us_alphas_data[i];
-            }
-
-            int us_peaks_shape_1 = us_peaks_tensor.size(1);
-            float* us_peaks_data = us_peaks_tensor.data_ptr<float>();
-            std::vector<float> us_peaks(us_peaks_shape_1);
-            for (int i = 0; i < us_peaks_shape_1; i++) {
-                us_peaks[i] = us_peaks_data[i];
-            }
-			if (lm_ == nullptr) {
-                result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
-			} else {
-			    result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
-                if (input_finished) {
-                    result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
+                float* us_alphas_data = us_alphas_tensor[index].data_ptr<float>();
+                std::vector<float> us_alphas(paraformer_length[index]);
+                for (int i = 0; i < us_alphas.size(); i++) {
+                    us_alphas[i] = us_alphas_data[i];
                 }
-			}
-        }else{
-            if (lm_ == nullptr) {
-                result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
-            } else {
-                result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
-                if (input_finished) {
-                    result = FinalizeDecode(wfst_decoder);
+
+                float* us_peaks_data = us_peaks_tensor[index].data_ptr<float>();
+                std::vector<float> us_peaks(paraformer_length[index]);
+                for (int i = 0; i < us_peaks.size(); i++) {
+                    us_peaks[i] = us_peaks_data[i];
+                }
+                if (lm_ == nullptr) {
+                    result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
+                } else {
+                    result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+                    if (input_finished) {
+                        result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
+                    }
+                }
+            }else{
+                if (lm_ == nullptr) {
+                    result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+                } else {
+                    result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+                    if (input_finished) {
+                        result = FinalizeDecode(wfst_decoder);
+                    }
                 }
             }
+            results.push_back(result);
         }
     }
     catch (std::exception const &e)
@@ -362,7 +375,7 @@
         LOG(ERROR)<<e.what();
     }
 
-    return result;
+    return results;
 }
 
 std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
diff --git a/runtime/onnxruntime/src/paraformer-torch.h b/runtime/onnxruntime/src/paraformer-torch.h
index f099fbc..60f1582 100644
--- a/runtime/onnxruntime/src/paraformer-torch.h
+++ b/runtime/onnxruntime/src/paraformer-torch.h
@@ -48,13 +48,15 @@
         std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
         void Reset();
         void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
-        string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
+        std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
         string GreedySearch( float* in, int n_len, int64_t token_nums,
                              bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
 
         string Rescoring();
         string GetLang(){return language;};
         int GetAsrSampleRate() { return asr_sample_rate; };
+        void SetBatchSize(int batch_size) {batch_size_ = batch_size};
+        int GetBatchSize() {return batch_size_;};
         void StartUtterance();
         void EndUtterance();
         void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
@@ -88,6 +90,7 @@
         float cif_threshold = 1.0;
         float tail_alphas = 0.45;
         int asr_sample_rate = MODEL_SAMPLE_RATE;
+        int batch_size_ = 1;
     };
 
 } // namespace funasr
diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp
index c56421c..e690458 100644
--- a/runtime/onnxruntime/src/paraformer.cpp
+++ b/runtime/onnxruntime/src/paraformer.cpp
@@ -462,15 +462,23 @@
     asr_feats = out_feats;
 }
 
-string Paraformer::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
+std::vector<std::string> Paraformer::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
 {
+    std::vector<std::string> results;
+    string result="";
     WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
     int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
 
+    if(batch_in != 1){
+        results.push_back(result);
+        return results;
+    }
+
     std::vector<std::vector<float>> asr_feats;
-    FbankKaldi(asr_sample_rate, din, len, asr_feats);
+    FbankKaldi(asr_sample_rate, din[0], len[0], asr_feats);
     if(asr_feats.size() == 0){
-      return "";
+        results.push_back(result);
+        return results;
     }
     LfrCmvn(asr_feats);
     int32_t feat_dim = lfr_m*in_feat_dim;
@@ -509,7 +517,8 @@
         if (use_hotword) {
             if(hw_emb.size()<=0){
                 LOG(ERROR) << "hw_emb is null";
-                return "";
+                results.push_back(result);
+                return results;
             }
             //PrintMat(hw_emb, "input_clas_emb");
             const int64_t hotword_shape[3] = {1, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())};
@@ -526,10 +535,10 @@
     }catch (std::exception const &e)
     {
         LOG(ERROR)<<e.what();
-        return "";
+        results.push_back(result);
+        return results;
     }
 
-    string result="";
     try {
         auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
         std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
@@ -577,7 +586,8 @@
         LOG(ERROR)<<e.what();
     }
 
-    return result;
+    results.push_back(result);
+    return results;
 }
 
 
diff --git a/runtime/onnxruntime/src/paraformer.h b/runtime/onnxruntime/src/paraformer.h
index 5bb9477..aa683bb 100644
--- a/runtime/onnxruntime/src/paraformer.h
+++ b/runtime/onnxruntime/src/paraformer.h
@@ -52,13 +52,14 @@
         std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
         void Reset();
         void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
-        string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
+        std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
         string GreedySearch( float* in, int n_len, int64_t token_nums,
                              bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
 
         string Rescoring();
         string GetLang(){return language;};
         int GetAsrSampleRate() { return asr_sample_rate; };
+        int GetBatchSize() {return batch_size_;};
         void StartUtterance();
         void EndUtterance();
         void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
@@ -110,6 +111,7 @@
         float cif_threshold = 1.0;
         float tail_alphas = 0.45;
         int asr_sample_rate = MODEL_SAMPLE_RATE;
+        int batch_size_ = 1;
     };
 
 } // namespace funasr

--
Gitblit v1.9.1