| | |
| | | else: |
| | | self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim) |
| | | |
| | | def forward( |
| | | self, feats: torch.Tensor, feats_mask: torch.Tensor |
| | | ) -> Union[ |
| | | def forward(self, feats: torch.Tensor, feats_mask: torch.Tensor) -> Union[ |
| | | Tuple[torch.Tensor, torch.Tensor], |
| | | Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], |
| | | ]: |
| | |
| | | |
| | | b, c, t, f = vgg_output.size() |
| | | |
| | | vgg_output = self.output( |
| | | vgg_output.transpose(1, 2).contiguous().view(b, t, c * f) |
| | | ) |
| | | vgg_output = self.output(vgg_output.transpose(1, 2).contiguous().view(b, t, c * f)) |
| | | |
| | | if feats_mask is not None: |
| | | vgg_mask = self.create_new_mask(feats_mask) |