游雁
2024-03-21 bbda5496ffae1d9ab052e8736a8c0b080ea017f5
funasr/models/seaco_paraformer/model.py
@@ -117,6 +117,8 @@
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        text_lengths = text_lengths.squeeze()
        speech_lengths = speech_lengths.squeeze()
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
@@ -128,7 +130,7 @@
    
        hotword_pad = kwargs.get("hotword_pad")
        hotword_lengths = kwargs.get("hotword_lengths")
        dha_pad = kwargs.get("dha_pad")
        seaco_label_pad = kwargs.get("seaco_label_pad")
        
        batch_size = speech.shape[0]
        # for data-parallel
@@ -148,7 +150,7 @@
                                        ys_lengths, 
                                        hotword_pad, 
                                        hotword_lengths, 
                                        dha_pad,
                                        seaco_label_pad,
                                        )
        if self.train_decoder:
            loss_att, acc_att = self._calc_att_loss(
@@ -164,7 +166,7 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
            batch_size = (text_lengths + self.predictor_bias).sum()
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
@@ -185,13 +187,12 @@
            ys_lengths: torch.Tensor,
            hotword_pad: torch.Tensor,
            hotword_lengths: torch.Tensor,
            dha_pad: torch.Tensor,
            seaco_label_pad: torch.Tensor,
    ):  
        # predictor forward
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
                                                                                  ignore_id=self.ignore_id)
        pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0]
        # decoder forward
        decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
        selected = self._hotword_representation(hotword_pad, 
@@ -204,7 +205,7 @@
        dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_out, ys_lengths)
        merged = self._merge(cif_attended, dec_attended)
        dha_output = self.hotword_output_layer(merged[:, :-1])  # remove the last token in loss calculation
        loss_att = self.criterion_seaco(dha_output, dha_pad)
        loss_att = self.criterion_seaco(dha_output, seaco_label_pad)
        return loss_att
    def _seaco_decode_with_ASF(self, 
@@ -430,7 +431,6 @@
        
        return results, meta_data
    def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
        def load_seg_dict(seg_dict_file):
            seg_dict = {}
@@ -532,3 +532,13 @@
            hotword_list = None
        return hotword_list
    def export(
        self,
        **kwargs,
    ):
        if 'max_seq_len' not in kwargs:
            kwargs['max_seq_len'] = 512
        from .export_meta import export_rebuild_model
        models = export_rebuild_model(model=self, **kwargs)
        return models