| | |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | class DecoderLayer(nn.Module): |
| | | """Single decoder layer module. |
| | | |
| | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size, |
| | | self_attn, |
| | | src_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | self, |
| | | size, |
| | | self_attn, |
| | | src_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | ): |
| | | """Construct an DecoderLayer object.""" |
| | | super(DecoderLayer, self).__init__() |
| | |
| | | tgt_q_mask = tgt_mask[:, -1:, :] |
| | | |
| | | if self.concat_after: |
| | | tgt_concat = torch.cat( |
| | | (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 |
| | | ) |
| | | tgt_concat = torch.cat((tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1) |
| | | x = residual + self.concat_linear1(tgt_concat) |
| | | else: |
| | | x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) |
| | |
| | | if self.normalize_before: |
| | | x = self.norm2(x) |
| | | if self.concat_after: |
| | | x_concat = torch.cat( |
| | | (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 |
| | | ) |
| | | x_concat = torch.cat((x, self.src_attn(x, memory, memory, memory_mask)), dim=-1) |
| | | x = residual + self.concat_linear2(x_concat) |
| | | else: |
| | | x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) |
| | |
| | | |
| | | if cache is not None: |
| | | x = torch.cat([cache, x], dim=1) |
| | | |
| | | return x, tgt_mask, memory, memory_mask |
| | | |
| | | |
| | | class DecoderLayerExport(nn.Module): |
| | | def __init__(self, model): |
| | | super().__init__() |
| | | self.self_attn = model.self_attn |
| | | self.src_attn = model.src_attn |
| | | self.feed_forward = model.feed_forward |
| | | self.norm1 = model.norm1 |
| | | self.norm2 = model.norm2 |
| | | self.norm3 = model.norm3 |
| | | |
| | | def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): |
| | | residual = tgt |
| | | tgt = self.norm1(tgt) |
| | | tgt_q = tgt |
| | | tgt_q_mask = tgt_mask |
| | | x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask) |
| | | |
| | | residual = x |
| | | x = self.norm2(x) |
| | | |
| | | x = residual + self.src_attn(x, memory, memory, memory_mask) |
| | | |
| | | residual = x |
| | | x = self.norm3(x) |
| | | x = residual + self.feed_forward(x) |
| | | |
| | | return x, tgt_mask, memory, memory_mask |
| | | |
| | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | ): |
| | | super().__init__() |
| | | attention_dim = encoder_output_size |
| | |
| | | self.decoders = None |
| | | |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Forward decoder. |
| | | |
| | |
| | | tgt_mask = tgt_mask & m |
| | | |
| | | memory = hs_pad |
| | | memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( |
| | | memory.device |
| | | ) |
| | | memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device) |
| | | # Padding for Longformer |
| | | if memory_mask.shape[-1] != memory.shape[1]: |
| | | padlen = memory.shape[1] - memory_mask.shape[-1] |
| | | memory_mask = torch.nn.functional.pad( |
| | | memory_mask, (0, padlen), "constant", False |
| | | ) |
| | | memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False) |
| | | |
| | | x = self.embed(tgt) |
| | | x, tgt_mask, memory, memory_mask = self.decoders( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | x, tgt_mask, memory, memory_mask = self.decoders(x, tgt_mask, memory, memory_mask) |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | if self.output_layer is not None: |
| | |
| | | return x, olens |
| | | |
| | | def forward_one_step( |
| | | self, |
| | | tgt: torch.Tensor, |
| | | tgt_mask: torch.Tensor, |
| | | memory: torch.Tensor, |
| | | cache: List[torch.Tensor] = None, |
| | | self, |
| | | tgt: torch.Tensor, |
| | | tgt_mask: torch.Tensor, |
| | | memory: torch.Tensor, |
| | | cache: List[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
| | | """Forward one step. |
| | | |
| | |
| | | cache = [None] * len(self.decoders) |
| | | new_cache = [] |
| | | for c, decoder in zip(cache, self.decoders): |
| | | x, tgt_mask, memory, memory_mask = decoder( |
| | | x, tgt_mask, memory, None, cache=c |
| | | ) |
| | | x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, None, cache=c) |
| | | new_cache.append(x) |
| | | |
| | | if self.normalize_before: |
| | |
| | | def score(self, ys, state, x): |
| | | """Score.""" |
| | | ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) |
| | | logp, state = self.forward_one_step( |
| | | ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state |
| | | ) |
| | | logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state) |
| | | return logp.squeeze(0), state |
| | | |
| | | def batch_score( |
| | | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor |
| | | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, List[Any]]: |
| | | """Score new token batch. |
| | | |
| | |
| | | else: |
| | | # transpose state of [batch, layer] into [layer, batch] |
| | | batch_state = [ |
| | | torch.stack([states[b][i] for b in range(n_batch)]) |
| | | for i in range(n_layers) |
| | | torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers) |
| | | ] |
| | | |
| | | # batch decoding |
| | |
| | | state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] |
| | | return logp, state_list |
| | | |
| | | |
| | | @tables.register("decoder_classes", "TransformerDecoder") |
| | | class TransformerDecoder(BaseTransformerDecoder): |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | ): |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | |
| | | num_blocks, |
| | | lambda lnum: DecoderLayer( |
| | | attention_dim, |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, self_attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate), |
| | | MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | |
| | | @tables.register("decoder_classes", "LightweightConvolutionTransformerDecoder") |
| | | class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder): |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | conv_wshare: int = 4, |
| | | conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), |
| | | conv_usebias: int = False, |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | conv_wshare: int = 4, |
| | | conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), |
| | | conv_usebias: int = False, |
| | | ): |
| | | if len(conv_kernel_length) != num_blocks: |
| | | raise ValueError( |
| | |
| | | use_kernel_mask=True, |
| | | use_bias=conv_usebias, |
| | | ), |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | |
| | | ), |
| | | ) |
| | | |
| | | |
| | | @tables.register("decoder_classes", "LightweightConvolution2DTransformerDecoder") |
| | | class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder): |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | conv_wshare: int = 4, |
| | | conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), |
| | | conv_usebias: int = False, |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | conv_wshare: int = 4, |
| | | conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), |
| | | conv_usebias: int = False, |
| | | ): |
| | | if len(conv_kernel_length) != num_blocks: |
| | | raise ValueError( |
| | |
| | | use_kernel_mask=True, |
| | | use_bias=conv_usebias, |
| | | ), |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | |
| | | @tables.register("decoder_classes", "DynamicConvolutionTransformerDecoder") |
| | | class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder): |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | conv_wshare: int = 4, |
| | | conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), |
| | | conv_usebias: int = False, |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | conv_wshare: int = 4, |
| | | conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), |
| | | conv_usebias: int = False, |
| | | ): |
| | | if len(conv_kernel_length) != num_blocks: |
| | | raise ValueError( |
| | |
| | | use_kernel_mask=True, |
| | | use_bias=conv_usebias, |
| | | ), |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | |
| | | ), |
| | | ) |
| | | |
| | | |
| | | @tables.register("decoder_classes", "DynamicConvolution2DTransformerDecoder") |
| | | class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder): |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | conv_wshare: int = 4, |
| | | conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), |
| | | conv_usebias: int = False, |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | conv_wshare: int = 4, |
| | | conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), |
| | | conv_usebias: int = False, |
| | | ): |
| | | if len(conv_kernel_length) != num_blocks: |
| | | raise ValueError( |
| | |
| | | use_kernel_mask=True, |
| | | use_bias=conv_usebias, |
| | | ), |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |