| | |
| | | nbest: int = 1, |
| | | frontend_conf: dict = None, |
| | | hotword_list_or_file: str = None, |
| | | decoding_ind: int = 0, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | |
| | | self.nbest = nbest |
| | | self.frontend = frontend |
| | | self.encoder_downsampling_factor = 1 |
| | | self.decoding_ind = decoding_ind |
| | | if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d": |
| | | self.encoder_downsampling_factor = 4 |
| | | |
| | |
| | | batch = to_device(batch, device=self.device) |
| | | |
| | | # b. Forward Encoder |
| | | enc, enc_len = self.asr_model.encode(**batch) |
| | | enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind) |
| | | if isinstance(enc, tuple): |
| | | enc = enc[0] |
| | | # assert len(enc) == 1, len(enc) |