From fa3e8359835107aa9ab8e3ae604ed61cad407bf8 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 14 六月 2024 10:38:02 +0800
Subject: [PATCH] update with main (#1816)

---
 funasr/models/sanm/encoder.py |  220 +++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 220 insertions(+), 0 deletions(-)

diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
index dc30a94..b2a442b 100644
--- a/funasr/models/sanm/encoder.py
+++ b/funasr/models/sanm/encoder.py
@@ -484,6 +484,226 @@
         return xs_pad, ilens, None
 
 
+@tables.register("encoder_classes", "SANMTPEncoder")
+class SANMTPEncoder(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
+    https://arxiv.org/abs/2006.01713
+    """
+    def __init__(
+            self,
+            input_size: int,
+            output_size: int = 256,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            tp_blocks: int = 0,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            attention_dropout_rate: float = 0.0,
+            stochastic_depth_rate: float = 0.0,
+            input_layer: Optional[str] = "conv2d",
+            pos_enc_class=SinusoidalPositionEncoder,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            positionwise_layer_type: str = "linear",
+            positionwise_conv_kernel_size: int = 1,
+            padding_idx: int = -1,
+            kernel_size: int = 11,
+            sanm_shfit: int = 0,
+            selfattention_layer_type: str = "sanm",
+    ):
+        super().__init__()
+        self._output_size = output_size
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                eval(pos_enc_class)(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "linear_no_pos":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                eval(pos_enc_class)(output_size, positional_dropout_rate, use_pos=False),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                eval(pos_enc_class)(output_size, positional_dropout_rate),
+            )
+        elif input_layer is None:
+            if input_size == output_size:
+                self.embed = None
+            else:
+                self.embed = torch.nn.Linear(input_size, output_size)
+        elif input_layer == "pe":
+            self.embed = SinusoidalPositionEncoder()
+        elif input_layer == "pe_online":
+            self.embed = StreamSinusoidalPositionEncoder()
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+        if selfattention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+        elif selfattention_layer_type == "sanm":
+            encoder_selfattn_layer = MultiHeadedAttentionSANM
+            encoder_selfattn_layer_args0 = (
+                attention_heads,
+                input_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+        self.encoders0 = repeat(
+            1,
+            lambda lnum: EncoderLayerSANM(
+                input_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        self.encoders = repeat(
+            num_blocks - 1,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+                stochastic_depth_rate,
+            ),
+        )
+        self.tp_encoders = repeat(
+            tp_blocks,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+                stochastic_depth_rate,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+        self.tp_blocks = tp_blocks
+        if self.tp_blocks > 0:
+            self.tp_norm = LayerNorm(output_size)
+    def output_size(self) -> int:
+        return self._output_size
+    def forward(
+            self,
+            xs_pad: torch.Tensor,
+            ilens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        xs_pad *= self.output_size() ** 0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (
+                isinstance(self.embed, Conv2dSubsampling)
+                or isinstance(self.embed, Conv2dSubsampling2)
+                or isinstance(self.embed, Conv2dSubsampling6)
+                or isinstance(self.embed, Conv2dSubsampling8)
+        ):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
+                    + f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+        # forward encoder1
+        mask_shfit_chunk, mask_att_chunk_encoder = None, None
+        encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+        # forward encoder2
+        olens = masks.squeeze(1).sum(1)
+        mask_shfit_chunk2, mask_att_chunk_encoder2 = None, None
+        for layer_idx, encoder_layer in enumerate(self.tp_encoders):
+            encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk2, mask_att_chunk_encoder2)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        if self.tp_blocks > 0:
+            xs_pad = self.tp_norm(xs_pad)
+        return xs_pad, olens
+
+
 class EncoderLayerSANMExport(nn.Module):
     def __init__(
         self,

--
Gitblit v1.9.1