| | |
| | | torch::jit::setGraphExecutorOptimize(false); |
| | | torch::jit::FusionStrategy static0 = {{torch::jit::FusionBehavior::STATIC, 0}}; |
| | | torch::jit::setFusionStrategy(static0); |
| | | #ifdef USE_GPU |
| | | WarmUp(); |
| | | #endif |
| | | } catch (std::exception const &e) { |
| | | LOG(ERROR) << "Error when load am model: " << am_model << e.what(); |
| | | exit(-1); |
| | |
| | | return results; |
| | | } |
| | | |
| | | void ParaformerTorch::WarmUp() |
| | | { |
| | | int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins; |
| | | int32_t feature_dim = lfr_m*in_feat_dim; |
| | | int batch_in = 1; |
| | | int max_frames = 10; |
| | | std::vector<int32_t> paraformer_length; |
| | | paraformer_length.push_back(max_frames); |
| | | |
| | | std::vector<float> all_feats(batch_in * max_frames * feature_dim, 0.1); |
| | | torch::Tensor feats = |
| | | torch::from_blob(all_feats.data(), |
| | | {batch_in, max_frames, feature_dim}, torch::kFloat).contiguous(); |
| | | torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(), |
| | | {batch_in}, torch::kInt32); |
| | | |
| | | // 2. forward |
| | | feats = feats.to(at::kCUDA); |
| | | feat_lens = feat_lens.to(at::kCUDA); |
| | | std::vector<torch::jit::IValue> inputs = {feats, feat_lens}; |
| | | |
| | | if (use_hotword) { |
| | | std::string hotwords_wp = ""; |
| | | std::vector<std::vector<float>> hw_emb = CompileHotwordEmbedding(hotwords_wp); |
| | | std::vector<float> embedding; |
| | | embedding.reserve(hw_emb.size() * hw_emb[0].size()); |
| | | for (auto item : hw_emb) { |
| | | embedding.insert(embedding.end(), item.begin(), item.end()); |
| | | } |
| | | torch::Tensor tensor_hw_emb = |
| | | torch::from_blob(embedding.data(), |
| | | {batch_in, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())}, torch::kFloat).contiguous(); |
| | | tensor_hw_emb = tensor_hw_emb.to(at::kCUDA); |
| | | inputs.emplace_back(tensor_hw_emb); |
| | | } |
| | | |
| | | try { |
| | | auto outputs = model_->forward(inputs).toTuple()->elements(); |
| | | } |
| | | catch (std::exception const &e) |
| | | { |
| | | LOG(ERROR)<<e.what(); |
| | | } |
| | | } |
| | | |
| | | std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) { |
| | | int embedding_dim = encoder_size; |
| | | std::vector<std::vector<float>> hw_emb; |