aky15
2023-04-10 d46a542fae26009eee16204a81903862cb4dba73
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -3,14 +3,25 @@
using namespace std;
using namespace paraformer;
ModelImp::ModelImp(const char* path,int nNumThread)
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
{
    string model_path = pathAppend(path, "model.onnx");
    string vocab_path = pathAppend(path, "vocab.txt");
    string model_path;
    string cmvn_path;
    string config_path;
    if(quantize)
    {
        model_path = pathAppend(path, "model_quant.onnx");
    }else{
        model_path = pathAppend(path, "model.onnx");
    }
    cmvn_path = pathAppend(path, "am.mvn");
    config_path = pathAppend(path, "config.yaml");
    fe = new FeatureExtract(3);
    sessionOptions.SetInterOpNumThreads(nNumThread);
    //sessionOptions.SetInterOpNumThreads(1);
    sessionOptions.SetIntraOpNumThreads(nNumThread);
    sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
#ifdef _WIN32
@@ -35,7 +46,8 @@
        m_szInputNames.push_back(item.c_str());
    for (auto& item : m_strOutputNames)
        m_szOutputNames.push_back(item.c_str());
    vocab = new Vocab(vocab_path.c_str());
    vocab = new Vocab(config_path.c_str());
    load_cmvn(cmvn_path.c_str());
}
ModelImp::~ModelImp()
@@ -80,16 +92,49 @@
    din = tmp;
}
void ModelImp::load_cmvn(const char *filename)
{
    ifstream cmvn_stream(filename);
    string line;
    while (getline(cmvn_stream, line)) {
        istringstream iss(line);
        vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
        if (line_item[0] == "<AddShift>") {
            getline(cmvn_stream, line);
            istringstream means_lines_stream(line);
            vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
            if (means_lines[0] == "<LearnRateCoef>") {
                for (int j = 3; j < means_lines.size() - 1; j++) {
                    means_list.push_back(stof(means_lines[j]));
                }
                continue;
            }
        }
        else if (line_item[0] == "<Rescale>") {
            getline(cmvn_stream, line);
            istringstream vars_lines_stream(line);
            vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
            if (vars_lines[0] == "<LearnRateCoef>") {
                for (int j = 3; j < vars_lines.size() - 1; j++) {
                    vars_list.push_back(stof(vars_lines[j])*scale);
                }
                continue;
            }
        }
    }
}
void ModelImp::apply_cmvn(Tensor<float>* din)
{
    const float* var;
    const float* mean;
    float scale = 22.6274169979695;
    var = vars_list.data();
    mean= means_list.data();
    int m = din->size[2];
    int n = din->size[3];
    var = (const float*)paraformer_cmvn_var_hex;
    mean = (const float*)paraformer_cmvn_mean_hex;
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < n; j++) {
            int idx = i * n + j;