| | |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.transformer.utils.nets_utils import to_device |
| | | from funasr.models.language_model.rnn.attentions import initial_att |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | from funasr.utils.get_default_kwargs import get_default_kwargs |
| | | |
| | | |
| | | def build_attention_list( |
| | |
| | | return att_list |
| | | |
| | | |
| | | class RNNDecoder(AbsDecoder): |
| | | class RNNDecoder(nn.Module): |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | |
| | | context_residual: bool = False, |
| | | replace_sos: bool = False, |
| | | num_encs: int = 1, |
| | | att_conf: dict = get_default_kwargs(build_attention_list), |
| | | att_conf: dict = None, |
| | | ): |
| | | # FIXME(kamo): The parts of num_spk should be refactored more more more |
| | | if rnn_type not in {"lstm", "gru"}: |