zhifu gao
2024-03-01 5608bee0accea5e12030f8e1b6f7d62eee4dd892
fixbug (#1412)

4个文件已修改
15 ■■■■ 已修改文件
funasr/models/contextual_paraformer/model.py 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/model.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer_streaming/model.py 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/seaco_paraformer/model.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/contextual_paraformer/model.py
@@ -190,13 +190,10 @@
        # 0. sampler
        decoder_out_1st = None
        if self.sampling_ratio > 0.0:
            if self.step_cur < 2:
                logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
            sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
                                                           pre_acoustic_embeds, contextual_info)
        else:
            if self.step_cur < 2:
                logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
            sematic_embeds = pre_acoustic_embeds
        
        # 1. Forward decoder
funasr/models/paraformer/model.py
@@ -154,8 +154,8 @@
        self.predictor_bias = predictor_bias
        self.sampling_ratio = sampling_ratio
        self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
        # self.step_cur = 0
        #
        self.share_embedding = share_embedding
        if self.share_embedding:
            self.decoder.embed = None
funasr/models/paraformer_streaming/model.py
@@ -235,8 +235,7 @@
        decoder_out_1st = None
        pre_loss_att = None
        if self.sampling_ratio > 0.0:
            if self.step_cur < 2:
                logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
            if self.use_1st_decoder_loss:
                sematic_embeds, decoder_out_1st, pre_loss_att = \
                    self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
@@ -246,8 +245,6 @@
                    self.sampler(encoder_out, encoder_out_lens, ys_pad,
                                 ys_pad_lens, pre_acoustic_embeds, scama_mask)
        else:
            if self.step_cur < 2:
                logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
            sematic_embeds = pre_acoustic_embeds
        
        # 1. Forward decoder
funasr/models/seaco_paraformer/model.py
@@ -130,7 +130,6 @@
        dha_pad = kwargs.get("dha_pad")
        batch_size = speech.shape[0]
        self.step_cur += 1
        # for data-parallel
        text = text[:, : text_lengths.max()]
        speech = speech[:, :speech_lengths.max()]