lyblsgo
2023-04-24 b6d0ab4bfba04037203b3b9f6a34951e1525f36a
funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -143,14 +143,14 @@
    }
}
string Paraformer::GreedySearch(float * in, int n_len )
string Paraformer::GreedySearch(float * in, int n_len,  int64_t token_nums)
{
    vector<int> hyps;
    int Tmax = n_len;
    for (int i = 0; i < Tmax; i++) {
        int max_idx;
        float max_val;
        FindMax(in + i * 8404, 8404, max_val, max_idx);
        FindMax(in + i * token_nums, token_nums, max_val, max_idx);
        hyps.push_back(max_idx);
    }
@@ -238,7 +238,7 @@
        int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
        float* floatData = outputTensor[0].GetTensorMutableData<float>();
        auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
        result = GreedySearch(floatData, *encoder_out_lens);
        result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
    }
    catch (std::exception const &e)
    {