| | |
| | | ) |
| | | att_list.append(att) |
| | | else: |
| | | raise ValueError( |
| | | "Number of encoders needs to be more than one. {}".format(num_encs) |
| | | ) |
| | | raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs)) |
| | | return att_list |
| | | |
| | | |
| | | @tables.register("decoder_classes", "rnn_decoder") |
| | | class RNNDecoder(nn.Module): |
| | |
| | | self.decoder = torch.nn.ModuleList() |
| | | self.dropout_dec = torch.nn.ModuleList() |
| | | self.decoder += [ |
| | | torch.nn.LSTMCell(hidden_size + eprojs, hidden_size) |
| | | if self.dtype == "lstm" |
| | | else torch.nn.GRUCell(hidden_size + eprojs, hidden_size) |
| | | ( |
| | | torch.nn.LSTMCell(hidden_size + eprojs, hidden_size) |
| | | if self.dtype == "lstm" |
| | | else torch.nn.GRUCell(hidden_size + eprojs, hidden_size) |
| | | ) |
| | | ] |
| | | self.dropout_dec += [torch.nn.Dropout(p=dropout)] |
| | | for _ in range(1, self.dlayers): |
| | | self.decoder += [ |
| | | torch.nn.LSTMCell(hidden_size, hidden_size) |
| | | if self.dtype == "lstm" |
| | | else torch.nn.GRUCell(hidden_size, hidden_size) |
| | | ( |
| | | torch.nn.LSTMCell(hidden_size, hidden_size) |
| | | if self.dtype == "lstm" |
| | | else torch.nn.GRUCell(hidden_size, hidden_size) |
| | | ) |
| | | ] |
| | | self.dropout_dec += [torch.nn.Dropout(p=dropout)] |
| | | # NOTE: dropout is applied only for the vertical connections |
| | |
| | | else: |
| | | self.output = torch.nn.Linear(hidden_size, vocab_size) |
| | | |
| | | self.att_list = build_attention_list( |
| | | eprojs=eprojs, dunits=hidden_size, **att_conf |
| | | ) |
| | | self.att_list = build_attention_list(eprojs=eprojs, dunits=hidden_size, **att_conf) |
| | | |
| | | def zero_state(self, hs_pad): |
| | | return hs_pad.new_zeros(hs_pad.size(0), self.dunits) |
| | |
| | | else: |
| | | z_list[0] = self.decoder[0](ey, z_prev[0]) |
| | | for i in range(1, self.dlayers): |
| | | z_list[i] = self.decoder[i]( |
| | | self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i] |
| | | ) |
| | | z_list[i] = self.decoder[i](self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]) |
| | | return z_list, c_list |
| | | |
| | | def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0): |
| | |
| | | state["a_prev"][self.num_encs], |
| | | ) |
| | | ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) |
| | | z_list, c_list = self.rnn_forward( |
| | | ey, z_list, c_list, state["z_prev"], state["c_prev"] |
| | | ) |
| | | z_list, c_list = self.rnn_forward(ey, z_list, c_list, state["z_prev"], state["c_prev"]) |
| | | if self.context_residual: |
| | | logits = self.output( |
| | | torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) |
| | | ) |
| | | logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) |
| | | else: |
| | | logits = self.output(self.dropout_dec[-1](z_list[-1])) |
| | | logp = F.log_softmax(logits, dim=1).squeeze(0) |