游雁
2023-05-10 a97daeb247563b14df49ddeed40f991c9916858e
paraformer long batch infer
3个文件已修改
1个文件已添加
841 ■■■■■ 已修改文件
egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer.py 160 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_vad_punc.py 658 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/vad_utils.py 18 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py
@@ -2,14 +2,15 @@
from modelscope.utils.constant import Tasks
if __name__ == '__main__':
    audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
    audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav'
    output_dir = None
    inference_pipeline = pipeline(
        task=Tasks.auto_speech_recognition,
        model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
        vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
        punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
        output_dir=output_dir
        output_dir=output_dir,
        batch_size=8,
    )
    rec_result = inference_pipeline(audio_in=audio_in)
    print(rec_result)
funasr/bin/asr_inference_paraformer.py
@@ -358,160 +358,6 @@
            hotword_list = None
        return hotword_list
class Speech2TextExport:
    """Speech2TextExport class
    """
    def __init__(
            self,
            asr_train_config: Union[Path, str] = None,
            asr_model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            lm_train_config: Union[Path, str] = None,
            lm_file: Union[Path, str] = None,
            token_type: str = None,
            bpemodel: str = None,
            device: str = "cpu",
            maxlenratio: float = 0.0,
            minlenratio: float = 0.0,
            dtype: str = "float32",
            beam_size: int = 20,
            ctc_weight: float = 0.5,
            lm_weight: float = 1.0,
            ngram_weight: float = 0.9,
            penalty: float = 0.0,
            nbest: int = 1,
            frontend_conf: dict = None,
            hotword_list_or_file: str = None,
            **kwargs,
    ):
        # 1. Build ASR model
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        token_list = asr_model.token_list
        logging.info(f"Decoding device={device}, dtype={dtype}")
        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
            if bpemodel is not None:
                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
            else:
                tokenizer = None
        else:
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        # self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
        self.tokenizer = tokenizer
        self.device = device
        self.dtype = dtype
        self.nbest = nbest
        self.frontend = frontend
        model = Paraformer_export(asr_model, onnx=False)
        self.asr_model = model
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
    ):
        """Inference
        Args:
                speech: Input speech data
        Returns:
                text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
            self.asr_model.frontend = None
        else:
            feats = speech
            feats_len = speech_lengths
        enc_len_batch_total = feats_len.sum()
        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
        batch = {"speech": feats, "speech_lengths": feats_len}
        # a. To device
        batch = to_device(batch, device=self.device)
        decoder_outs = self.asr_model(**batch)
        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        results = []
        b, n, d = decoder_out.size()
        for i in range(b):
            am_scores = decoder_out[i, :ys_pad_lens[i], :]
            yseq = am_scores.argmax(dim=-1)
            score = am_scores.max(dim=-1)[0]
            score = torch.sum(score, dim=-1)
            # pad with mask tokens to ensure compatibility with sos/eos tokens
            yseq = torch.tensor(
                yseq.tolist(), device=yseq.device
            )
            nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
            for hyp in nbest_hyps:
                assert isinstance(hyp, (Hypothesis)), type(hyp)
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq[1:last_pos]
                else:
                    token_int = hyp.yseq[1:last_pos].tolist()
                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
                # Change integer-ids to tokens
                token = self.converter.ids2tokens(token_int)
                if self.tokenizer is not None:
                    text = self.tokenizer.tokens2text(token)
                else:
                    text = None
                results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
        return results
def inference(
        maxlenratio: float,
@@ -665,10 +511,8 @@
        nbest=nbest,
        hotword_list_or_file=hotword_list_or_file,
    )
    if export_mode:
        speech2text = Speech2TextExport(**speech2text_kwargs)
    else:
        speech2text = Speech2Text(**speech2text_kwargs)
    speech2text = Speech2Text(**speech2text_kwargs)
    if timestamp_model_file is not None:
        speechtext2timestamp = SpeechText2Timestamp(
funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -47,327 +47,323 @@
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.bin.punctuation_infer import Text2Punc
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
header_colors = '\033[95m'
end_colors = '\033[0m'
class Speech2Text:
    """Speech2Text class
    Examples:
            >>> import soundfile
            >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
            >>> audio, rate = soundfile.read("speech.wav")
            >>> speech2text(audio)
            [(text, token, token_int, hypothesis object), ...]
    """
    def __init__(
            self,
            asr_train_config: Union[Path, str] = None,
            asr_model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            lm_train_config: Union[Path, str] = None,
            lm_file: Union[Path, str] = None,
            token_type: str = None,
            bpemodel: str = None,
            device: str = "cpu",
            maxlenratio: float = 0.0,
            minlenratio: float = 0.0,
            dtype: str = "float32",
            beam_size: int = 20,
            ctc_weight: float = 0.5,
            lm_weight: float = 1.0,
            ngram_weight: float = 0.9,
            penalty: float = 0.0,
            nbest: int = 1,
            frontend_conf: dict = None,
            hotword_list_or_file: str = None,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file=cmvn_file, device=device
        )
        frontend = None
        if asr_model.frontend is not None and asr_train_args.frontend_conf is not None:
            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
        # logging.info("asr_model: {}".format(asr_model))
        # logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        if asr_model.ctc != None:
            ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
            scorers.update(
                ctc=ctc
            )
        token_list = asr_model.token_list
        scorers.update(
            length_bonus=LengthBonus(len(token_list)),
        )
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            )
            scorers["lm"] = lm.lm
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        # 4. Build BeamSearch object
        # transducer is not supported now
        beam_search_transducer = None
        weights = dict(
            decoder=1.0 - ctc_weight,
            ctc=ctc_weight,
            lm=lm_weight,
            ngram=ngram_weight,
            length_bonus=penalty,
        )
        beam_search = BeamSearch(
            beam_size=beam_size,
            weights=weights,
            scorers=scorers,
            sos=asr_model.sos,
            eos=asr_model.eos,
            vocab_size=len(token_list),
            token_list=token_list,
            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
        )
        beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
        for scorer in scorers.values():
            if isinstance(scorer, torch.nn.Module):
                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
        logging.info(f"Decoding device={device}, dtype={dtype}")
        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
            if bpemodel is not None:
                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
            else:
                tokenizer = None
        else:
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
        self.tokenizer = tokenizer
        # 6. [Optional] Build hotword list from str, local file or url
        self.hotword_list = None
        self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
        is_use_lm = lm_weight != 0.0 and lm_file is not None
        if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
            beam_search = None
        self.beam_search = beam_search
        logging.info(f"Beam_search: {self.beam_search}")
        self.beam_search_transducer = beam_search_transducer
        self.maxlenratio = maxlenratio
        self.minlenratio = minlenratio
        self.device = device
        self.dtype = dtype
        self.nbest = nbest
        self.frontend = frontend
        self.encoder_downsampling_factor = 1
        if asr_train_args.encoder_conf["input_layer"] == "conv2d":
            self.encoder_downsampling_factor = 4
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
            begin_time: int = 0, end_time: int = None,
    ):
        """Inference
        Args:
                speech: Input speech data
        Returns:
                text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            # feats, feats_len = self.frontend.forward(speech, speech_lengths)
            # fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
            feats, feats_len = self.frontend.forward_lfr_cmvn(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
            self.asr_model.frontend = None
        else:
            feats = speech
            feats_len = speech_lengths
        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
        batch = {"speech": feats, "speech_lengths": feats_len}
        # a. To device
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        enc, enc_len = self.asr_model.encode(**batch)
        if isinstance(enc, tuple):
            enc = enc[0]
        # assert len(enc) == 1, len(enc)
        enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
        predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
                                                                        predictor_outs[2], predictor_outs[3]
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        if not isinstance(self.asr_model, ContextualParaformer):
            if self.hotword_list:
                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        else:
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        if isinstance(self.asr_model, BiCifParaformer):
            _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
                                                                                   pre_token_length)  # test no bias cif2
        results = []
        b, n, d = decoder_out.size()
        for i in range(b):
            x = enc[i, :enc_len[i], :]
            am_scores = decoder_out[i, :pre_token_length[i], :]
            if self.beam_search is not None:
                nbest_hyps = self.beam_search(
                    x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
                )
                nbest_hyps = nbest_hyps[: self.nbest]
            else:
                yseq = am_scores.argmax(dim=-1)
                score = am_scores.max(dim=-1)[0]
                score = torch.sum(score, dim=-1)
                # pad with mask tokens to ensure compatibility with sos/eos tokens
                yseq = torch.tensor(
                    [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
                )
                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
            for hyp in nbest_hyps:
                assert isinstance(hyp, (Hypothesis)), type(hyp)
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq[1:last_pos]
                else:
                    token_int = hyp.yseq[1:last_pos].tolist()
                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
                if len(token_int) == 0:
                    continue
                # Change integer-ids to tokens
                token = self.converter.ids2tokens(token_int)
                if self.tokenizer is not None:
                    text = self.tokenizer.tokens2text(token)
                else:
                    text = None
                if isinstance(self.asr_model, BiCifParaformer):
                    _, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
                                                            us_peaks[i],
                                                            copy.copy(token),
                                                            vad_offset=begin_time)
                    results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
                else:
                    results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
        # assert check_return_type(results)
        return results
    def generate_hotwords_list(self, hotword_list_or_file):
        # for None
        if hotword_list_or_file is None:
            hotword_list = None
        # for local txt inputs
        elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
            logging.info("Attempting to parse hotwords from local txt...")
            hotword_list = []
            hotword_str_list = []
            with codecs.open(hotword_list_or_file, 'r') as fin:
                for line in fin.readlines():
                    hw = line.strip()
                    hotword_str_list.append(hw)
                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                hotword_list.append([self.asr_model.sos])
                hotword_str_list.append('<s>')
            logging.info("Initialized hotword list from file: {}, hotword list: {}."
                         .format(hotword_list_or_file, hotword_str_list))
        # for url, download and generate txt
        elif hotword_list_or_file.startswith('http'):
            logging.info("Attempting to parse hotwords from url...")
            work_dir = tempfile.TemporaryDirectory().name
            if not os.path.exists(work_dir):
                os.makedirs(work_dir)
            text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
            local_file = requests.get(hotword_list_or_file)
            open(text_file_path, "wb").write(local_file.content)
            hotword_list_or_file = text_file_path
            hotword_list = []
            hotword_str_list = []
            with codecs.open(hotword_list_or_file, 'r') as fin:
                for line in fin.readlines():
                    hw = line.strip()
                    hotword_str_list.append(hw)
                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                hotword_list.append([self.asr_model.sos])
                hotword_str_list.append('<s>')
            logging.info("Initialized hotword list from file: {}, hotword list: {}."
                         .format(hotword_list_or_file, hotword_str_list))
        # for text str input
        elif not hotword_list_or_file.endswith('.txt'):
            logging.info("Attempting to parse hotwords as str...")
            hotword_list = []
            hotword_str_list = []
            for hw in hotword_list_or_file.strip().split():
                hotword_str_list.append(hw)
                hotword_list.append(self.converter.tokens2ids([i for i in hw]))
            hotword_list.append([self.asr_model.sos])
            hotword_str_list.append('<s>')
            logging.info("Hotword list: {}.".format(hotword_str_list))
        else:
            hotword_list = None
        return hotword_list
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.bin.asr_inference_paraformer import Speech2Text
# class Speech2Text:
#     """Speech2Text class
#
#     Examples:
#             >>> import soundfile
#             >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
#             >>> audio, rate = soundfile.read("speech.wav")
#             >>> speech2text(audio)
#             [(text, token, token_int, hypothesis object), ...]
#
#     """
#
#     def __init__(
#             self,
#             asr_train_config: Union[Path, str] = None,
#             asr_model_file: Union[Path, str] = None,
#             cmvn_file: Union[Path, str] = None,
#             lm_train_config: Union[Path, str] = None,
#             lm_file: Union[Path, str] = None,
#             token_type: str = None,
#             bpemodel: str = None,
#             device: str = "cpu",
#             maxlenratio: float = 0.0,
#             minlenratio: float = 0.0,
#             dtype: str = "float32",
#             beam_size: int = 20,
#             ctc_weight: float = 0.5,
#             lm_weight: float = 1.0,
#             ngram_weight: float = 0.9,
#             penalty: float = 0.0,
#             nbest: int = 1,
#             frontend_conf: dict = None,
#             hotword_list_or_file: str = None,
#             **kwargs,
#     ):
#         assert check_argument_types()
#
#         # 1. Build ASR model
#         scorers = {}
#         asr_model, asr_train_args = ASRTask.build_model_from_file(
#             asr_train_config, asr_model_file, cmvn_file=cmvn_file, device=device
#         )
#         frontend = None
#         if asr_model.frontend is not None and asr_train_args.frontend_conf is not None:
#             frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
#
#         # logging.info("asr_model: {}".format(asr_model))
#         # logging.info("asr_train_args: {}".format(asr_train_args))
#         asr_model.to(dtype=getattr(torch, dtype)).eval()
#
#         if asr_model.ctc != None:
#             ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
#             scorers.update(
#                 ctc=ctc
#             )
#         token_list = asr_model.token_list
#         scorers.update(
#             length_bonus=LengthBonus(len(token_list)),
#         )
#
#         # 2. Build Language model
#         if lm_train_config is not None:
#             lm, lm_train_args = LMTask.build_model_from_file(
#                 lm_train_config, lm_file, device
#             )
#             scorers["lm"] = lm.lm
#
#         # 3. Build ngram model
#         # ngram is not supported now
#         ngram = None
#         scorers["ngram"] = ngram
#
#         # 4. Build BeamSearch object
#         # transducer is not supported now
#         beam_search_transducer = None
#
#         weights = dict(
#             decoder=1.0 - ctc_weight,
#             ctc=ctc_weight,
#             lm=lm_weight,
#             ngram=ngram_weight,
#             length_bonus=penalty,
#         )
#         beam_search = BeamSearch(
#             beam_size=beam_size,
#             weights=weights,
#             scorers=scorers,
#             sos=asr_model.sos,
#             eos=asr_model.eos,
#             vocab_size=len(token_list),
#             token_list=token_list,
#             pre_beam_score_key=None if ctc_weight == 1.0 else "full",
#         )
#
#         beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
#         for scorer in scorers.values():
#             if isinstance(scorer, torch.nn.Module):
#                 scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
#
#         logging.info(f"Decoding device={device}, dtype={dtype}")
#
#         # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
#         if token_type is None:
#             token_type = asr_train_args.token_type
#         if bpemodel is None:
#             bpemodel = asr_train_args.bpemodel
#
#         if token_type is None:
#             tokenizer = None
#         elif token_type == "bpe":
#             if bpemodel is not None:
#                 tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
#             else:
#                 tokenizer = None
#         else:
#             tokenizer = build_tokenizer(token_type=token_type)
#         converter = TokenIDConverter(token_list=token_list)
#         logging.info(f"Text tokenizer: {tokenizer}")
#
#         self.asr_model = asr_model
#         self.asr_train_args = asr_train_args
#         self.converter = converter
#         self.tokenizer = tokenizer
#
#         # 6. [Optional] Build hotword list from str, local file or url
#         self.hotword_list = None
#         self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
#
#         is_use_lm = lm_weight != 0.0 and lm_file is not None
#         if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
#             beam_search = None
#         self.beam_search = beam_search
#         logging.info(f"Beam_search: {self.beam_search}")
#         self.beam_search_transducer = beam_search_transducer
#         self.maxlenratio = maxlenratio
#         self.minlenratio = minlenratio
#         self.device = device
#         self.dtype = dtype
#         self.nbest = nbest
#         self.frontend = frontend
#         self.encoder_downsampling_factor = 1
#         if asr_train_args.encoder_conf["input_layer"] == "conv2d":
#             self.encoder_downsampling_factor = 4
#
#     @torch.no_grad()
#     def __call__(
#             self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
#             begin_time: int = 0, end_time: int = None,
#     ):
#         """Inference
#
#         Args:
#                 speech: Input speech data
#         Returns:
#                 text, token, token_int, hyp
#
#         """
#         assert check_argument_types()
#
#         # Input as audio signal
#         if isinstance(speech, np.ndarray):
#             speech = torch.tensor(speech)
#
#         if self.frontend is not None:
#             feats, feats_len = self.frontend.forward(speech, speech_lengths)
#             # fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
#             # feats, feats_len = self.frontend.forward_lfr_cmvn(speech, speech_lengths)
#             feats = to_device(feats, device=self.device)
#             feats_len = feats_len.int()
#             self.asr_model.frontend = None
#         else:
#             feats = speech
#             feats_len = speech_lengths
#         lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
#         batch = {"speech": feats, "speech_lengths": feats_len}
#
#         # a. To device
#         batch = to_device(batch, device=self.device)
#
#         # b. Forward Encoder
#         enc, enc_len = self.asr_model.encode(**batch)
#         if isinstance(enc, tuple):
#             enc = enc[0]
#         # assert len(enc) == 1, len(enc)
#         enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
#
#         predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
#         pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
#                                                                         predictor_outs[2], predictor_outs[3]
#         pre_token_length = pre_token_length.round().long()
#         if torch.max(pre_token_length) < 1:
#             return []
#
#         if not isinstance(self.asr_model, ContextualParaformer):
#             if self.hotword_list:
#                 logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
#             decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
#             decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
#         else:
#             decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
#             decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
#
#         if isinstance(self.asr_model, BiCifParaformer):
#             _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
#                                                                                    pre_token_length)  # test no bias cif2
#
#         results = []
#         b, n, d = decoder_out.size()
#         for i in range(b):
#             x = enc[i, :enc_len[i], :]
#             am_scores = decoder_out[i, :pre_token_length[i], :]
#             if self.beam_search is not None:
#                 nbest_hyps = self.beam_search(
#                     x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
#                 )
#
#                 nbest_hyps = nbest_hyps[: self.nbest]
#             else:
#                 yseq = am_scores.argmax(dim=-1)
#                 score = am_scores.max(dim=-1)[0]
#                 score = torch.sum(score, dim=-1)
#                 # pad with mask tokens to ensure compatibility with sos/eos tokens
#                 yseq = torch.tensor(
#                     [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
#                 )
#                 nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
#
#             for hyp in nbest_hyps:
#                 assert isinstance(hyp, (Hypothesis)), type(hyp)
#
#                 # remove sos/eos and get results
#                 last_pos = -1
#                 if isinstance(hyp.yseq, list):
#                     token_int = hyp.yseq[1:last_pos]
#                 else:
#                     token_int = hyp.yseq[1:last_pos].tolist()
#
#                 # remove blank symbol id, which is assumed to be 0
#                 token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
#                 if len(token_int) == 0:
#                     continue
#
#                 # Change integer-ids to tokens
#                 token = self.converter.ids2tokens(token_int)
#
#                 if self.tokenizer is not None:
#                     text = self.tokenizer.tokens2text(token)
#                 else:
#                     text = None
#
#                 if isinstance(self.asr_model, BiCifParaformer):
#                     _, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
#                                                             us_peaks[i],
#                                                             copy.copy(token),
#                                                             vad_offset=begin_time)
#                     results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
#                 else:
#                     results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
#
#         # assert check_return_type(results)
#         return results
#
#     def generate_hotwords_list(self, hotword_list_or_file):
#         # for None
#         if hotword_list_or_file is None:
#             hotword_list = None
#         # for local txt inputs
#         elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
#             logging.info("Attempting to parse hotwords from local txt...")
#             hotword_list = []
#             hotword_str_list = []
#             with codecs.open(hotword_list_or_file, 'r') as fin:
#                 for line in fin.readlines():
#                     hw = line.strip()
#                     hotword_str_list.append(hw)
#                     hotword_list.append(self.converter.tokens2ids([i for i in hw]))
#                 hotword_list.append([self.asr_model.sos])
#                 hotword_str_list.append('<s>')
#             logging.info("Initialized hotword list from file: {}, hotword list: {}."
#                          .format(hotword_list_or_file, hotword_str_list))
#         # for url, download and generate txt
#         elif hotword_list_or_file.startswith('http'):
#             logging.info("Attempting to parse hotwords from url...")
#             work_dir = tempfile.TemporaryDirectory().name
#             if not os.path.exists(work_dir):
#                 os.makedirs(work_dir)
#             text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
#             local_file = requests.get(hotword_list_or_file)
#             open(text_file_path, "wb").write(local_file.content)
#             hotword_list_or_file = text_file_path
#             hotword_list = []
#             hotword_str_list = []
#             with codecs.open(hotword_list_or_file, 'r') as fin:
#                 for line in fin.readlines():
#                     hw = line.strip()
#                     hotword_str_list.append(hw)
#                     hotword_list.append(self.converter.tokens2ids([i for i in hw]))
#                 hotword_list.append([self.asr_model.sos])
#                 hotword_str_list.append('<s>')
#             logging.info("Initialized hotword list from file: {}, hotword list: {}."
#                          .format(hotword_list_or_file, hotword_str_list))
#         # for text str input
#         elif not hotword_list_or_file.endswith('.txt'):
#             logging.info("Attempting to parse hotwords as str...")
#             hotword_list = []
#             hotword_str_list = []
#             for hw in hotword_list_or_file.strip().split():
#                 hotword_str_list.append(hw)
#                 hotword_list.append(self.converter.tokens2ids([i for i in hw]))
#             hotword_list.append([self.asr_model.sos])
#             hotword_str_list.append('<s>')
#             logging.info("Hotword list: {}.".format(hotword_str_list))
#         else:
#             hotword_list = None
#         return hotword_list
def inference(
@@ -611,15 +607,17 @@
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            vad_results = speech2vadsegment(**batch)
            fbanks, vadsegments = vad_results[0], vad_results[1]
            _, vadsegments = vad_results[0], vad_results[1]
            speech, speech_lengths = batch["speech"],  batch["speech_lengths"]
            for i, segments in enumerate(vadsegments):
                result_segments = [["", [], [], []]]
                for j, segment_idx in enumerate(segments):
                    bed_idx, end_idx = int(segment_idx[0] / 10), int(segment_idx[1] / 10)
                    segment = fbanks[:, bed_idx:end_idx, :].to(device)
                    speech_lengths = torch.Tensor([end_idx - bed_idx]).int().to(device)
                    batch = {"speech": segment, "speech_lengths": speech_lengths, "begin_time": vadsegments[i][j][0],
                             "end_time": vadsegments[i][j][1]}
                # for j, segment_idx in enumerate(segments):
                for j, beg_idx in enumerate(range(0, len(segments), batch_size)):
                    end_idx = min(len(segments), beg_idx + batch_size)
                    speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, segments[beg_idx:end_idx])
                    batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
                    batch = to_device(batch, device=device)
                    results = speech2text(**batch)
                    if len(results) < 1:
                        continue
@@ -633,8 +631,8 @@
                key = keys[0]
                result = result_segments[0]
                text, token, token_int = result[0], result[1], result[2]
                time_stamp = None if len(result) < 4 else result[3]
                text, token, token_int, hyp = result[0], result[1], result[2], result[3]
                time_stamp = None if len(result) < 5 else result[4]
                if use_timestamp and time_stamp is not None: 
funasr/utils/vad_utils.py
New file
@@ -0,0 +1,18 @@
import torch
from torch.nn.utils.rnn import pad_sequence
def slice_padding_fbank(speech, speech_lengths, vad_segments):
    speech_list = []
    speech_lengths_list = []
    for i, segment in enumerate(vad_segments):
        bed_idx = int(segment[0]*16)
        end_idx = min(int(segment[1]*16), speech_lengths[0])
        speech_i = speech[0, bed_idx: end_idx]
        speech_lengths_i = end_idx-bed_idx
        speech_list.append(speech_i)
        speech_lengths_list.append(speech_lengths_i)
    feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
    speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
    return feats_pad, speech_lengths_pad