shixian.shi
2023-06-27 8d519b9dc66e0df35c15110ef23a26d07bc7f7c3
update
4个文件已修改
23 ■■■■■ 已修改文件
funasr/bin/asr_infer.py 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/decoder/contextual_decoder.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_contextual_paraformer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_infer.py
@@ -285,6 +285,7 @@
            nbest: int = 1,
            frontend_conf: dict = None,
            hotword_list_or_file: str = None,
            clas_scale: float = 1.0,
            decoding_ind: int = 0,
            **kwargs,
    ):
@@ -382,6 +383,7 @@
        # 6. [Optional] Build hotword list from str, local file or url
        self.hotword_list = None
        self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
        self.clas_scale = clas_scale
        is_use_lm = lm_weight != 0.0 and lm_file is not None
        if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
@@ -446,16 +448,20 @@
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
                                                                                   NeatContextualParaformer):
        if not isinstance(self.asr_model, ContextualParaformer) and \
            not isinstance(self.asr_model, NeatContextualParaformer):
            if self.hotword_list:
                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                     pre_token_length)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        else:
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                     pre_token_length, hw_list=self.hotword_list)
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc,
                                                                     enc_len,
                                                                     pre_acoustic_embeds,
                                                                     pre_token_length,
                                                                     hw_list=self.hotword_list,
                                                                     clas_scale=self.clas_scale)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        if isinstance(self.asr_model, BiCifParaformer):
funasr/bin/asr_inference_launch.py
@@ -260,6 +260,7 @@
        export_mode = param_dict.get("export_mode", False)
    else:
        hotword_list_or_file = None
    clas_scale = param_dict.get('clas_scale', 1.0)
    if kwargs.get("device", None) == "cpu":
        ngpu = 0
@@ -292,6 +293,7 @@
        penalty=penalty,
        nbest=nbest,
        hotword_list_or_file=hotword_list_or_file,
        clas_sacle=clas_scale,
    )
    speech2text = Speech2TextParaformer(**speech2text_kwargs)
funasr/models/decoder/contextual_decoder.py
@@ -246,6 +246,7 @@
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        contextual_info: torch.Tensor,
        clas_scale: float = 1.0,
        return_hidden: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.
@@ -285,7 +286,7 @@
        cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
        if self.bias_output is not None:
            x = torch.cat([x_src_attn, cx], dim=2)
            x = torch.cat([x_src_attn, cx*clas_scale], dim=2)
            x = self.bias_output(x.transpose(1, 2)).transpose(1, 2)  # 2D -> D
            x = x_self_attn + self.dropout(x)
funasr/models/e2e_asr_contextual_paraformer.py
@@ -343,7 +343,7 @@
            input_mask_expand_dim, 0)
        return sematic_embeds * tgt_mask, decoder_out * tgt_mask
    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, clas_scale=1.0):
        if hw_list is None:
            hw_list = [torch.Tensor([1]).long().to(encoder_out.device)]  # empty hotword list
            hw_list_pad = pad_list(hw_list, 0)
@@ -365,7 +365,7 @@
            hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
        
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed
            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
        )
        decoder_out = decoder_outs[0]
        decoder_out = torch.log_softmax(decoder_out, dim=-1)