游雁
2023-10-19 b9bcf1f093c3053fdc4e2cf4a1d38e27bbf429fb
funasr/runtime/onnxruntime/src/paraformer-online.cpp
@@ -101,35 +101,39 @@
        waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
        }
        if (lfr_splice_cache_.empty()) {
        for (int i = 0; i < (lfr_m - 1) / 2; i++) {
            lfr_splice_cache_.emplace_back(wav_feats[0]);
        }
            for (int i = 0; i < (lfr_m - 1) / 2; i++) {
                lfr_splice_cache_.emplace_back(wav_feats[0]);
            }
        }
        if (wav_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
        wav_feats.insert(wav_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
        int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
        int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
        int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished);
        int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
        reserve_waveforms_.clear();
        reserve_waveforms_.insert(reserve_waveforms_.begin(),
                                    waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
                                    waves.begin() + frame_from_waves * frame_shift_sample_length_);
        int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
        waves.erase(waves.begin() + sample_length, waves.end());
            wav_feats.insert(wav_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
            int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
            int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
            int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished);
            int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
            reserve_waveforms_.clear();
            reserve_waveforms_.insert(reserve_waveforms_.begin(),
                                        waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
                                        waves.begin() + frame_from_waves * frame_shift_sample_length_);
            int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
            waves.erase(waves.begin() + sample_length, waves.end());
        } else {
        reserve_waveforms_.clear();
        reserve_waveforms_.insert(reserve_waveforms_.begin(),
                                    waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
        lfr_splice_cache_.insert(lfr_splice_cache_.end(), wav_feats.begin(), wav_feats.end());
            reserve_waveforms_.clear();
            reserve_waveforms_.insert(reserve_waveforms_.begin(),
                                        waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
            lfr_splice_cache_.insert(lfr_splice_cache_.end(), wav_feats.begin(), wav_feats.end());
        }
    } else {
        if (input_finished) {
        if (!reserve_waveforms_.empty()) {
            waves = reserve_waveforms_;
        }
        wav_feats = lfr_splice_cache_;
        OnlineLfrCmvn(wav_feats, input_finished);
            if (!reserve_waveforms_.empty()) {
                waves = reserve_waveforms_;
            }
            wav_feats = lfr_splice_cache_;
            if(wav_feats.size() == 0){
                LOG(ERROR) << "wav_feats's size is 0";
            }else{
                OnlineLfrCmvn(wav_feats, input_finished);
            }
        }
    }
    if(input_finished){
@@ -465,7 +469,7 @@
    return result;
}
string ParaformerOnline::Forward(float* din, int len, bool input_finished)
string ParaformerOnline::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb)
{
    std::vector<std::vector<float>> wav_feats;
    std::vector<float> waves(din, din+len);