雾聪
2024-03-27 16a3cd3cfb866e99b68417e36a7c8ae613678fbf
runtime/onnxruntime/src/paraformer-torch.cpp
@@ -38,15 +38,22 @@
        LOG(ERROR) << "CUDA is not available! Please check your GPU settings";
        exit(-1);
    } else {
        LOG(INFO) << "CUDA available! Running on GPU";
        LOG(INFO) << "CUDA is available, running on GPU";
        device = at::kCUDA;
    }
    #endif
    #ifdef USE_IPEX
    torch::jit::setTensorExprFuserEnabled(false);
    #endif
    torch::jit::script::Module model = torch::jit::load(am_model, device);
    model_ = std::make_shared<TorchModule>(std::move(model));
    try {
        torch::jit::script::Module model = torch::jit::load(am_model, device);
        model_ = std::make_shared<TorchModule>(std::move(model));
        LOG(INFO) << "Successfully load model from " << am_model;
    } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load am model: " << am_model << e.what();
        exit(-1);
    }
}
void ParaformerTorch::InitLm(const std::string &lm_file, 
@@ -280,6 +287,7 @@
    paraformer_length.emplace_back(num_frames);
    torch::NoGradGuard no_grad;
    model_->eval();
    torch::Tensor feats =
        torch::from_blob(wav_feats.data(),
                {1, num_frames, feat_dim}, torch::kFloat).contiguous();
@@ -305,15 +313,49 @@
        am_scores = outputs[0].toTensor();
        valid_token_lens = outputs[1].toTensor();
        #endif
        if (lm_ == nullptr) {
            result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
        } else {
            result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
            if (input_finished) {
                result = FinalizeDecode(wfst_decoder);
        // timestamp
        if(outputs.size() == 4){
            torch::Tensor us_alphas_tensor;
            torch::Tensor us_peaks_tensor;
            #ifdef USE_GPU
            us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
            us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
            #else
            us_alphas_tensor = outputs[2].toTensor();
            us_peaks_tensor = outputs[3].toTensor();
            #endif
            int us_alphas_shape_1 = us_alphas_tensor.size(1);
            float* us_alphas_data = us_alphas_tensor.data_ptr<float>();
            std::vector<float> us_alphas(us_alphas_shape_1);
            for (int i = 0; i < us_alphas_shape_1; i++) {
                us_alphas[i] = us_alphas_data[i];
            }
        }
            int us_peaks_shape_1 = us_peaks_tensor.size(1);
            float* us_peaks_data = us_peaks_tensor.data_ptr<float>();
            std::vector<float> us_peaks(us_peaks_shape_1);
            for (int i = 0; i < us_peaks_shape_1; i++) {
                us_peaks[i] = us_peaks_data[i];
            }
         if (lm_ == nullptr) {
                result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
         } else {
             result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
                if (input_finished) {
                    result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
                }
         }
        }else{
            if (lm_ == nullptr) {
                result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
            } else {
                result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
                if (input_finished) {
                    result = FinalizeDecode(wfst_decoder);
                }
            }
        }
    }
    catch (std::exception const &e)
    {
@@ -324,7 +366,8 @@
}
std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
    std::vector<std::vector<float>> result;
    // TODO
    std::vector<std::vector<float>> result(1, std::vector<float>(10, 0.0f));
    return result;
}