游雁
2023-10-10 f974935484d5d8eb37b36eb2646816c02a41184c
Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
11个文件已修改
78 ■■■■ 已修改文件
funasr/runtime/onnxruntime/include/model.h 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/punc-model.h 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/ct-transformer-online.cpp 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/ct-transformer-online.h 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/ct-transformer.cpp 24 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/ct-transformer.h 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/funasrruntime.cpp 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer.h 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/vocab.cpp 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/vocab.h 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/bin/funasr-wss-server.cpp 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/model.h
@@ -18,6 +18,7 @@
    virtual void InitHwCompiler(const std::string &hw_model, int thread_num){};
    virtual void InitSegDict(const std::string &seg_dict_model){};
    virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
    virtual std::string GetLang(){return "";};
};
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
funasr/runtime/onnxruntime/include/punc-model.h
@@ -12,8 +12,8 @@
  public:
    virtual ~PuncModel(){};
      virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0;
      virtual std::string AddPunc(const char* sz_input){return "";};
      virtual std::string AddPunc(const char* sz_input, std::vector<std::string>& arr_cache){return "";};
      virtual std::string AddPunc(const char* sz_input, std::string language="zh-cn"){return "";};
      virtual std::string AddPunc(const char* sz_input, std::vector<std::string>& arr_cache, std::string language="zh-cn"){return "";};
};
PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE);
funasr/runtime/onnxruntime/src/ct-transformer-online.cpp
@@ -50,7 +50,7 @@
{
}
string CTTransformerOnline::AddPunc(const char* sz_input, vector<string> &arr_cache)
string CTTransformerOnline::AddPunc(const char* sz_input, vector<string> &arr_cache, std::string language)
{
    string strResult;
    vector<string> strOut;
funasr/runtime/onnxruntime/src/ct-transformer-online.h
@@ -29,7 +29,7 @@
    void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
    ~CTTransformerOnline();
    vector<int>  Infer(vector<int32_t> input_data, int nCacheSize);
    string AddPunc(const char* sz_input, vector<string> &arr_cache);
    string AddPunc(const char* sz_input, vector<string> &arr_cache, std::string language="zh-cn");
    void Transport(vector<float>& In, int nRows, int nCols);
    void VadMask(int size, int vad_pos,vector<float>& Result);
    void Triangle(int text_length, vector<float>& Result);
funasr/runtime/onnxruntime/src/ct-transformer.cpp
@@ -46,7 +46,7 @@
{
}
string CTTransformer::AddPunc(const char* sz_input)
string CTTransformer::AddPunc(const char* sz_input, std::string language)
{
    string strResult;
    vector<string> strOut;
@@ -139,8 +139,28 @@
            }
        }
    }
    for (auto& item : NewSentenceOut)
    for (auto& item : NewSentenceOut){
        strResult += item;
    }
    if(language == "en-bpe"){
        std::vector<std::string> chineseSymbols;
        chineseSymbols.push_back(",");
        chineseSymbols.push_back("。");
        chineseSymbols.push_back("、");
        chineseSymbols.push_back("?");
        std::string englishSymbols = ",.,?";
        for (size_t i = 0; i < chineseSymbols.size(); i++) {
            size_t pos = 0;
            while ((pos = strResult.find(chineseSymbols[i], pos)) != std::string::npos) {
                strResult.replace(pos, 3, 1, englishSymbols[i]);
                pos++;
            }
        }
    }
    return strResult;
}
funasr/runtime/onnxruntime/src/ct-transformer.h
@@ -29,6 +29,6 @@
    void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
    ~CTTransformer();
    vector<int>  Infer(vector<int32_t> input_data);
    string AddPunc(const char* sz_input);
    string AddPunc(const char* sz_input, std::string language="zh-cn");
};
} // namespace funasr
funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -282,7 +282,8 @@
            p_result->stamp += cur_stamp + "]";
        }
        if(offline_stream->UsePunc()){
            string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
            string lang = (offline_stream->asr_handle)->GetLang();
            string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str(), lang);
            p_result->msg = punc_res;
        }
#if !defined(__APPLE__)
@@ -363,7 +364,8 @@
            p_result->stamp += cur_stamp + "]";
        }
        if(offline_stream->UsePunc()){
            string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
            string lang = (offline_stream->asr_handle)->GetLang();
            string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str(), lang);
            p_result->msg = punc_res;
        }
#if !defined(__APPLE__)
funasr/runtime/onnxruntime/src/paraformer.h
@@ -33,7 +33,6 @@
        vector<const char*> hw_m_szInputNames;
        vector<const char*> hw_m_szOutputNames;
        bool use_hotword;
        std::string language="zh-cn";
    public:
        Paraformer();
@@ -55,6 +54,7 @@
        string PostProcess(std::vector<string> &raw_char, std::vector<std::vector<float>> &timestamp_list);
        string Rescoring();
        string GetLang(){return language;};
        knf::FbankOptions fbank_opts_;
        vector<float> means_list_;
@@ -71,6 +71,8 @@
        vector<const char*> m_szInputNames;
        vector<const char*> m_szOutputNames;
        std::string language="zh-cn";
        // paraformer-online
        std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
        std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
funasr/runtime/onnxruntime/src/vocab.cpp
@@ -75,6 +75,21 @@
    return false;
}
string Vocab::WordFormat(std::string word)
{
    if(word == "i"){
        return "I";
    }else if(word == "i'm"){
        return "I'm";
    }else if(word == "i've"){
        return "I've";
    }else if(word == "i'll"){
        return "I'll";
    }else{
        return word;
    }
}
string Vocab::Vector2StringV2(vector<int> in, std::string language)
{
    int i;
@@ -94,6 +109,7 @@
            size_t found = word.find(unicodeChar);
            if(found != std::string::npos){
                if (combine != ""){
                    combine = WordFormat(combine);
                    if (words.size() != 0){
                        combine = " " + combine;
                    }
@@ -164,6 +180,7 @@
    }
    if (language == "en-bpe" and combine != ""){
        combine = WordFormat(combine);
        if (words.size() != 0){
            combine = " " + combine;
        }
funasr/runtime/onnxruntime/src/vocab.h
@@ -23,6 +23,7 @@
    bool IsChinese(string ch);
    void Vector2String(vector<int> in, std::vector<std::string> &preds);
    string Vector2StringV2(vector<int> in, std::string language="");
    string WordFormat(std::string word);
    int GetIdByToken(const std::string &token);
};
funasr/runtime/websocket/bin/funasr-wss-server.cpp
@@ -195,11 +195,16 @@
                size_t found = s_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
                if (found != std::string::npos) {
                    model_path["model-revision"]="v1.2.4";
                }else{
                    found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
                    if (found != std::string::npos) {
                        model_path["model-revision"]="v1.0.5";
                    }
                }
                found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
                if (found != std::string::npos) {
                    model_path["model-revision"]="v1.0.5";
                }
                found = s_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
                if (found != std::string::npos) {
                    model_path["model-revision"]="v1.0.0";
                }
                // modelscope