| | |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | class EncoderLayer(nn.Module): |
| | | """Encoder layer module. |
| | | |
| | |
| | | x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) |
| | | x = residual + stoch_layer_coeff * self.concat_linear(x_concat) |
| | | else: |
| | | x = residual + stoch_layer_coeff * self.dropout( |
| | | self.self_attn(x_q, x, x, mask) |
| | | ) |
| | | x = residual + stoch_layer_coeff * self.dropout(self.self_attn(x_q, x, x, mask)) |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | |
| | | x = torch.cat([cache, x], dim=1) |
| | | |
| | | return x, mask |
| | | |
| | | |
| | | @tables.register("encoder_classes", "TransformerEncoder") |
| | | class TransformerEncoder(nn.Module): |
| | |
| | | num_blocks, |
| | | lambda lnum: EncoderLayer( |
| | | output_size, |
| | | MultiHeadedAttention( |
| | | attention_heads, output_size, attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention(attention_heads, output_size, attention_dropout_rate), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | normalize_before, |
| | |
| | | if len(intermediate_outs) > 0: |
| | | return (xs_pad, intermediate_outs), olens, None |
| | | return xs_pad, olens, None |
| | | |