| | |
| | | """Sequential implementation of Recurrent Neural Network Language Model.""" |
| | | |
| | | from typing import Tuple |
| | | from typing import Union |
| | | |
| | |
| | | self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=ignore_id) |
| | | if rnn_type in ["LSTM", "GRU"]: |
| | | rnn_class = getattr(nn, rnn_type) |
| | | self.rnn = rnn_class( |
| | | ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True |
| | | ) |
| | | self.rnn = rnn_class(ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True) |
| | | else: |
| | | try: |
| | | nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type] |
| | |
| | | # https://arxiv.org/abs/1611.01462 |
| | | if tie_weights: |
| | | if nhid != ninp: |
| | | raise ValueError( |
| | | "When using the tied flag, nhid must be equal to emsize" |
| | | ) |
| | | raise ValueError("When using the tied flag, nhid must be equal to emsize") |
| | | self.decoder.weight = self.encoder.weight |
| | | |
| | | self.rnn_type = rnn_type |