| | |
| | | encoder: AbsEncoder, |
| | | predictor: CifPredictorV3, |
| | | predictor_bias: int = 0, |
| | | token_list=None, |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | |
| | | self.predictor = predictor |
| | | self.predictor_bias = predictor_bias |
| | | self.criterion_pre = mae_loss() |
| | | self.token_list = token_list |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | encoder_out_mask, |
| | | token_num) |
| | | return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak |
| | | |
| | | def collect_feats( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | ) -> Dict[str, torch.Tensor]: |
| | | if self.extract_feats_in_collect_stats: |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | else: |
| | | # Generate dummy stats if extract_feats_in_collect_stats is False |
| | | logging.warning( |
| | | "Generating dummy stats for feats and feats_lengths, " |
| | | "because encoder_conf.extract_feats_in_collect_stats is " |
| | | f"{self.extract_feats_in_collect_stats}" |
| | | ) |
| | | feats, feats_lengths = speech, speech_lengths |
| | | return {"feats": feats, "feats_lengths": feats_lengths} |