From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/models/sanm/encoder.py |  404 +++++++++++++++++++++++++++++++++++++++++++++++++++++++--
 1 files changed, 387 insertions(+), 17 deletions(-)

diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
index cb4e21a..b2a442b 100644
--- a/funasr/models/sanm/encoder.py
+++ b/funasr/models/sanm/encoder.py
@@ -1,3 +1,8 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 from typing import List
 from typing import Optional
 from typing import Sequence
@@ -12,7 +17,10 @@
 from funasr.train_utils.device_funcs import to_device
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.embedding import (
+    SinusoidalPositionEncoder,
+    StreamSinusoidalPositionEncoder,
+)
 from funasr.models.transformer.layer_norm import LayerNorm
 from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
 from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
@@ -31,6 +39,7 @@
 from funasr.models.ctc.ctc import CTC
 
 from funasr.register import tables
+
 
 class EncoderLayerSANM(nn.Module):
     def __init__(
@@ -91,7 +100,18 @@
             x = self.norm1(x)
 
         if self.concat_after:
-            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+            x_concat = torch.cat(
+                (
+                    x,
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    ),
+                ),
+                dim=-1,
+            )
             if self.in_size == self.size:
                 x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
             else:
@@ -99,11 +119,21 @@
         else:
             if self.in_size == self.size:
                 x = residual + stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    )
                 )
             else:
                 x = stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    )
                 )
         if not self.normalize_before:
             x = self.norm1(x)
@@ -153,13 +183,13 @@
 
         return x, cache
 
