雾聪
2024-10-29 1819303f5e8cfc03f4c0ec2495571a54a186d34b
runtime/onnxruntime/src/sensevoice-small.h
@@ -12,12 +12,14 @@
    class SenseVoiceSmall : public Model {
    private:
        Vocab* vocab = nullptr;
        Vocab* online_vocab = nullptr;
        Vocab* lm_vocab = nullptr;
        SegDict* seg_dict = nullptr;
        PhoneSet* phone_set_ = nullptr;
        const float scale = 1.0;
        void LoadConfigFromYaml(const char* filename);
        void LoadOnlineConfigFromYaml(const char* filename);
        void LoadCmvn(const char *filename);
        void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
@@ -34,9 +36,10 @@
        ~SenseVoiceSmall();
        void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
        // online
        // void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
        void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
        // 2pass
        // void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
        void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config,
            const std::string &token_file, const std::string &online_token_file, int thread_num);
        // void InitHwCompiler(const std::string &hw_model, int thread_num);
        // void InitSegDict(const std::string &seg_dict_model);
        std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
@@ -44,7 +47,8 @@
        void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
        std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, std::string svs_lang="auto", bool svs_itn=true, int batch_in=1);
        string CTCSearch( float * in, std::vector<int32_t> paraformer_length, std::vector<int64_t> outputShape);
        string GreedySearch( float* in, int n_len, int64_t token_nums,
                             bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
        string Rescoring();
        string GetLang(){return language;};
        int GetAsrSampleRate() { return asr_sample_rate; };
@@ -100,6 +104,8 @@
        int asr_sample_rate = MODEL_SAMPLE_RATE;
        int batch_size_ = 1;
        int blank_id = 0;
        float cif_threshold = 1.0;
        float tail_alphas = 0.45;
        //dict
        std::map<std::string, int> lid_map = {
            {"auto", 0},