From 38c1f6393a16f4d15a8897647b8f8693a48f737d Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 26 六月 2024 11:39:19 +0800
Subject: [PATCH] add warmup for paraformer-torch

---
 runtime/onnxruntime/src/paraformer-torch.cpp |   48 ++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 48 insertions(+), 0 deletions(-)

diff --git a/runtime/onnxruntime/src/paraformer-torch.cpp b/runtime/onnxruntime/src/paraformer-torch.cpp
index 466d80a..4e550b9 100644
--- a/runtime/onnxruntime/src/paraformer-torch.cpp
+++ b/runtime/onnxruntime/src/paraformer-torch.cpp
@@ -55,6 +55,9 @@
         torch::jit::setGraphExecutorOptimize(false);
         torch::jit::FusionStrategy static0 = {{torch::jit::FusionBehavior::STATIC, 0}};
         torch::jit::setFusionStrategy(static0);
+        #ifdef USE_GPU
+        WarmUp();
+        #endif
     } catch (std::exception const &e) {
         LOG(ERROR) << "Error when load am model: " << am_model << e.what();
         exit(-1);
@@ -471,6 +474,51 @@
     return results;
 }
 
+void ParaformerTorch::WarmUp()
+{
+    int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
+    int32_t feature_dim = lfr_m*in_feat_dim;
+    int batch_in = 1;
+    int max_frames = 10;
+    std::vector<int32_t> paraformer_length;
+    paraformer_length.push_back(max_frames);
+
+    std::vector<float> all_feats(batch_in * max_frames * feature_dim, 0.1);
+    torch::Tensor feats =
+        torch::from_blob(all_feats.data(),
+                {batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
+    torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
+                        {batch_in}, torch::kInt32);
+
+    // 2. forward
+    feats = feats.to(at::kCUDA);
+    feat_lens = feat_lens.to(at::kCUDA);
+    std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
+
+    if (use_hotword) {
+        std::string hotwords_wp = "";
+        std::vector<std::vector<float>> hw_emb =  CompileHotwordEmbedding(hotwords_wp);
+        std::vector<float> embedding;
+        embedding.reserve(hw_emb.size() * hw_emb[0].size());
+        for (auto item : hw_emb) {
+            embedding.insert(embedding.end(), item.begin(), item.end());
+        }
+        torch::Tensor tensor_hw_emb =
+            torch::from_blob(embedding.data(),
+                    {batch_in, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())}, torch::kFloat).contiguous();
+        tensor_hw_emb = tensor_hw_emb.to(at::kCUDA);
+        inputs.emplace_back(tensor_hw_emb);
+    }
+
+    try {
+        auto outputs = model_->forward(inputs).toTuple()->elements();
+    }
+    catch (std::exception const &e)
+    {
+        LOG(ERROR)<<e.what();
+    }
+}
+
 std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
     int embedding_dim = encoder_size;
     std::vector<std::vector<float>> hw_emb;

--
Gitblit v1.9.1