From a58c3d4593892459614aabda689fb61af74e20fe Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 29 三月 2024 16:47:27 +0800
Subject: [PATCH] add batch for paraformer

---
 runtime/onnxruntime/src/paraformer-torch.cpp |  127 +++++++++++++++++++++++-------------------
 1 files changed, 70 insertions(+), 57 deletions(-)

diff --git a/runtime/onnxruntime/src/paraformer-torch.cpp b/runtime/onnxruntime/src/paraformer-torch.cpp
index 06c88f6..e7fbadf 100644
--- a/runtime/onnxruntime/src/paraformer-torch.cpp
+++ b/runtime/onnxruntime/src/paraformer-torch.cpp
@@ -265,34 +265,45 @@
     asr_feats = out_feats;
 }
 
-string ParaformerTorch::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
+std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
 {
     WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
     int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
+    int32_t feature_dim = lfr_m*in_feat_dim;
 
-    std::vector<std::vector<float>> asr_feats;
-    FbankKaldi(asr_sample_rate, din, len, asr_feats);
-    if(asr_feats.size() == 0){
-      return "";
-    }
-    LfrCmvn(asr_feats);
-    int32_t feat_dim = lfr_m*in_feat_dim;
-    int32_t num_frames = asr_feats.size();
-
-    std::vector<float> wav_feats;
-    for (const auto &frame_feat: asr_feats) {
-        wav_feats.insert(wav_feats.end(), frame_feat.begin(), frame_feat.end());
-    }
+    std::vector<vector<float>> feats_batch;
     std::vector<int32_t> paraformer_length;
-    paraformer_length.emplace_back(num_frames);
+    int max_size = 0;
+    int max_frames = 0;
+    for(int index=0; index<batch_in; index++){
+        std::vector<std::vector<float>> asr_feats;
+        FbankKaldi(asr_sample_rate, din[index], len[index], asr_feats);
+        if(asr_feats.size() != 0){
+            LfrCmvn(asr_feats);
+        }
+        feats_batch.emplace_back(asr_feats);
+        int32_t num_frames  = asr_feats.size() / feature_dim;
+        paraformer_length.emplace_back(num_frames);
+        if(max_size < asr_feats.size()){
+            max_size = asr_feats.size();
+            max_frames = num_frames;
+        }
+    }
 
     torch::NoGradGuard no_grad;
     model_->eval();
+    // padding
+    std::vector<float> all_feats(batch_in * max_frames * feature_dim);
+    for(int index=0; index<batch_in; index++){
+        feats_batch[index].resize(max_size);
+        std::memcpy(&all_feats[index * max_frames * feature_dim], feats_batch[index].data(),
+                        max_frames * feature_dim * sizeof(float));
+    }
     torch::Tensor feats =
-        torch::from_blob(wav_feats.data(),
-                {1, num_frames, feat_dim}, torch::kFloat).contiguous();
+        torch::from_blob(all_feats.data(),
+                {batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
     torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
-                        {1}, torch::kInt32);
+                        {batch_in}, torch::kInt32);
 
     // 2. forward
     #ifdef USE_GPU
@@ -301,7 +312,7 @@
     #endif
     std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
 
-    string result="";
+    vector<std::string> results;
     try {
         auto outputs = model_->forward(inputs).toTuple()->elements();
         torch::Tensor am_scores;
@@ -314,47 +325,49 @@
         valid_token_lens = outputs[1].toTensor();
         #endif
         // timestamp
-        if(outputs.size() == 4){
-            torch::Tensor us_alphas_tensor;
-            torch::Tensor us_peaks_tensor;
-            #ifdef USE_GPU
-            us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
-            us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
-            #else
-            us_alphas_tensor = outputs[2].toTensor();
-            us_peaks_tensor = outputs[3].toTensor();
-            #endif
+        for(int index=0; index<batch_in; index++){
+            string result="";
+            if(outputs.size() == 4){
+                torch::Tensor us_alphas_tensor;
+                torch::Tensor us_peaks_tensor;
+                #ifdef USE_GPU
+                us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
+                us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
+                #else
+                us_alphas_tensor = outputs[2].toTensor();
+                us_peaks_tensor = outputs[3].toTensor();
+                #endif
 
-            int us_alphas_shape_1 = us_alphas_tensor.size(1);
-            float* us_alphas_data = us_alphas_tensor.data_ptr<float>();
-            std::vector<float> us_alphas(us_alphas_shape_1);
-            for (int i = 0; i < us_alphas_shape_1; i++) {
-                us_alphas[i] = us_alphas_data[i];
-            }
-
-            int us_peaks_shape_1 = us_peaks_tensor.size(1);
-            float* us_peaks_data = us_peaks_tensor.data_ptr<float>();
-            std::vector<float> us_peaks(us_peaks_shape_1);
-            for (int i = 0; i < us_peaks_shape_1; i++) {
-                us_peaks[i] = us_peaks_data[i];
-            }
-			if (lm_ == nullptr) {
-                result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
-			} else {
-			    result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
-                if (input_finished) {
-                    result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
+                float* us_alphas_data = us_alphas_tensor[index].data_ptr<float>();
+                std::vector<float> us_alphas(paraformer_length[index]);
+                for (int i = 0; i < us_alphas.size(); i++) {
+                    us_alphas[i] = us_alphas_data[i];
                 }
-			}
-        }else{
-            if (lm_ == nullptr) {
-                result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
-            } else {
-                result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
-                if (input_finished) {
-                    result = FinalizeDecode(wfst_decoder);
+
+                float* us_peaks_data = us_peaks_tensor[index].data_ptr<float>();
+                std::vector<float> us_peaks(paraformer_length[index]);
+                for (int i = 0; i < us_peaks.size(); i++) {
+                    us_peaks[i] = us_peaks_data[i];
+                }
+                if (lm_ == nullptr) {
+                    result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
+                } else {
+                    result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+                    if (input_finished) {
+                        result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
+                    }
+                }
+            }else{
+                if (lm_ == nullptr) {
+                    result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+                } else {
+                    result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+                    if (input_finished) {
+                        result = FinalizeDecode(wfst_decoder);
+                    }
                 }
             }
+            results.push_back(result);
         }
     }
     catch (std::exception const &e)
@@ -362,7 +375,7 @@
         LOG(ERROR)<<e.what();
     }
 
-    return result;
+    return results;
 }
 
 std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {

--
Gitblit v1.9.1