| | |
| | | nbest: int = 1, |
| | | frontend_conf: dict = None, |
| | | hotword_list_or_file: str = None, |
| | | decoding_ind: int = 0, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | |
| | | self.nbest = nbest |
| | | self.frontend = frontend |
| | | self.encoder_downsampling_factor = 1 |
| | | self.decoding_ind = decoding_ind |
| | | if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d": |
| | | self.encoder_downsampling_factor = 4 |
| | | |
| | |
| | | batch = to_device(batch, device=self.device) |
| | | |
| | | # b. Forward Encoder |
| | | enc, enc_len = self.asr_model.encode(**batch) |
| | | enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind) |
| | | if isinstance(enc, tuple): |
| | | enc = enc[0] |
| | | # assert len(enc) == 1, len(enc) |
| | |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | decoding_ind: int = None, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Frontend + Encoder + Decoder + Calc loss |
| | | Args: |
| | |
| | | speech = speech[:, :speech_lengths.max()] |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | if hasattr(self.encoder, "overlap_chunk_cls"): |
| | | ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind) |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind) |
| | | else: |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | |
| | | return {"feats": feats, "feats_lengths": feats_lengths} |
| | | |
| | | def encode( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Frontend + Encoder. Note that this method is used by asr_inference.py |
| | | Args: |
| | |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim2) |
| | | if self.encoder.interctc_use_conditioning: |
| | | encoder_out, encoder_out_lens, _ = self.encoder( |
| | | feats, feats_lengths, ctc=self.ctc |
| | | ) |
| | | if hasattr(self.encoder, "overlap_chunk_cls"): |
| | | encoder_out, encoder_out_lens, _ = self.encoder( |
| | | feats, feats_lengths, ctc=self.ctc, ind=ind |
| | | ) |
| | | encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, |
| | | encoder_out_lens, |
| | | chunk_outs=None) |
| | | else: |
| | | encoder_out, encoder_out_lens, _ = self.encoder( |
| | | feats, feats_lengths, ctc=self.ctc |
| | | ) |
| | | else: |
| | | encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) |
| | | if hasattr(self.encoder, "overlap_chunk_cls"): |
| | | encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind) |
| | | encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, |
| | | encoder_out_lens, |
| | | chunk_outs=None) |
| | | else: |
| | | encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | return var_dict_torch_update |
| | | return var_dict_torch_update |