| | |
| | | """Construct a RNNDecoder object.""" |
| | | super().__init__() |
| | | |
| | | |
| | | if rnn_type not in ("lstm", "gru"): |
| | | raise ValueError(f"Not supported: rnn_type={rnn_type}") |
| | | |
| | |
| | | |
| | | rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU |
| | | |
| | | self.rnn = torch.nn.ModuleList( |
| | | [rnn_class(embed_size, hidden_size, 1, batch_first=True)] |
| | | ) |
| | | self.rnn = torch.nn.ModuleList([rnn_class(embed_size, hidden_size, 1, batch_first=True)]) |
| | | |
| | | for _ in range(1, num_layers): |
| | | self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)] |
| | |
| | | time_mask_width_range=3, |
| | | num_time_mask=4, |
| | | apply_freq_mask=False, |
| | | apply_time_warp=False |
| | | apply_time_warp=False, |
| | | ) |
| | | |
| | | |
| | | def forward( |
| | | self, |
| | | labels: torch.Tensor, |
| | |
| | | |
| | | for layer in range(self.dlayers): |
| | | if self.dtype == "lstm": |
| | | x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[ |
| | | layer |
| | | ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1])) |
| | | else: |
| | | x, h_next[layer : layer + 1] = self.rnn[layer]( |
| | | x, hx=h_prev[layer : layer + 1] |
| | | x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[layer]( |
| | | x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]) |
| | | ) |
| | | else: |
| | | x, h_next[layer : layer + 1] = self.rnn[layer](x, hx=h_prev[layer : layer + 1]) |
| | | |
| | | x = self.dropout_rnn[layer](x) |
| | | |
| | |
| | | """ |
| | | self.device = device |
| | | |
| | | def init_state( |
| | | self, batch_size: int |
| | | ) -> Tuple[torch.Tensor, Optional[torch.tensor]]: |
| | | def init_state(self, batch_size: int) -> Tuple[torch.Tensor, Optional[torch.tensor]]: |
| | | """Initialize decoder states. |
| | | |
| | | Args: |
| | |
| | | """ |
| | | return ( |
| | | torch.cat([s[0] for s in new_states], dim=1), |
| | | torch.cat([s[1] for s in new_states], dim=1) |
| | | if self.dtype == "lstm" |
| | | else None, |
| | | torch.cat([s[1] for s in new_states], dim=1) if self.dtype == "lstm" else None, |
| | | ) |