lyblsgo
2023-04-24 b6d0ab4bfba04037203b3b9f6a34951e1525f36a
fix GreedySearch
9个文件已修改
35 ■■■■■ 已修改文件
funasr/runtime/onnxruntime/include/com-define.h 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/libfunasrapi.h 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/alignedmem.h 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/commonfunc.h 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/model.cpp 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/online-feature.h 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer.cpp 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer.h 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/com-define.h
@@ -28,7 +28,6 @@
// punc
#define PUNC_MODEL_FILE  "punc_model.onnx"
#define PUNC_YAML_FILE "punc.yaml"
#define UNK_CHAR "<unk>"
#define  INPUT_NUM  2
funasr/runtime/onnxruntime/include/libfunasrapi.h
@@ -51,21 +51,14 @@
// if not give a fn_callback ,it should be NULL 
_FUNASRAPI FUNASR_RESULT    FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
_FUNASRAPI FUNASR_RESULT    FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
_FUNASRAPI FUNASR_RESULT    FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
_FUNASRAPI FUNASR_RESULT    FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
_FUNASRAPI const char*    FunASRGetResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI const int        FunASRGetRetNumber(FUNASR_RESULT result);
_FUNASRAPI const int    FunASRGetRetNumber(FUNASR_RESULT result);
_FUNASRAPI void            FunASRFreeResult(FUNASR_RESULT result);
_FUNASRAPI void            FunASRUninit(FUNASR_HANDLE handle);
_FUNASRAPI const float    FunASRGetRetSnippetTime(FUNASR_RESULT result);
#ifdef __cplusplus 
funasr/runtime/onnxruntime/src/alignedmem.h
@@ -2,8 +2,6 @@
#ifndef ALIGNEDMEM_H
#define ALIGNEDMEM_H
extern void *AlignedMalloc(size_t alignment, size_t required_bytes);
extern void AlignedFree(void *p);
funasr/runtime/onnxruntime/src/commonfunc.h
@@ -33,7 +33,6 @@
        {
            auto t = session->GetInputNameAllocated(nIndex, allocator);
            inputName = t.get();
        }
    }
}
@@ -45,7 +44,6 @@
        {
            auto t = session->GetOutputNameAllocated(nIndex, allocator);
            outputName = t.get();
        }
    }
}
funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
@@ -58,7 +58,6 @@
        }else{
            cout <<"No return data!";
        }
    }
    {
        lock_guard<mutex> guard(mtx);
funasr/runtime/onnxruntime/src/model.cpp
@@ -3,8 +3,6 @@
Model *CreateModel(const char *path, int thread_num, bool quantize, bool use_vad, bool use_punc)
{
    Model *mm;
    mm = new paraformer::Paraformer(path, thread_num, quantize, use_vad, use_punc);
    return mm;
}
funasr/runtime/onnxruntime/src/online-feature.h
@@ -12,15 +12,12 @@
  void ExtractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
private:
  void OnlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
  int OnlineLfrCmvn(vector<vector<float>> &vad_feats);
  static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
    int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
    if (frame_num >= 1 && sample_length >= frame_sample_length)
      return frame_num;
    else
funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -143,14 +143,14 @@
    }
}
string Paraformer::GreedySearch(float * in, int n_len )
string Paraformer::GreedySearch(float * in, int n_len,  int64_t token_nums)
{
    vector<int> hyps;
    int Tmax = n_len;
    for (int i = 0; i < Tmax; i++) {
        int max_idx;
        float max_val;
        FindMax(in + i * 8404, 8404, max_val, max_idx);
        FindMax(in + i * token_nums, token_nums, max_val, max_idx);
        hyps.push_back(max_idx);
    }
@@ -238,7 +238,7 @@
        int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
        float* floatData = outputTensor[0].GetTensorMutableData<float>();
        auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
        result = GreedySearch(floatData, *encoder_out_lens);
        result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
    }
    catch (std::exception const &e)
    {
funasr/runtime/onnxruntime/src/paraformer.h
@@ -9,6 +9,11 @@
namespace paraformer {
    class Paraformer : public Model {
    /**
     * Author: Speech Lab of DAMO Academy, Alibaba Group
     * Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
     * https://arxiv.org/pdf/2206.08317.pdf
    */
    private:
        //std::unique_ptr<knf::OnlineFbank> fbank_;
        knf::FbankOptions fbank_opts;
@@ -27,7 +32,7 @@
        vector<float> ApplyLfr(const vector<float> &in);
        void ApplyCmvn(vector<float> *v);
        string GreedySearch( float* in, int n_len);
        string GreedySearch( float* in, int n_len, int64_t token_nums);
        std::shared_ptr<Ort::Session> m_session;
        Ort::Env env_;