| | |
| | | from funasr.models.scama import utils as myutils |
| | | from funasr.models.transformer.decoder import BaseTransformerDecoder |
| | | |
| | | from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt |
| | | from funasr.models.sanm.attention import ( |
| | | MultiHeadedAttentionSANMDecoder, |
| | | MultiHeadedAttentionCrossAtt, |
| | | ) |
| | | from funasr.models.transformer.embedding import PositionalEncoding |
| | | from funasr.models.transformer.layer_norm import LayerNorm |
| | | from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM |
| | | from funasr.models.transformer.utils.repeat import repeat |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | class DecoderLayerSANM(nn.Module): |
| | | """Single decoder layer module. |
| | |
| | | |
| | | x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) |
| | | |
| | | |
| | | return x, tgt_mask, memory, memory_mask, cache |
| | | |
| | | def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0): |
| | | def forward_chunk( |
| | | self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0 |
| | | ): |
| | | """Compute decoded features. |
| | | |
| | | Args: |
| | |
| | | San-m: Memory equipped self-attention for end-to-end speech recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | """ |
| | | |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | |
| | | ) |
| | | if attention_dim is None: |
| | | attention_dim = encoder_output_size |
| | | |
| | | |
| | | if input_layer == "embed": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(vocab_size, attention_dim), |
| | |
| | | ) |
| | | else: |
| | | raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") |
| | | |
| | | |
| | | self.normalize_before = normalize_before |
| | | if self.normalize_before: |
| | | self.after_norm = LayerNorm(attention_dim) |
| | |
| | | self.output_layer = torch.nn.Linear(attention_dim, vocab_size) |
| | | else: |
| | | self.output_layer = None |
| | | |
| | | |
| | | self.att_layer_num = att_layer_num |
| | | self.num_blocks = num_blocks |
| | | if sanm_shfit is None: |
| | |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit |
| | | ), |
| | | MultiHeadedAttentionCrossAtt( |
| | | attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size |
| | | attention_heads, |
| | | attention_dim, |
| | | src_attention_dropout_rate, |
| | | encoder_output_size=encoder_output_size, |
| | | ), |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit |
| | | attention_dim, |
| | | self_attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit=sanm_shfit, |
| | | ), |
| | | None, |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |
| | |
| | | concat_after, |
| | | ), |
| | | ) |
| | | |
| | | |
| | | self.decoders3 = repeat( |
| | | 1, |
| | | lambda lnum: DecoderLayerSANM( |
| | |
| | | attention_dim + encoder_output_size, |
| | | None, |
| | | None, |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate, |
| | | adim=attention_dim), |
| | | PositionwiseFeedForwardDecoderSANM( |
| | | attention_dim + encoder_output_size, |
| | | linear_units, |
| | | dropout_rate, |
| | | adim=attention_dim, |
| | | ), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | |
| | | self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch |
| | | self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf |
| | | self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf |
| | | |
| | | |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | |
| | | """ |
| | | tgt = ys_in_pad |
| | | tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] |
| | | |
| | | |
| | | memory = hs_pad |
| | | memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | if chunk_mask is not None: |
| | | memory_mask = memory_mask * chunk_mask |
| | | if tgt_mask.size(1) != memory_mask.size(1): |
| | | memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1) |
| | | |
| | | |
| | | x = self.embed(tgt) |
| | | |
| | | |
| | | if pre_acoustic_embeds is not None and self.concat_embeds: |
| | | x = torch.cat((x, pre_acoustic_embeds), dim=-1) |
| | | x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None) |
| | | |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask) |
| | | if self.decoders2 is not None: |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders2( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders3( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask) |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask) |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | if self.output_layer is not None: |
| | | x = self.output_layer(x) |
| | | |
| | | |
| | | olens = tgt_mask.sum(1) |
| | | return x, olens |
| | | |
| | | def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ): |
| | | |
| | | def score( |
| | | self, |
| | | ys, |
| | | state, |
| | | x, |
| | | x_mask=None, |
| | | pre_acoustic_embeds: torch.Tensor = None, |
| | | ): |
| | | """Score.""" |
| | | ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None] |
| | | ys_mask = myutils.sequence_mask( |
| | | torch.tensor([len(ys)], dtype=torch.int32), device=x.device |
| | | )[:, :, None] |
| | | logp, state = self.forward_one_step( |
| | | ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, |
| | | cache=state |
| | | ys.unsqueeze(0), |
| | | ys_mask, |
| | | x.unsqueeze(0), |
| | | memory_mask=x_mask, |
| | | pre_acoustic_embeds=pre_acoustic_embeds, |
| | | cache=state, |
| | | ) |
| | | return logp.squeeze(0), state |
| | | |
| | | |
| | | def forward_one_step( |
| | | self, |
| | | tgt: torch.Tensor, |
| | |
| | | y, cache: NN output value and cache per `self.decoders`. |
| | | y.shape` is (batch, maxlen_out, token) |
| | | """ |
| | | |
| | | |
| | | x = tgt[:, -1:] |
| | | tgt_mask = None |
| | | x = self.embed(x) |
| | | |
| | | |
| | | if pre_acoustic_embeds is not None and self.concat_embeds: |
| | | x = torch.cat((x, pre_acoustic_embeds), dim=-1) |
| | | x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None) |
| | | |
| | | |
| | | if cache is None: |
| | | cache_layer_num = len(self.decoders) |
| | | if self.decoders2 is not None: |
| | |
| | | x, tgt_mask, memory, memory_mask, cache=c |
| | | ) |
| | | new_cache.append(c_ret) |
| | | |
| | | |
| | | if self.num_blocks - self.att_layer_num >= 1: |
| | | for i in range(self.num_blocks - self.att_layer_num): |
| | | j = i + self.att_layer_num |
| | |
| | | x, tgt_mask, memory, memory_mask, cache=c |
| | | ) |
| | | new_cache.append(c_ret) |
| | | |
| | | |
| | | for decoder in self.decoders3: |
| | | x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step( |
| | | x, tgt_mask, memory, None, cache=None |
| | | ) |
| | | |
| | | |
| | | if self.normalize_before: |
| | | y = self.after_norm(x[:, -1]) |
| | | else: |
| | |
| | | if self.output_layer is not None: |
| | | y = self.output_layer(y) |
| | | y = torch.log_softmax(y, dim=-1) |
| | | |
| | | |
| | | return y, new_cache |
| | | |
| | | |