zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/transducer/model.py
@@ -165,9 +165,13 @@
        batch_size = speech.shape[0]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None:
            encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
                                                                                        chunk_outs=None)
        if (
            hasattr(self.encoder, "overlap_chunk_cls")
            and self.encoder.overlap_chunk_cls is not None
        ):
            encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(
                encoder_out, encoder_out_lens, chunk_outs=None
            )
        # 2. Transducer-related I/O preparation
        decoder_in, target, t_len, u_len = get_transducer_task_io(
            text,
@@ -180,9 +184,7 @@
        decoder_out = self.decoder(decoder_in, u_len)
        
        # 4. Joint Network
        joint_out = self.joint_network(
            encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
        )
        joint_out = self.joint_network(encoder_out.unsqueeze(2), decoder_out.unsqueeze(1))
        
        # 5. Losses
        loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
@@ -227,7 +229,10 @@
        return loss, stats, weight
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
@@ -285,12 +290,12 @@
        if self.criterion_transducer is None:
            try:
                from warp_rnnt import rnnt_loss as RNNTLoss
                self.criterion_transducer = RNNTLoss
            
            except ImportError:
                logging.error(
                    "warp-rnnt was not installed."
                    "Please consult the installation documentation."
                    "warp-rnnt was not installed." "Please consult the installation documentation."
                )
                exit(1)
        
@@ -346,9 +351,7 @@
            loss_ctc: CTC loss value.
        """
        ctc_in = self.ctc_lin(
            torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
        )
        ctc_in = self.ctc_lin(torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate))
        ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
        
        target_mask = target != 0
@@ -400,13 +403,12 @@
            true_dist,
            reduction="none",
        )
        loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
            0
        )
        loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(0)
        
        return loss_lm
    
    def init_beam_search(self,
    def init_beam_search(
        self,
                         **kwargs,
                         ):
    
@@ -415,9 +417,7 @@
        
        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(
                ctc=ctc
            )
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            length_bonus=LengthBonus(len(token_list)),
@@ -440,7 +440,8 @@
        #         scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
        self.beam_search = beam_search
        
    def inference(self,
    def inference(
        self,
                  data_in: list,
                  data_lengths: list=None,
                  key: list=None,
@@ -453,7 +454,9 @@
        
        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        # if self.beam_search is None and (is_use_lm or is_use_ctc):
        logging.info("enable beam_search")
        self.init_beam_search(**kwargs)
@@ -462,13 +465,19 @@
        meta_data = {}
        # extract fbank feats
        time1 = time.perf_counter()
        audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
        audio_sample_list = load_audio_text_image_video(
            data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)
        )
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
        speech, speech_lengths = extract_fbank(
            audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend
        )
        time3 = time.perf_counter()
        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
        meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
        meta_data["batch_data_time"] = (
            speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
        )
        
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
@@ -500,14 +509,23 @@
                    token_int = hyp.yseq#[1:last_pos].tolist()
                    
                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
                token_int = list(
                    filter(
                        lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
                    )
                )
                
                # Change integer-ids to tokens
                token = tokenizer.ids2tokens(token_int)
                text = tokenizer.tokens2text(token)
                
                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
                result_i = {
                    "key": key[i],
                    "token": token,
                    "text": text,
                    "text_postprocessed": text_postprocessed,
                }
                results.append(result_i)
                
                if ibest_writer is not None:
@@ -516,4 +534,3 @@
                    ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
        
        return results, meta_data