huangmingming
2023-01-30 adcee8828ef5d78b575043954deb662a35e318f7
funasr/models/e2e_uni_asr.py
@@ -198,16 +198,15 @@
        # for data-parallel
        text = text[:, : text_lengths.max()]
        speech = speech[:, :speech_lengths.max(), :]
        speech = speech[:, :speech_lengths.max()]
        ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
        speech_raw = speech.clone().to(speech.device)
        # 1. Encoder
        if self.enable_maas_finetune:
            with torch.no_grad():
                encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
                speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
        else:
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
            speech_raw, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
@@ -486,7 +485,7 @@
            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                feats, feats_lengths = self.normalize(feats, feats_lengths)
        speech_raw = feats.clone().to(feats.device)
        # Pre-encoder, e.g. used for raw input data
        if self.preencoder is not None:
            feats, feats_lengths = self.preencoder(feats, feats_lengths)
@@ -523,7 +522,7 @@
        if intermediate_outs is not None:
            return (encoder_out, intermediate_outs), encoder_out_lens
        return encoder_out, encoder_out_lens
        return speech_raw, encoder_out, encoder_out_lens
    def encode2(
        self,