| | |
| | | """Speech Pretrained Representation frontend structure for ASR.""" |
| | | |
| | | def __init__( |
| | | self, |
| | | fs: Union[int, str] = 16000, |
| | | frontend_conf: Optional[dict] = None, |
| | | download_dir: str = None, |
| | | multilayer_feature: bool = False, |
| | | self, |
| | | fs: Union[int, str] = 16000, |
| | | frontend_conf: Optional[dict] = None, |
| | | download_dir: str = None, |
| | | multilayer_feature: bool = False, |
| | | ): |
| | | super().__init__() |
| | | if isinstance(fs, str): |
| | |
| | | ).to("cpu") |
| | | |
| | | if getattr( |
| | | s3prl_upstream, "model", None |
| | | s3prl_upstream, "model", None |
| | | ) is not None and s3prl_upstream.model.__class__.__name__ in [ |
| | | "Wav2Vec2Model", |
| | | "HubertModel", |
| | |
| | | Output - sequence of tiled representations |
| | | shape: (batch_size, seq_len * factor, feature_dim) |
| | | """ |
| | | assert ( |
| | | len(feature.shape) == 3 |
| | | ), "Input argument `feature` has invalid shape: {}".format(feature.shape) |
| | | assert len(feature.shape) == 3, "Input argument `feature` has invalid shape: {}".format( |
| | | feature.shape |
| | | ) |
| | | tiled_feature = feature.repeat(1, 1, self.args.tile_factor) |
| | | tiled_feature = tiled_feature.reshape( |
| | | feature.size(0), feature.size(1) * self.args.tile_factor, feature.size(2) |
| | |
| | | return self.output_dim |
| | | |
| | | def forward( |
| | | self, input: torch.Tensor, input_lengths: torch.Tensor |
| | | self, input: torch.Tensor, input_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)] |
| | | self.upstream.eval() |