From d1a3a7ad90ce9deb7fb4940970bca0abb9409181 Mon Sep 17 00:00:00 2001
From: Lizerui9926 <110582652+Lizerui9926@users.noreply.github.com>
Date: 星期四, 09 二月 2023 20:44:30 +0800
Subject: [PATCH] Merge pull request #89 from alibaba-damo-academy/dev_lzr
---
funasr/bin/asr_inference_paraformer_vad_punc.py | 17 +++++++++++++----
1 files changed, 13 insertions(+), 4 deletions(-)
diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index 1d09c79..7d18e02 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -14,6 +14,7 @@
from typing import Any
from typing import List
import math
+import copy
import numpy as np
import torch
from typeguard import check_argument_types
@@ -38,8 +39,9 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_lfr6
+from funasr.utils.timestamp_tools import time_stamp_lfr6, time_stamp_lfr6_pl
from funasr.bin.punctuation_infer import Text2Punc
+from funasr.models.e2e_asr_paraformer import BiCifParaformer
header_colors = '\033[95m'
end_colors = '\033[0m'
@@ -234,6 +236,10 @@
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ if isinstance(self.asr_model, BiCifParaformer):
+ _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
+ pre_token_length) # test no bias cif2
+
results = []
b, n, d = decoder_out.size()
for i in range(b):
@@ -276,9 +282,12 @@
else:
text = None
- time_stamp = time_stamp_lfr6(alphas[i:i+1,], enc_len[i:i+1,], token, begin_time, end_time)
-
- results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
+ if isinstance(self.asr_model, BiCifParaformer):
+ timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
+ results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
+ else:
+ time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time, end_time)
+ results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
# assert check_return_type(results)
return results
--
Gitblit v1.9.1