wuhongsheng
2024-07-05 3a4281f4959534b1bf5d01acf0085f4f8e6f2ec8
funasr/models/language_model/seq_rnn_lm.py
@@ -1,4 +1,5 @@
"""Sequential implementation of Recurrent Neural Network Language Model."""
from typing import Tuple
from typing import Union
@@ -37,9 +38,7 @@
        self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=ignore_id)
        if rnn_type in ["LSTM", "GRU"]:
            rnn_class = getattr(nn, rnn_type)
            self.rnn = rnn_class(
                ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True
            )
            self.rnn = rnn_class(ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True)
        else:
            try:
                nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
@@ -67,9 +66,7 @@
        # https://arxiv.org/abs/1611.01462
        if tie_weights:
            if nhid != ninp:
                raise ValueError(
                    "When using the tied flag, nhid must be equal to emsize"
                )
                raise ValueError("When using the tied flag, nhid must be equal to emsize")
            self.decoder.weight = self.encoder.weight
        self.rnn_type = rnn_type