From 9ba0dbd98bf69c830dfcfde8f109a400cb65e4e5 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 29 三月 2024 17:24:59 +0800
Subject: [PATCH] fix func Forward
---
runtime/onnxruntime/src/paraformer-torch.cpp | 143 ++++++++++++++++++++++++++++-------------------
1 files changed, 84 insertions(+), 59 deletions(-)
diff --git a/runtime/onnxruntime/src/paraformer-torch.cpp b/runtime/onnxruntime/src/paraformer-torch.cpp
index bdfd0ee..e603e89 100644
--- a/runtime/onnxruntime/src/paraformer-torch.cpp
+++ b/runtime/onnxruntime/src/paraformer-torch.cpp
@@ -45,8 +45,15 @@
#ifdef USE_IPEX
torch::jit::setTensorExprFuserEnabled(false);
#endif
- torch::jit::script::Module model = torch::jit::load(am_model, device);
- model_ = std::make_shared<TorchModule>(std::move(model));
+
+ try {
+ torch::jit::script::Module model = torch::jit::load(am_model, device);
+ model_ = std::make_shared<TorchModule>(std::move(model));
+ LOG(INFO) << "Successfully load model from " << am_model;
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load am model: " << am_model << e.what();
+ exit(-1);
+ }
}
void ParaformerTorch::InitLm(const std::string &lm_file,
@@ -258,34 +265,50 @@
asr_feats = out_feats;
}
-string ParaformerTorch::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
+std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
{
WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
+ int32_t feature_dim = lfr_m*in_feat_dim;
- std::vector<std::vector<float>> asr_feats;
- FbankKaldi(asr_sample_rate, din, len, asr_feats);
- if(asr_feats.size() == 0){
- return "";
- }
- LfrCmvn(asr_feats);
- int32_t feat_dim = lfr_m*in_feat_dim;
- int32_t num_frames = asr_feats.size();
-
- std::vector<float> wav_feats;
- for (const auto &frame_feat: asr_feats) {
- wav_feats.insert(wav_feats.end(), frame_feat.begin(), frame_feat.end());
- }
+ std::vector<vector<float>> feats_batch;
std::vector<int32_t> paraformer_length;
- paraformer_length.emplace_back(num_frames);
+ int max_size = 0;
+ int max_frames = 0;
+ for(int index=0; index<batch_in; index++){
+ std::vector<std::vector<float>> asr_feats;
+ FbankKaldi(asr_sample_rate, din[index], len[index], asr_feats);
+ if(asr_feats.size() != 0){
+ LfrCmvn(asr_feats);
+ }
+ int32_t num_frames = asr_feats.size() / feature_dim;
+ paraformer_length.emplace_back(num_frames);
+ if(max_size < asr_feats.size()){
+ max_size = asr_feats.size();
+ max_frames = num_frames;
+ }
+
+ std::vector<float> flattened;
+ for (const auto& sub_vector : asr_feats) {
+ flattened.insert(flattened.end(), sub_vector.begin(), sub_vector.end());
+ }
+ feats_batch.emplace_back(flattened);
+ }
torch::NoGradGuard no_grad;
model_->eval();
+ // padding
+ std::vector<float> all_feats(batch_in * max_frames * feature_dim);
+ for(int index=0; index<batch_in; index++){
+ feats_batch[index].resize(max_size);
+ std::memcpy(&all_feats[index * max_frames * feature_dim], feats_batch[index].data(),
+ max_frames * feature_dim * sizeof(float));
+ }
torch::Tensor feats =
- torch::from_blob(wav_feats.data(),
- {1, num_frames, feat_dim}, torch::kFloat).contiguous();
+ torch::from_blob(all_feats.data(),
+ {batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
- {1}, torch::kInt32);
+ {batch_in}, torch::kInt32);
// 2. forward
#ifdef USE_GPU
@@ -294,7 +317,7 @@
#endif
std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
- string result="";
+ vector<std::string> results;
try {
auto outputs = model_->forward(inputs).toTuple()->elements();
torch::Tensor am_scores;
@@ -307,47 +330,49 @@
valid_token_lens = outputs[1].toTensor();
#endif
// timestamp
- if(outputs.size() == 4){
- torch::Tensor us_alphas_tensor;
- torch::Tensor us_peaks_tensor;
- #ifdef USE_GPU
- us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
- us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
- #else
- us_alphas_tensor = outputs[2].toTensor();
- us_peaks_tensor = outputs[3].toTensor();
- #endif
+ for(int index=0; index<batch_in; index++){
+ string result="";
+ if(outputs.size() == 4){
+ torch::Tensor us_alphas_tensor;
+ torch::Tensor us_peaks_tensor;
+ #ifdef USE_GPU
+ us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
+ us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
+ #else
+ us_alphas_tensor = outputs[2].toTensor();
+ us_peaks_tensor = outputs[3].toTensor();
+ #endif
- int us_alphas_shape_1 = us_alphas_tensor.size(1);
- float* us_alphas_data = us_alphas_tensor.data_ptr<float>();
- std::vector<float> us_alphas(us_alphas_shape_1);
- for (int i = 0; i < us_alphas_shape_1; i++) {
- us_alphas[i] = us_alphas_data[i];
- }
-
- int us_peaks_shape_1 = us_peaks_tensor.size(1);
- float* us_peaks_data = us_peaks_tensor.data_ptr<float>();
- std::vector<float> us_peaks(us_peaks_shape_1);
- for (int i = 0; i < us_peaks_shape_1; i++) {
- us_peaks[i] = us_peaks_data[i];
- }
- if (lm_ == nullptr) {
- result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
- } else {
- result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
- if (input_finished) {
- result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
+ float* us_alphas_data = us_alphas_tensor[index].data_ptr<float>();
+ std::vector<float> us_alphas(paraformer_length[index]);
+ for (int i = 0; i < us_alphas.size(); i++) {
+ us_alphas[i] = us_alphas_data[i];
}
- }
- }else{
- if (lm_ == nullptr) {
- result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
- } else {
- result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
- if (input_finished) {
- result = FinalizeDecode(wfst_decoder);
+
+ float* us_peaks_data = us_peaks_tensor[index].data_ptr<float>();
+ std::vector<float> us_peaks(paraformer_length[index]);
+ for (int i = 0; i < us_peaks.size(); i++) {
+ us_peaks[i] = us_peaks_data[i];
+ }
+ if (lm_ == nullptr) {
+ result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
+ } else {
+ result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+ if (input_finished) {
+ result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
+ }
+ }
+ }else{
+ if (lm_ == nullptr) {
+ result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+ } else {
+ result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+ if (input_finished) {
+ result = FinalizeDecode(wfst_decoder);
+ }
}
}
+ results.push_back(result);
}
}
catch (std::exception const &e)
@@ -355,7 +380,7 @@
LOG(ERROR)<<e.what();
}
- return result;
+ return results;
}
std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
--
Gitblit v1.9.1