liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/language_model/rnn/encoders.py
@@ -7,7 +7,7 @@
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from funasr.metrics import get_vgg2l_odim
from funasr.metrics.common import get_vgg2l_odim
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.nets_utils import to_device
@@ -34,9 +34,7 @@
                inputdim = hdim
            RNN = torch.nn.LSTM if "lstm" in typ else torch.nn.GRU
            rnn = RNN(
                inputdim, cdim, num_layers=1, bidirectional=bidir, batch_first=True
            )
            rnn = RNN(inputdim, cdim, num_layers=1, bidirectional=bidir, batch_first=True)
            setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
@@ -72,9 +70,7 @@
            rnn.flatten_parameters()
            if prev_state is not None and rnn.bidirectional:
                prev_state = reset_backward_rnn_state(prev_state)
            ys, states = rnn(
                xs_pack, hx=None if prev_state is None else prev_state[layer]
            )
            ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[layer])
            elayer_states.append(states)
            # ys: utt list of frame x cdim x 2 (2: means bidirectional)
            ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
@@ -155,9 +151,7 @@
        # ys: utt list of frame x cdim x 2 (2: means bidirectional)
        ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
        # (sum _utt frame_utt) x dim
        projected = torch.tanh(
            self.l_last(ys_pad.contiguous().view(-1, ys_pad.size(2)))
        )
        projected = torch.tanh(self.l_last(ys_pad.contiguous().view(-1, ys_pad.size(2))))
        xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
        return xs_pad, ilens, states  # x: utt list of frame x dim
@@ -225,9 +219,7 @@
        else:
            ilens = np.array(ilens, dtype=np.float32)
        ilens = np.array(np.ceil(ilens / 2), dtype=np.int64)
        ilens = np.array(
            np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64
        ).tolist()
        ilens = np.array(np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64).tolist()
        # x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
        xs_pad = xs_pad.transpose(1, 2)
@@ -250,9 +242,7 @@
    :param int in_channel: number of input channels
    """
    def __init__(
            self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1
    ):
    def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1):
        super(Encoder, self).__init__()
        typ = etype.lstrip("vgg").rstrip("p")
        if typ not in ["lstm", "gru", "blstm", "bgru"]:
@@ -367,6 +357,4 @@
            enc_list.append(enc)
        return enc_list
    else:
        raise ValueError(
            "Number of encoders needs to be more than one. {}".format(num_encs)
        )
        raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))