| | |
| | | import random |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import random |
| | | import numpy as np |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | |
| | | from funasr.register import tables |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.transformer.utils.nets_utils import to_device |
| | | from funasr.models.language_model.rnn.attentions import initial_att |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | from funasr.utils.get_default_kwargs import get_default_kwargs |
| | | |
| | | |
| | | def build_attention_list( |
| | |
| | | ) |
| | | att_list.append(att) |
| | | else: |
| | | raise ValueError( |
| | | "Number of encoders needs to be more than one. {}".format(num_encs) |
| | | ) |
| | | raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs)) |
| | | return att_list |
| | | |
| | | |
| | | class RNNDecoder(AbsDecoder): |
| | | @tables.register("decoder_classes", "rnn_decoder") |
| | | class RNNDecoder(nn.Module): |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | |
| | | context_residual: bool = False, |
| | | replace_sos: bool = False, |
| | | num_encs: int = 1, |
| | | att_conf: dict = get_default_kwargs(build_attention_list), |
| | | att_conf: dict = None, |
| | | ): |
| | | # FIXME(kamo): The parts of num_spk should be refactored more more more |
| | | if rnn_type not in {"lstm", "gru"}: |
| | |
| | | self.decoder = torch.nn.ModuleList() |
| | | self.dropout_dec = torch.nn.ModuleList() |
| | | self.decoder += [ |
| | | torch.nn.LSTMCell(hidden_size + eprojs, hidden_size) |
| | | if self.dtype == "lstm" |
| | | else torch.nn.GRUCell(hidden_size + eprojs, hidden_size) |
| | | ( |
| | | torch.nn.LSTMCell(hidden_size + eprojs, hidden_size) |
| | | if self.dtype == "lstm" |
| | | else torch.nn.GRUCell(hidden_size + eprojs, hidden_size) |
| | | ) |
| | | ] |
| | | self.dropout_dec += [torch.nn.Dropout(p=dropout)] |
| | | for _ in range(1, self.dlayers): |
| | | self.decoder += [ |
| | | torch.nn.LSTMCell(hidden_size, hidden_size) |
| | | if self.dtype == "lstm" |
| | | else torch.nn.GRUCell(hidden_size, hidden_size) |
| | | ( |
| | | torch.nn.LSTMCell(hidden_size, hidden_size) |
| | | if self.dtype == "lstm" |
| | | else torch.nn.GRUCell(hidden_size, hidden_size) |
| | | ) |
| | | ] |
| | | self.dropout_dec += [torch.nn.Dropout(p=dropout)] |
| | | # NOTE: dropout is applied only for the vertical connections |
| | |
| | | else: |
| | | self.output = torch.nn.Linear(hidden_size, vocab_size) |
| | | |
| | | self.att_list = build_attention_list( |
| | | eprojs=eprojs, dunits=hidden_size, **att_conf |
| | | ) |
| | | self.att_list = build_attention_list(eprojs=eprojs, dunits=hidden_size, **att_conf) |
| | | |
| | | def zero_state(self, hs_pad): |
| | | return hs_pad.new_zeros(hs_pad.size(0), self.dunits) |
| | |
| | | else: |
| | | z_list[0] = self.decoder[0](ey, z_prev[0]) |
| | | for i in range(1, self.dlayers): |
| | | z_list[i] = self.decoder[i]( |
| | | self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i] |
| | | ) |
| | | z_list[i] = self.decoder[i](self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]) |
| | | return z_list, c_list |
| | | |
| | | def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0): |
| | |
| | | state["a_prev"][self.num_encs], |
| | | ) |
| | | ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) |
| | | z_list, c_list = self.rnn_forward( |
| | | ey, z_list, c_list, state["z_prev"], state["c_prev"] |
| | | ) |
| | | z_list, c_list = self.rnn_forward(ey, z_list, c_list, state["z_prev"], state["c_prev"]) |
| | | if self.context_residual: |
| | | logits = self.output( |
| | | torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) |
| | | ) |
| | | logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) |
| | | else: |
| | | logits = self.output(self.dropout_dec[-1](z_list[-1])) |
| | | logp = F.log_softmax(logits, dim=1).squeeze(0) |