| | |
| | | whisper_model: str = None, |
| | | do_pad_trim: bool = True, |
| | | n_mels: int = 80, |
| | | permute: bool = False, |
| | | **kwargs, |
| | | ): |
| | | super().__init__() |
| | | assert fs == 16000 |
| | |
| | | self.do_pad_trim = do_pad_trim |
| | | if do_pad_trim: |
| | | self.pad_or_trim = whisper.pad_or_trim |
| | | self.permute = permute |
| | | |
| | | # assert whisper_model in whisper.available_models() |
| | | |
| | |
| | | return log_spec, olens |
| | | |
| | | def forward( |
| | | self, input: torch.Tensor, input_lengths: torch.Tensor |
| | | self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | batch_size = input.size(0) |
| | | feats = [] |
| | |
| | | feats_pad = pad_sequence(feats, |
| | | batch_first=True, |
| | | padding_value=0.0) |
| | | |
| | | if self.permute: |
| | | feats_pad = feats_pad.permute(0, 2, 1) |
| | | return feats_pad, feats_lens |