| funasr/runtime/onnxruntime/src/paraformer.cpp | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/runtime/onnxruntime/src/paraformer.h | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/runtime/onnxruntime/src/vocab.cpp | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/runtime/onnxruntime/src/vocab.h | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -65,6 +65,7 @@ for (auto& item : m_strOutputNames) m_szOutputNames.push_back(item.c_str()); vocab = new Vocab(am_config.c_str()); LoadConfigFromYaml(am_config.c_str()); LoadCmvn(am_cmvn.c_str()); } @@ -181,6 +182,27 @@ m_szInputNames.push_back(item.c_str()); for (auto& item : m_strOutputNames) m_szOutputNames.push_back(item.c_str()); } void Paraformer::LoadConfigFromYaml(const char* filename){ YAML::Node config; try{ config = YAML::LoadFile(filename); }catch(exception const &e){ LOG(ERROR) << "Error loading file, yaml file error or not exist."; exit(-1); } try{ YAML::Node lang_conf = config["lang"]; if (lang_conf.IsDefined()){ language = lang_conf.as<string>(); } }catch(exception const &e){ LOG(ERROR) << "Error when load argument from vad config YAML."; exit(-1); } } void Paraformer::LoadOnlineConfigFromYaml(const char* filename){ @@ -342,7 +364,7 @@ hyps.push_back(max_idx); } if(!is_stamp){ return vocab->Vector2StringV2(hyps); return vocab->Vector2StringV2(hyps, language); }else{ std::vector<string> char_list; std::vector<std::vector<float>> timestamp_list; @@ -707,17 +729,6 @@ }else{ result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]); } // int pos = 0; // std::vector<std::vector<float>> logits; // for (int j = 0; j < outputShape[1]; j++) // { // std::vector<float> vec_token; // vec_token.insert(vec_token.begin(), floatData + pos, floatData + pos + outputShape[2]); // logits.push_back(vec_token); // pos += outputShape[2]; // } // //PrintMat(logits, "logits_out"); // result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]); } catch (std::exception const &e) { funasr/runtime/onnxruntime/src/paraformer.h
@@ -20,6 +20,7 @@ //const float scale = 22.6274169979695; const float scale = 1.0; void LoadConfigFromYaml(const char* filename); void LoadOnlineConfigFromYaml(const char* filename); void LoadCmvn(const char *filename); vector<float> ApplyLfr(const vector<float> &in); @@ -32,6 +33,7 @@ vector<const char*> hw_m_szInputNames; vector<const char*> hw_m_szOutputNames; bool use_hotword; std::string language="zh-cn"; public: Paraformer(); funasr/runtime/onnxruntime/src/vocab.cpp
@@ -75,20 +75,36 @@ return false; } string Vocab::Vector2StringV2(vector<int> in) string Vocab::Vector2StringV2(vector<int> in, std::string language) { int i; list<string> words; int is_pre_english = false; int pre_english_len = 0; int is_combining = false; string combine = ""; std::string combine = ""; std::string unicodeChar = "▁"; for (auto it = in.begin(); it != in.end(); it++) { string word = vocab[*it]; // step1 space character skips if (word == "<s>" || word == "</s>" || word == "<unk>") continue; if (language == "en-bpe"){ size_t found = word.find(unicodeChar); if(found != std::string::npos){ if (combine != ""){ if (words.size() != 0){ combine = " " + combine; } words.push_back(combine); } combine = word.substr(3); }else{ combine += word; } continue; } // step2 combie phoneme to full word { int sub_word = !(word.find("@@") == string::npos); funasr/runtime/onnxruntime/src/vocab.h
@@ -22,7 +22,7 @@ int Size(); bool IsChinese(string ch); void Vector2String(vector<int> in, std::vector<std::string> &preds); string Vector2StringV2(vector<int> in); string Vector2StringV2(vector<int> in, std::string language=""); int GetIdByToken(const std::string &token); };