From 4ace5a95b052d338947fc88809a440ccd55cf6b4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 十一月 2023 16:39:52 +0800
Subject: [PATCH] funasr pages

---
 funasr/models/encoder/sanm_encoder.py |  395 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 385 insertions(+), 10 deletions(-)

diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 4c4bd7c..c15343e 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -6,12 +6,13 @@
 import logging
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
-from typeguard import check_argument_types
 import numpy as np
+from funasr.torch_utils.device_funcs import to_device
 from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
-from funasr.modules.embedding import SinusoidalPositionEncoder
+from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
+from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
 from funasr.modules.layer_norm import LayerNorm
 from funasr.modules.multi_layer_conv import Conv1dLinear
 from funasr.modules.multi_layer_conv import MultiLayeredConv1d
@@ -25,9 +26,10 @@
 from funasr.modules.subsampling import Conv2dSubsampling8
 from funasr.modules.subsampling import TooShortUttError
 from funasr.modules.subsampling import check_short_utt
+from funasr.modules.mask import subsequent_mask, vad_mask
+
 from funasr.models.ctc import CTC
 from funasr.models.encoder.abs_encoder import AbsEncoder
-
 
 class EncoderLayerSANM(nn.Module):
     def __init__(
@@ -112,12 +114,48 @@
         if not self.normalize_before:
             x = self.norm2(x)
 
-
         return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+
+    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if self.in_size == self.size:
+            attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+            x = residual + attn
+        else:
+            x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + self.feed_forward(x)
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        return x, cache
+
 
 class SANMEncoder(AbsEncoder):
     """
-    author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     San-m: Memory equipped self-attention for end-to-end speech recognition
     https://arxiv.org/abs/2006.01713
 
@@ -144,11 +182,14 @@
         interctc_use_conditioning: bool = False,
         kernel_size : int = 11,
         sanm_shfit : int = 0,
+        lora_list: List[str] = None,
+        lora_rank: int = 8,
+        lora_alpha: int = 16,
+        lora_dropout: float = 0.1,
         selfattention_layer_type: str = "sanm",
         tf2torch_tensor_name_prefix_torch: str = "encoder",
         tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
     ):
-        assert check_argument_types()
         super().__init__()
         self._output_size = output_size
 
@@ -180,6 +221,8 @@
                 self.embed = torch.nn.Linear(input_size, output_size)
         elif input_layer == "pe":
             self.embed = SinusoidalPositionEncoder()
+        elif input_layer == "pe_online":
+            self.embed = StreamSinusoidalPositionEncoder()
         else:
             raise ValueError("unknown input_layer: " + input_layer)
         self.normalize_before = normalize_before
@@ -226,6 +269,10 @@
                 attention_dropout_rate,
                 kernel_size,
                 sanm_shfit,
+                lora_list,
+                lora_rank,
+                lora_alpha,
+                lora_dropout,
             )
 
             encoder_selfattn_layer_args = (
@@ -235,6 +282,10 @@
                 attention_dropout_rate,
                 kernel_size,
                 sanm_shfit,
+                lora_list,
+                lora_rank,
+                lora_alpha,
+                lora_dropout,
             )
         self.encoders0 = repeat(
             1,
@@ -293,7 +344,7 @@
             position embedded tensor and mask
         """
         masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        xs_pad *= self.output_size()**0.5
+        xs_pad = xs_pad * self.output_size()**0.5
         if self.embed is None:
             xs_pad = xs_pad
         elif (
@@ -346,6 +397,59 @@
         if len(intermediate_outs) > 0:
             return (xs_pad, intermediate_outs), olens, None
         return xs_pad, olens, None
+
+    def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
+        if len(cache) == 0:
+            return feats
+        cache["feats"] = to_device(cache["feats"], device=feats.device)
+        overlap_feats = torch.cat((cache["feats"], feats), dim=1)
+        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+        return overlap_feats
+
+    def forward_chunk(self,
+                      xs_pad: torch.Tensor,
+                      ilens: torch.Tensor,
+                      cache: dict = None,
+                      ctc: CTC = None,
+                      ):
+        xs_pad *= self.output_size() ** 0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        else:
+            xs_pad = self.embed(xs_pad, cache)
+        if cache["tail_chunk"]:
+            xs_pad = to_device(cache["feats"], device=xs_pad.device)
+        else:
+            xs_pad = self._add_overlap_chunk(xs_pad, cache)
+        encoder_outs = self.encoders0(xs_pad, None, None, None, None)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+        if len(self.interctc_layer_idx) == 0:
+            encoder_outs = self.encoders(xs_pad, None, None, None, None)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        else:
+            for layer_idx, encoder_layer in enumerate(self.encoders):
+                encoder_outs = encoder_layer(xs_pad, None, None, None, None)
+                xs_pad, masks = encoder_outs[0], encoder_outs[1]
+                if layer_idx + 1 in self.interctc_layer_idx:
+                    encoder_out = xs_pad
+
+                    # intermediate outputs are also normalized
+                    if self.normalize_before:
+                        encoder_out = self.after_norm(encoder_out)
+
+                    intermediate_outs.append((layer_idx + 1, encoder_out))
+
+                    if self.interctc_use_conditioning:
+                        ctc_out = ctc.softmax(encoder_out)
+                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), None, None
+        return xs_pad, ilens, None
 
     def gen_tf2torch_map_dict(self):
         tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
@@ -507,7 +611,7 @@
 
 class SANMEncoderChunkOpt(AbsEncoder):
     """
-    author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
     https://arxiv.org/abs/2006.01713
 
@@ -543,7 +647,6 @@
             tf2torch_tensor_name_prefix_torch: str = "encoder",
             tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
     ):
-        assert check_argument_types()
         super().__init__()
         self._output_size = output_size
 
@@ -575,6 +678,8 @@
                 self.embed = torch.nn.Linear(input_size, output_size)
         elif input_layer == "pe":
             self.embed = SinusoidalPositionEncoder()
+        elif input_layer == "pe_online":
+            self.embed = StreamSinusoidalPositionEncoder()
         else:
             raise ValueError("unknown input_layer: " + input_layer)
         self.normalize_before = normalize_before
@@ -760,6 +865,49 @@
             return (xs_pad, intermediate_outs), olens, None
         return xs_pad, olens, None
 
+    def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
+        if len(cache) == 0:
+            return feats
+        cache["feats"] = to_device(cache["feats"], device=feats.device)
+        overlap_feats = torch.cat((cache["feats"], feats), dim=1)
+        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+        return overlap_feats
+
+    def forward_chunk(self,
+                      xs_pad: torch.Tensor,
+                      ilens: torch.Tensor,
+                      cache: dict = None,
+                      ):
+        xs_pad *= self.output_size() ** 0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        else:
+            xs_pad = self.embed(xs_pad, cache)
+        if cache["tail_chunk"]:
+            xs_pad = to_device(cache["feats"], device=xs_pad.device)
+        else:
+            xs_pad = self._add_overlap_chunk(xs_pad, cache)
+        if cache["opt"] is None:
+            cache_layer_num = len(self.encoders0) + len(self.encoders)
+            new_cache = [None] * cache_layer_num
+        else:
+            new_cache = cache["opt"]
+
+        for layer_idx, encoder_layer in enumerate(self.encoders0):
+            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
+            xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
+
+        for layer_idx, encoder_layer in enumerate(self.encoders):
+            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
+            xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+        if cache["encoder_chunk_look_back"] > 0 or cache["encoder_chunk_look_back"] == -1:
+            cache["opt"] = new_cache
+
+        return xs_pad, ilens, None
+
     def gen_tf2torch_map_dict(self):
         tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
         tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
@@ -916,3 +1064,230 @@
                                                                                       var_dict_tf[name_tf].shape))
     
         return var_dict_torch_update
+
+
+class SANMVadEncoder(AbsEncoder):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: Optional[str] = "conv2d",
+        pos_enc_class=SinusoidalPositionEncoder,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 1,
+        padding_idx: int = -1,
+        interctc_layer_idx: List[int] = [],
+        interctc_use_conditioning: bool = False,
+        kernel_size : int = 11,
+        sanm_shfit : int = 0,
+        selfattention_layer_type: str = "sanm",
+    ):
+        super().__init__()
+        self._output_size = output_size
+
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                SinusoidalPositionEncoder(),
+            )
+        elif input_layer is None:
+            if input_size == output_size:
+                self.embed = None
+            else:
+                self.embed = torch.nn.Linear(input_size, output_size)
+        elif input_layer == "pe":
+            self.embed = SinusoidalPositionEncoder()
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+
+        if selfattention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+
+        elif selfattention_layer_type == "sanm":
+            self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
+            encoder_selfattn_layer_args0 = (
+                attention_heads,
+                input_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+
+        self.encoders0 = repeat(
+            1,
+            lambda lnum: EncoderLayerSANM(
+                input_size,
+                output_size,
+                self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+        self.encoders = repeat(
+            num_blocks-1,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+
+        self.interctc_layer_idx = interctc_layer_idx
+        if len(interctc_layer_idx) > 0:
+            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+        self.interctc_use_conditioning = interctc_use_conditioning
+        self.conditioning_layer = None
+        self.dropout = nn.Dropout(dropout_rate)
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        vad_indexes: torch.Tensor,
+        prev_states: torch.Tensor = None,
+        ctc: CTC = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
+        no_future_masks = masks & sub_masks
+        xs_pad *= self.output_size()**0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
+              or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling " +
+                    f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+
+        # xs_pad = self.dropout(xs_pad)
+        mask_tup0 = [masks, no_future_masks]
+        encoder_outs = self.encoders0(xs_pad, mask_tup0)
+        xs_pad, _ = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+
+
+        for layer_idx, encoder_layer in enumerate(self.encoders):
+                if layer_idx + 1 == len(self.encoders):
+                    # This is last layer.
+                    coner_mask = torch.ones(masks.size(0),
+                                            masks.size(-1),
+                                            masks.size(-1),
+                                            device=xs_pad.device,
+                                            dtype=torch.bool)
+                    for word_index, length in enumerate(ilens):
+                        coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
+                                                                vad_indexes[word_index],
+                                                                device=xs_pad.device)
+                    layer_mask = masks & coner_mask
+                else:
+                    layer_mask = no_future_masks
+                mask_tup1 = [masks, layer_mask]
+                encoder_outs = encoder_layer(xs_pad, mask_tup1)
+                xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        olens = masks.squeeze(1).sum(1)
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), olens, None
+        return xs_pad, olens, None

--
Gitblit v1.9.1