+
 @tables.register("encoder_classes", "SANMEncoder")
 class SANMEncoder(nn.Module):
     """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin
     San-m: Memory equipped self-attention for end-to-end speech recognition
     https://arxiv.org/abs/2006.01713
-
     """
 
     def __init__(
@@ -181,8 +211,8 @@
         padding_idx: int = -1,
         interctc_layer_idx: List[int] = [],
         interctc_use_conditioning: bool = False,
-        kernel_size : int = 11,
-        sanm_shfit : int = 0,
+        kernel_size: int = 11,
+        sanm_shfit: int = 0,
         lora_list: List[str] = None,
         lora_rank: int = 8,
         lora_alpha: int = 16,
@@ -302,7 +332,7 @@
         )
 
         self.encoders = repeat(
-            num_blocks-1,
+            num_blocks - 1,
             lambda lnum: EncoderLayerSANM(
                 output_size,
                 output_size,
@@ -345,7 +375,7 @@
             position embedded tensor and mask
         """
         masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        xs_pad = xs_pad * self.output_size()**0.5
+        xs_pad = xs_pad * self.output_size() ** 0.5
         if self.embed is None:
             xs_pad = xs_pad
         elif (
@@ -404,15 +434,16 @@
             return feats
         cache["feats"] = to_device(cache["feats"], device=feats.device)
         overlap_feats = torch.cat((cache["feats"], feats), dim=1)
-        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]) :, :]
         return overlap_feats
 
-    def forward_chunk(self,
-                      xs_pad: torch.Tensor,
-                      ilens: torch.Tensor,
-                      cache: dict = None,
-                      ctc: CTC = None,
-                      ):
+    def forward_chunk(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        cache: dict = None,
+        ctc: CTC = None,
+    ):
         xs_pad *= self.output_size() ** 0.5
         if self.embed is None:
             xs_pad = xs_pad
@@ -452,3 +483,342 @@
             return (xs_pad, intermediate_outs), None, None
         return xs_pad, ilens, None
 
+
+@tables.register("encoder_classes", "SANMTPEncoder")
+class SANMTPEncoder(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
+    https://arxiv.org/abs/2006.01713
+    """
+    def __init__(
+            self,
+            input_size: int,
+            output_size: int = 256,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            tp_blocks: int = 0,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            attention_dropout_rate: float = 0.0,
+            stochastic_depth_rate: float = 0.0,
+            input_layer: Optional[str] = "conv2d",
+            pos_enc_class=SinusoidalPositionEncoder,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            positionwise_layer_type: str = "linear",
+            positionwise_conv_kernel_size: int = 1,
+            padding_idx: int = -1,
+            kernel_size: int = 11,
+            sanm_shfit: int = 0,
+            selfattention_layer_type: str = "sanm",
+    ):
+        super().__init__()
+        self._output_size = output_size
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                eval(pos_enc_class)(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "linear_no_pos":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                eval(pos_enc_class)(output_size, positional_dropout_rate, use_pos=False),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                eval(pos_enc_class)(output_size, positional_dropout_rate),
+            )
+        elif input_layer is None:
+            if input_size == output_size:
+                self.embed = None
+            else:
+                self.embed = torch.nn.Linear(input_size, output_size)
+        elif input_layer == "pe":
+            self.embed = SinusoidalPositionEncoder()
+        elif input_layer == "pe_online":
+            self.embed = StreamSinusoidalPositionEncoder()
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+        if selfattention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+        elif selfattention_layer_type == "sanm":
+            encoder_selfattn_layer = MultiHeadedAttentionSANM
+            encoder_selfattn_layer_args0 = (
+                attention_heads,
+                input_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+        self.encoders0 = repeat(
+            1,
+            lambda lnum: EncoderLayerSANM(
+                input_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        self.encoders = repeat(
+            num_blocks - 1,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+                stochastic_depth_rate,
+            ),
+        )
+        self.tp_encoders = repeat(
+            tp_blocks,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+                stochastic_depth_rate,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+        self.tp_blocks = tp_blocks
+        if self.tp_blocks > 0:
+            self.tp_norm = LayerNorm(output_size)
+    def output_size(self) -> int:
+        return self._output_size
+    def forward(
+            self,
+            xs_pad: torch.Tensor,
+            ilens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        xs_pad *= self.output_size() ** 0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (
+                isinstance(self.embed, Conv2dSubsampling)
+                or isinstance(self.embed, Conv2dSubsampling2)
+                or isinstance(self.embed, Conv2dSubsampling6)
+                or isinstance(self.embed, Conv2dSubsampling8)
+        ):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
+                    + f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+        # forward encoder1
+        mask_shfit_chunk, mask_att_chunk_encoder = None, None
+        encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+        # forward encoder2
+        olens = masks.squeeze(1).sum(1)
+        mask_shfit_chunk2, mask_att_chunk_encoder2 = None, None
+        for layer_idx, encoder_layer in enumerate(self.tp_encoders):
+            encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk2, mask_att_chunk_encoder2)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        if self.tp_blocks > 0:
+            xs_pad = self.tp_norm(xs_pad)
+        return xs_pad, olens
+
+
+class EncoderLayerSANMExport(nn.Module):
+    def __init__(
+        self,
+        model,
+    ):
+        """Construct an EncoderLayer object."""
+        super().__init__()
+        self.self_attn = model.self_attn
+        self.feed_forward = model.feed_forward
+        self.norm1 = model.norm1
+        self.norm2 = model.norm2
+        self.in_size = model.in_size
+        self.size = model.size
+
+    def forward(self, x, mask):
+
+        residual = x
+        x = self.norm1(x)
+        x = self.self_attn(x, mask)
+        if self.in_size == self.size:
+            x = x + residual
+        residual = x
+        x = self.norm2(x)
+        x = self.feed_forward(x)
+        x = x + residual
+
+        return x, mask
+
+
+@tables.register("encoder_classes", "SANMEncoderChunkOptExport")
+@tables.register("encoder_classes", "SANMEncoderExport")
+class SANMEncoderExport(nn.Module):
+    def __init__(
+        self,
+        model,
+        max_seq_len=512,
+        feats_dim=560,
+        model_name="encoder",
+        onnx: bool = True,
+    ):
+        super().__init__()
+        self.embed = model.embed
+        if isinstance(self.embed, StreamSinusoidalPositionEncoder):
+            self.embed = None
+        self.model = model
+        self.feats_dim = feats_dim
+        self._output_size = model._output_size
+
+        from funasr.utils.torch_function import sequence_mask
+
+        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+        from funasr.models.sanm.attention import MultiHeadedAttentionSANMExport
+
+        if hasattr(model, "encoders0"):
+            for i, d in enumerate(self.model.encoders0):
+                if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+                    d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+                self.model.encoders0[i] = EncoderLayerSANMExport(d)
+
+        for i, d in enumerate(self.model.encoders):
+            if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+                d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+            self.model.encoders[i] = EncoderLayerSANMExport(d)
+
+        self.model_name = model_name
+        self.num_heads = model.encoders[0].self_attn.h
+        self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
+
+    def prepare_mask(self, mask):
+        mask_3d_btd = mask[:, :, None]
+        if len(mask.shape) == 2:
+            mask_4d_bhlt = 1 - mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask_4d_bhlt = 1 - mask[:, None, :]
+        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+        return mask_3d_btd, mask_4d_bhlt
+
+    def forward(self, speech: torch.Tensor, speech_lengths: torch.Tensor, online: bool = False):
+        if not online:
+            speech = speech * self._output_size**0.5
+        mask = self.make_pad_mask(speech_lengths)
+        mask = self.prepare_mask(mask)
+        if self.embed is None:
+            xs_pad = speech
+        else:
+            xs_pad = self.embed(speech)
+
+        encoder_outs = self.model.encoders0(xs_pad, mask)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+        encoder_outs = self.model.encoders(xs_pad, mask)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+        xs_pad = self.model.after_norm(xs_pad)
+
+        return xs_pad, speech_lengths
+
+    def get_output_size(self):
+        return self.model.encoders[0].size
+
+    def get_dummy_inputs(self):
+        feats = torch.randn(1, 100, self.feats_dim)
+        return feats
+
+    def get_input_names(self):
+        return ["feats"]
+
+    def get_output_names(self):
+        return ["encoder_out", "encoder_out_lens", "predictor_weight"]
+
+    def get_dynamic_axes(self):
+        return {
+            "feats": {1: "feats_length"},
+            "encoder_out": {1: "enc_out_length"},
+            "predictor_weight": {1: "pre_out_length"},
+        }

--
Gitblit v1.9.1