| | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | dropout_rate: float = 0.0, |
| | | whisper_model: str = "small", |
| | | download_dir: str = None, |
| | | use_specaug: bool = False, |
| | | use_padmask: bool = False, |
| | | specaug_conf: Union[dict, None] = None, |
| | | self, |
| | | dropout_rate: float = 0.0, |
| | | whisper_model: str = "small", |
| | | download_dir: str = None, |
| | | use_specaug: bool = False, |
| | | use_padmask: bool = False, |
| | | specaug_conf: Union[dict, None] = None, |
| | | ): |
| | | super().__init__() |
| | | |
| | |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | |
| | | assert whisper_model in whisper.available_models() |
| | | _model = whisper.load_model( |
| | | whisper_model, download_root=download_dir, device="cpu" |
| | | ) |
| | | _model = whisper.load_model(whisper_model, download_root=download_dir, device="cpu") |
| | | self.encoders = copy.deepcopy(_model.encoder) |
| | | self.encoders.train() |
| | | |
| | |
| | | self.use_padmask = use_padmask |
| | | |
| | | def whisper_encode( |
| | | self, |
| | | input: torch.Tensor, |
| | | ilens: torch.Tensor = None, |
| | | self, |
| | | input: torch.Tensor, |
| | | ilens: torch.Tensor = None, |
| | | ) -> torch.Tensor: |
| | | x = F.gelu(self.encoders.conv1(input)) |
| | | x = F.gelu(self.encoders.conv2(x)) |
| | |
| | | |
| | | if ilens is not None: |
| | | olens = ( |
| | | 1 |
| | | + ( |
| | | ilens |
| | | - self.encoders.conv2.kernel_size[0] |
| | | + 2 * self.encoders.conv2.padding[0] |
| | | ) |
| | | // self.encoders.conv2.stride[0] |
| | | 1 |
| | | + (ilens - self.encoders.conv2.kernel_size[0] + 2 * self.encoders.conv2.padding[0]) |
| | | // self.encoders.conv2.stride[0] |
| | | ) |
| | | olens = torch.clamp(olens, max=max_pos) |
| | | else: |
| | |
| | | return self.encoders.conv2.weight.shape[0] |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | feats, feats_lens = xs_pad, ilens |
| | | |