querryton
2024-04-20 01df8f330ccc754223d5e2d688dc0a55d27f2dcc
funasr/models/seaco_paraformer/model.py
@@ -97,7 +97,8 @@
            smoothing=seaco_lsm_weight,
            normalize_length=seaco_length_normalized_loss,
        )
        self.train_decoder = kwargs.get("train_decoder", False)
        self.train_decoder = kwargs.get("train_decoder", True)
        self.seaco_weight = kwargs.get("seaco_weight", 0.01)
        self.NO_BIAS = kwargs.get("NO_BIAS", 8377)
        self.predictor_name = kwargs.get("predictor")
        
@@ -117,7 +118,10 @@
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        # Check that batch_size is unified
        assert (
                speech.shape[0]
@@ -129,6 +133,8 @@
        hotword_pad = kwargs.get("hotword_pad")
        hotword_lengths = kwargs.get("hotword_lengths")
        seaco_label_pad = kwargs.get("seaco_label_pad")
        if len(hotword_lengths.size()) > 1:
            hotword_lengths = hotword_lengths[:, 0]
        
        batch_size = speech.shape[0]
        # for data-parallel
@@ -151,20 +157,21 @@
                                        seaco_label_pad,
                                        )
        if self.train_decoder:
            loss_att, acc_att = self._calc_att_loss(
            loss_att, acc_att, _, _, _ = self._calc_att_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
            loss = loss_seaco + loss_att
            loss = loss_seaco + loss_att * self.seaco_weight
            stats["loss_att"] = torch.clone(loss_att.detach())
            stats["acc_att"] = acc_att
        else:
            loss = loss_seaco
        stats["loss_seaco"] = torch.clone(loss_seaco.detach())
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
            batch_size = (text_lengths + self.predictor_bias).sum()
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
@@ -190,8 +197,7 @@
        # predictor forward
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
                                                                                  ignore_id=self.ignore_id)
        pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0]
        # decoder forward
        decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
        selected = self._hotword_representation(hotword_pad, 
@@ -344,7 +350,7 @@
        pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
            return ([],)
        decoder_out = self._seaco_decode_with_ASF(encoder_out, 
                                                  encoder_out_lens,