From ade08818b7a579aac75182b906a5bd3b8126411c Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 27 五月 2024 15:46:26 +0800
Subject: [PATCH] Merge branch 'dev_batch' into main
---
runtime/onnxruntime/src/paraformer-torch.h | 96 ++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 96 insertions(+), 0 deletions(-)
diff --git a/runtime/onnxruntime/src/paraformer-torch.h b/runtime/onnxruntime/src/paraformer-torch.h
new file mode 100644
index 0000000..e49094d
--- /dev/null
+++ b/runtime/onnxruntime/src/paraformer-torch.h
@@ -0,0 +1,96 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+#pragma once
+#define C10_USE_GLOG
+#include <torch/serialize.h>
+#include <torch/script.h>
+#include <torch/torch.h>
+#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
+#include "precomp.h"
+#include "fst/fstlib.h"
+#include "fst/symbol-table.h"
+#include "bias-lm.h"
+#include "phone-set.h"
+
+namespace funasr {
+
+ class ParaformerTorch : public Model {
+ /**
+ * Author: Speech Lab of DAMO Academy, Alibaba Group
+ * Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ * https://arxiv.org/pdf/2206.08317.pdf
+ */
+ private:
+ Vocab* vocab = nullptr;
+ Vocab* lm_vocab = nullptr;
+ SegDict* seg_dict = nullptr;
+ PhoneSet* phone_set_ = nullptr;
+ //const float scale = 22.6274169979695;
+ const float scale = 1.0;
+
+ void LoadConfigFromYaml(const char* filename);
+ void LoadCmvn(const char *filename);
+ void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
+
+ using TorchModule = torch::jit::script::Module;
+ std::shared_ptr<TorchModule> model_ = nullptr;
+ std::vector<torch::Tensor> encoder_outs_;
+ bool use_hotword;
+
+ public:
+ ParaformerTorch();
+ ~ParaformerTorch();
+ void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
+ void InitHwCompiler(const std::string &hw_model, int thread_num);
+ void InitSegDict(const std::string &seg_dict_model);
+ std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
+ void Reset();
+ void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
+ std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
+ string GreedySearch( float* in, int n_len, int64_t token_nums,
+ bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
+
+ string Rescoring();
+ string GetLang(){return language;};
+ int GetAsrSampleRate() { return asr_sample_rate; };
+ void SetBatchSize(int batch_size) {batch_size_ = batch_size;};
+ int GetBatchSize() {return batch_size_;};
+ void StartUtterance();
+ void EndUtterance();
+ void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
+ string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums);
+ string FinalizeDecode(WfstDecoder* &wfst_decoder,
+ bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
+ Vocab* GetVocab();
+ Vocab* GetLmVocab();
+ PhoneSet* GetPhoneSet();
+
+ knf::FbankOptions fbank_opts_;
+ vector<float> means_list_;
+ vector<float> vars_list_;
+ int lfr_m = PARA_LFR_M;
+ int lfr_n = PARA_LFR_N;
+
+ // paraformer-offline
+ std::string language="zh-cn";
+
+ // lm
+ std::shared_ptr<fst::Fst<fst::StdArc>> lm_ = nullptr;
+
+ string window_type = "hamming";
+ int frame_length = 25;
+ int frame_shift = 10;
+ int n_mels = 80;
+ int encoder_size = 512;
+ int fsmn_layers = 16;
+ int fsmn_lorder = 10;
+ int fsmn_dims = 512;
+ float cif_threshold = 1.0;
+ float tail_alphas = 0.45;
+ int asr_sample_rate = MODEL_SAMPLE_RATE;
+ int batch_size_ = 1;
+ };
+
+} // namespace funasr
--
Gitblit v1.9.1