wanchen.swc
2023-03-30 cfc4d402093060fe087424b0a6be4e2b2546eae8
funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -58,6 +58,9 @@
            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
            try:
                if int(device_id) != -1:
                    feats = feats.cuda()
                    feats_len = feats_len.cuda()
                outputs = self.ort_infer(feats, feats_len)
                am_scores, valid_token_lens = outputs[0], outputs[1]
                if len(outputs) == 4: