| | |
| | | x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() |
| | | |
| | | return x |
| | | |
| | | |
| | | class MultiHeadedAttentionSANMDecoder(nn.Module): |
| | | """Multi-Head Attention layer. |
| | | |
| | | Args: |
| | | n_head (int): The number of heads. |
| | | n_feat (int): The number of features. |
| | | dropout_rate (float): Dropout rate. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0): |
| | | """Construct an MultiHeadedAttention object.""" |
| | | super().__init__() |
| | | |
| | | self.dropout = nn.Dropout(p=dropout_rate) |
| | | |
| | | self.fsmn_block = nn.Conv1d( |
| | | n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False |
| | | ) |
| | | # padding |
| | | # padding |
| | | left_padding = (kernel_size - 1) // 2 |
| | | if sanm_shfit > 0: |
| | | left_padding = left_padding + sanm_shfit |
| | | right_padding = kernel_size - 1 - left_padding |
| | | self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) |
| | | self.kernel_size = kernel_size |
| | | |
| | | def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None): |
| | | """ |
| | | :param x: (#batch, time1, size). |
| | | :param mask: Mask tensor (#batch, 1, time) |
| | | :return: |
| | | """ |
| | | # print("in fsmn, inputs", inputs.size()) |
| | | b, t, d = inputs.size() |
| | | # logging.info( |
| | | # "mask: {}".format(mask.size())) |
| | | if mask is not None: |
| | | mask = torch.reshape(mask, (b, -1, 1)) |
| | | # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :])) |
| | | if mask_shfit_chunk is not None: |
| | | # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :])) |
| | | mask = mask * mask_shfit_chunk |
| | | # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :])) |
| | | # print("in fsmn, mask", mask.size()) |
| | | # print("in fsmn, inputs", inputs.size()) |
| | | inputs = inputs * mask |
| | | |
| | | x = inputs.transpose(1, 2) |
| | | b, d, t = x.size() |
| | | if cache is None: |
| | | # print("in fsmn, cache is None, x", x.size()) |
| | | |
| | | x = self.pad_fn(x) |
| | | if not self.training: |
| | | cache = x |
| | | else: |
| | | # print("in fsmn, cache is not None, x", x.size()) |
| | | # x = torch.cat((x, cache), dim=2)[:, :, :-1] |
| | | # if t < self.kernel_size: |
| | | # x = self.pad_fn(x) |
| | | x = torch.cat((cache[:, :, 1:], x), dim=2) |
| | | x = x[:, :, -(self.kernel_size + t - 1) :] |
| | | # print("in fsmn, cache is not None, x_cat", x.size()) |
| | | cache = x |
| | | x = self.fsmn_block(x) |
| | | x = x.transpose(1, 2) |
| | | # print("in fsmn, fsmn_out", x.size()) |
| | | if x.size(1) != inputs.size(1): |
| | | inputs = inputs[:, -1, :] |
| | | |
| | | x = x + inputs |
| | | x = self.dropout(x) |
| | | if mask is not None: |
| | | x = x * mask |
| | | return x, cache |
| | | |
| | | |
| | | class ResidualAttentionBlockFSMN(nn.Module): |
| | | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, **kwargs): |
| | | super().__init__() |
| | | |
| | | self.attn = MultiHeadedAttentionSANMDecoder( |
| | | n_state, |
| | | kwargs.get("self_attention_dropout_rate"), |
| | | kwargs.get("kernel_size", 20), |
| | | kwargs.get("sanm_shfit", 10), |
| | | ) |
| | | self.attn_ln = LayerNorm(n_state) |
| | | |
| | | self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None |
| | | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None |
| | | |
| | | n_mlp = n_state * 4 |
| | | self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) |
| | | self.mlp_ln = LayerNorm(n_state) |
| | | |
| | | def forward( |
| | | self, |
| | | x: Tensor, |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False) |
| | | x = x + self.attn(self.attn_ln(x), mask=None, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] |
| | | if self.cross_attn: |
| | | x = ( |
| | | x |
| | | + self.cross_attn( |
| | | self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask |
| | | )[0] |
| | | ) |
| | | x = x + self.mlp(self.mlp_ln(x)) |
| | | return x |
| | | |
| | | |
| | | @tables.register("decoder_classes", "SenseVoiceDecoderFSMN") |
| | | class SenseVoiceDecoderFSMN(nn.Module): |
| | | def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs): |
| | | super().__init__() |
| | | |
| | | self.token_embedding = nn.Embedding(n_vocab, n_state) |
| | | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) |
| | | |
| | | self.blocks = nn.ModuleList( |
| | | [ |
| | | ResidualAttentionBlockFSMN( |
| | | n_state, n_head, cross_attention=True, layer_id=i, **kwargs |
| | | ) |
| | | for i in range(n_layer) |
| | | ] |
| | | ) |
| | | self.ln = LayerNorm(n_state) |
| | | |
| | | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) |
| | | self.register_buffer("mask", mask, persistent=False) |
| | | |
| | | self.use_padmask = kwargs.get("use_padmask", True) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | xa: torch.Tensor, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | """Forward decoder. |
| | | |
| | | Args: |
| | | hs_pad: encoded memory, float32 (batch, maxlen_in, feat) |
| | | hlens: (batch) |
| | | ys_in_pad: |
| | | input token ids, int64 (batch, maxlen_out) |
| | | if input_layer == "embed" |
| | | input tensor (batch, maxlen_out, #mels) in the other cases |
| | | ys_in_lens: (batch) |
| | | Returns: |
| | | (tuple): tuple containing: |
| | | |
| | | x: decoded token score before softmax (batch, maxlen_out, token) |
| | | if use_output_layer is True, |
| | | olens: (batch, ) |
| | | """ |
| | | # import pdb;pdb.set_trace() |
| | | use_padmask = self.use_padmask |
| | | hlens = kwargs.get("hlens", None) |
| | | |
| | | ys_in_lens = kwargs.get("ys_in_lens", None) |
| | | |
| | | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
| | | tgt, memory = x, xa |
| | | tgt[tgt == -1] = 0 |
| | | tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)] |
| | | # tgt = self.dropout(tgt) |
| | | |
| | | x = tgt.to(memory.dtype) |
| | | |
| | | if use_padmask and hlens is not None: |
| | | memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device) |
| | | else: |
| | | memory_mask = None |
| | | |
| | | for layer, block in enumerate(self.blocks): |
| | | x = block( |
| | | x, |
| | | memory, |
| | | mask=self.mask, |
| | | memory_mask=memory_mask, |
| | | is_pad_mask=False, |
| | | is_pad_memory_mask=True, |
| | | ) |
| | | |
| | | x = self.ln(x) |
| | | x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() |
| | | |
| | | return x |