| | |
| | | from funasr.models.transformer.utils.repeat import repeat |
| | | from funasr.register import tables |
| | | |
| | | |
| | | class EncoderLayer(nn.Module): |
| | | """Encoder layer module. |
| | | |
| | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size, |
| | | self_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | stochastic_depth_rate=0.0, |
| | | self, |
| | | size, |
| | | self_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | stochastic_depth_rate=0.0, |
| | | ): |
| | | """Construct an EncoderLayer object.""" |
| | | super(EncoderLayer, self).__init__() |
| | |
| | | x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) |
| | | x = residual + stoch_layer_coeff * self.concat_linear(x_concat) |
| | | else: |
| | | x = residual + stoch_layer_coeff * self.dropout( |
| | | self.self_attn(x_q, x, x, mask) |
| | | ) |
| | | x = residual + stoch_layer_coeff * self.dropout(self.self_attn(x_q, x, x, mask)) |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | |
| | | x = torch.cat([cache, x], dim=1) |
| | | |
| | | return x, mask |
| | | |
| | | |
| | | @tables.register("encoder_classes", "TransformerTextEncoder") |
| | | class TransformerTextEncoder(nn.Module): |
| | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | ): |
| | | super().__init__() |
| | | self._output_size = output_size |
| | |
| | | num_blocks, |
| | | lambda lnum: EncoderLayer( |
| | | output_size, |
| | | MultiHeadedAttention( |
| | | attention_heads, output_size, attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention(attention_heads, output_size, attention_dropout_rate), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | normalize_before, |
| | |
| | | return self._output_size |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Embed positions in tensor. |
| | | |
| | |
| | | |
| | | olens = masks.squeeze(1).sum(1) |
| | | return xs_pad, olens, None |
| | | |
| | | |
| | | |
| | | |
| | | @tables.register("encoder_classes", "FusionSANEncoder") |
| | |
| | | |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size, |
| | | attention_heads, |
| | | attention_dim, |
| | | linear_units, |
| | | self_attention_dropout_rate, |
| | | src_attention_dropout_rate, |
| | | positional_dropout_rate, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | self, |
| | | size, |
| | | attention_heads, |
| | | attention_dim, |
| | | linear_units, |
| | | self_attention_dropout_rate, |
| | | src_attention_dropout_rate, |
| | | positional_dropout_rate, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | ): |
| | | """Construct an SelfSrcAttention object.""" |
| | | super(SelfSrcAttention, self).__init__() |
| | | self.size = size |
| | | self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate) |
| | | self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate) |
| | | self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate) |
| | | self.self_attn = MultiHeadedAttention( |
| | | attention_heads, attention_dim, self_attention_dropout_rate |
| | | ) |
| | | self.src_attn = MultiHeadedAttentionReturnWeight( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ) |
| | | self.feed_forward = PositionwiseFeedForward( |
| | | attention_dim, linear_units, positional_dropout_rate |
| | | ) |
| | | self.norm1 = LayerNorm(size) |
| | | self.norm2 = LayerNorm(size) |
| | | self.norm3 = LayerNorm(size) |
| | |
| | | tgt_q_mask = tgt_mask[:, -1:, :] |
| | | |
| | | if self.concat_after: |
| | | tgt_concat = torch.cat( |
| | | (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 |
| | | ) |
| | | tgt_concat = torch.cat((tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1) |
| | | x = residual + self.concat_linear1(tgt_concat) |
| | | else: |
| | | x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) |
| | |
| | | if self.normalize_before: |
| | | x = self.norm2(x) |
| | | if self.concat_after: |
| | | x_concat = torch.cat( |
| | | (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 |
| | | ) |
| | | x_concat = torch.cat((x, self.src_attn(x, memory, memory, memory_mask)), dim=-1) |
| | | x = residual + self.concat_linear2(x_concat) |
| | | else: |
| | | x, score = self.src_attn(x, memory, memory, memory_mask) |
| | |
| | | |
| | | @tables.register("encoder_classes", "ConvBiasPredictor") |
| | | class ConvPredictor(nn.Module): |
| | | def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048): |
| | | def __init__( |
| | | self, |
| | | size=256, |
| | | l_order=3, |
| | | r_order=3, |
| | | attention_heads=4, |
| | | attention_dropout_rate=0.1, |
| | | linear_units=2048, |
| | | ): |
| | | super().__init__() |
| | | self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate) |
| | | self.norm1 = LayerNorm(size) |
| | |
| | | self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size) |
| | | self.output_linear = nn.Linear(size, 1) |
| | | |
| | | |
| | | def forward(self, text_enc, asr_enc): |
| | | # stage1 cross-attention |
| | | residual = text_enc |
| | | text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None) |
| | | |
| | | |
| | | # stage2 FFN |
| | | residual = text_enc |
| | | text_enc = self.norm1(text_enc) |
| | | text_enc = residual + self.feed_forward(text_enc) |
| | | |
| | | |
| | | # stage Conv predictor |
| | | text_enc = self.norm2(text_enc) |
| | | context = text_enc.transpose(1, 2) |
| | |
| | | output = output.transpose(1, 2) |
| | | output = torch.relu(output) |
| | | output = self.output_linear(output) |
| | | if output.dim()==3: |
| | | output = output.squeeze(2) |
| | | if output.dim() == 3: |
| | | output = output.squeeze(2) |
| | | return output |