nichongjia-2007
2023-07-07 1ce704d8c09bd4d4c7e5ab087f951f31fad9fca6
funasr/bin/asr_infer.py
@@ -280,6 +280,7 @@
            nbest: int = 1,
            frontend_conf: dict = None,
            hotword_list_or_file: str = None,
            clas_scale: float = 1.0,
            decoding_ind: int = 0,
            **kwargs,
    ):
@@ -376,6 +377,7 @@
        # 6. [Optional] Build hotword list from str, local file or url
        self.hotword_list = None
        self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
        self.clas_scale = clas_scale
        is_use_lm = lm_weight != 0.0 and lm_file is not None
        if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
@@ -439,16 +441,20 @@
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
                                                                                   NeatContextualParaformer):
        if not isinstance(self.asr_model, ContextualParaformer) and \
            not isinstance(self.asr_model, NeatContextualParaformer):
            if self.hotword_list:
                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                     pre_token_length)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        else:
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                     pre_token_length, hw_list=self.hotword_list)
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc,
                                                                     enc_len,
                                                                     pre_acoustic_embeds,
                                                                     pre_token_length,
                                                                     hw_list=self.hotword_list,
                                                                     clas_scale=self.clas_scale)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        if isinstance(self.asr_model, BiCifParaformer):