From 16a3cd3cfb866e99b68417e36a7c8ae613678fbf Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 27 三月 2024 14:11:27 +0800
Subject: [PATCH] update warmup to 10
---
runtime/onnxruntime/src/paraformer-torch.cpp | 67 +++++++++++++++++++++++++++------
1 files changed, 55 insertions(+), 12 deletions(-)
diff --git a/runtime/onnxruntime/src/paraformer-torch.cpp b/runtime/onnxruntime/src/paraformer-torch.cpp
index 1f15ec7..06c88f6 100644
--- a/runtime/onnxruntime/src/paraformer-torch.cpp
+++ b/runtime/onnxruntime/src/paraformer-torch.cpp
@@ -38,15 +38,22 @@
LOG(ERROR) << "CUDA is not available! Please check your GPU settings";
exit(-1);
} else {
- LOG(INFO) << "CUDA available! Running on GPU";
+ LOG(INFO) << "CUDA is available, running on GPU";
device = at::kCUDA;
}
#endif
#ifdef USE_IPEX
torch::jit::setTensorExprFuserEnabled(false);
#endif
- torch::jit::script::Module model = torch::jit::load(am_model, device);
- model_ = std::make_shared<TorchModule>(std::move(model));
+
+ try {
+ torch::jit::script::Module model = torch::jit::load(am_model, device);
+ model_ = std::make_shared<TorchModule>(std::move(model));
+ LOG(INFO) << "Successfully load model from " << am_model;
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load am model: " << am_model << e.what();
+ exit(-1);
+ }
}
void ParaformerTorch::InitLm(const std::string &lm_file,
@@ -280,6 +287,7 @@
paraformer_length.emplace_back(num_frames);
torch::NoGradGuard no_grad;
+ model_->eval();
torch::Tensor feats =
torch::from_blob(wav_feats.data(),
{1, num_frames, feat_dim}, torch::kFloat).contiguous();
@@ -305,15 +313,49 @@
am_scores = outputs[0].toTensor();
valid_token_lens = outputs[1].toTensor();
#endif
-
- if (lm_ == nullptr) {
- result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
- } else {
- result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
- if (input_finished) {
- result = FinalizeDecode(wfst_decoder);
+ // timestamp
+ if(outputs.size() == 4){
+ torch::Tensor us_alphas_tensor;
+ torch::Tensor us_peaks_tensor;
+ #ifdef USE_GPU
+ us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
+ us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
+ #else
+ us_alphas_tensor = outputs[2].toTensor();
+ us_peaks_tensor = outputs[3].toTensor();
+ #endif
+
+ int us_alphas_shape_1 = us_alphas_tensor.size(1);
+ float* us_alphas_data = us_alphas_tensor.data_ptr<float>();
+ std::vector<float> us_alphas(us_alphas_shape_1);
+ for (int i = 0; i < us_alphas_shape_1; i++) {
+ us_alphas[i] = us_alphas_data[i];
}
- }
+
+ int us_peaks_shape_1 = us_peaks_tensor.size(1);
+ float* us_peaks_data = us_peaks_tensor.data_ptr<float>();
+ std::vector<float> us_peaks(us_peaks_shape_1);
+ for (int i = 0; i < us_peaks_shape_1; i++) {
+ us_peaks[i] = us_peaks_data[i];
+ }
+ if (lm_ == nullptr) {
+ result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
+ } else {
+ result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
+ if (input_finished) {
+ result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
+ }
+ }
+ }else{
+ if (lm_ == nullptr) {
+ result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
+ } else {
+ result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
+ if (input_finished) {
+ result = FinalizeDecode(wfst_decoder);
+ }
+ }
+ }
}
catch (std::exception const &e)
{
@@ -324,7 +366,8 @@
}
std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
- std::vector<std::vector<float>> result;
+ // TODO
+ std::vector<std::vector<float>> result(1, std::vector<float>(10, 0.0f));
return result;
}
--
Gitblit v1.9.1