From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/fsmn_vad_streaming/encoder.py |  189 +++++++++++++++++++++++++++++-----------------
 1 files changed, 119 insertions(+), 70 deletions(-)

diff --git a/funasr/models/fsmn_vad_streaming/encoder.py b/funasr/models/fsmn_vad_streaming/encoder.py
index e7c0e8b..6668c5d 100755
--- a/funasr/models/fsmn_vad_streaming/encoder.py
+++ b/funasr/models/fsmn_vad_streaming/encoder.py
@@ -9,6 +9,7 @@
 
 from funasr.register import tables
 
+
 class LinearTransform(nn.Module):
 
     def __init__(self, input_dim, output_dim):
@@ -53,13 +54,13 @@
 class FSMNBlock(nn.Module):
 
     def __init__(
-            self,
-            input_dim: int,
-            output_dim: int,
-            lorder=None,
-            rorder=None,
-            lstride=1,
-            rstride=1,
+        self,
+        input_dim: int,
+        output_dim: int,
+        lorder=None,
+        rorder=None,
+        lstride=1,
+        rstride=1,
     ):
         super(FSMNBlock, self).__init__()
 
@@ -74,28 +75,30 @@
         self.rstride = rstride
 
         self.conv_left = nn.Conv2d(
-            self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
+            self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False
+        )
 
         if self.rorder > 0:
             self.conv_right = nn.Conv2d(
-                self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
+                self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False
+            )
         else:
             self.conv_right = None
 
     def forward(self, input: torch.Tensor, cache: torch.Tensor):
         x = torch.unsqueeze(input, 1)
         x_per = x.permute(0, 3, 2, 1)  # B D T C
-        
+
         cache = cache.to(x_per.device)
         y_left = torch.cat((cache, x_per), dim=2)
-        cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
+        cache = y_left[:, :, -(self.lorder - 1) * self.lstride :, :]
         y_left = self.conv_left(y_left)
         out = x_per + y_left
 
         if self.conv_right is not None:
             # maybe need to check
             y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
-            y_right = y_right[:, :, self.rstride:, :]
+            y_right = y_right[:, :, self.rstride :, :]
             y_right = self.conv_right(y_right)
             out += y_right
 
@@ -106,15 +109,16 @@
 
 
 class BasicBlock(nn.Module):
-    def __init__(self,
-                 linear_dim: int,
-                 proj_dim: int,
-                 lorder: int,
-                 rorder: int,
-                 lstride: int,
-                 rstride: int,
-                 stack_layer: int
-                 ):
+    def __init__(
+        self,
+        linear_dim: int,
+        proj_dim: int,
+        lorder: int,
+        rorder: int,
+        lstride: int,
+        rstride: int,
+        stack_layer: int,
+    ):
         super(BasicBlock, self).__init__()
         self.lorder = lorder
         self.rorder = rorder
@@ -128,17 +132,22 @@
 
     def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
         x1 = self.linear(input)  # B T D
-        cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+        cache_layer_name = "cache_layer_{}".format(self.stack_layer)
         if cache_layer_name not in cache:
-            cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+            cache[cache_layer_name] = torch.zeros(
+                x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1
+            )
         x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
         x3 = self.affine(x2)
         x4 = self.relu(x3)
         return x4
+
+
 class BasicBlock_export(nn.Module):
-    def __init__(self,
-                 model,
-                 ):
+    def __init__(
+        self,
+        model,
+    ):
         super(BasicBlock_export, self).__init__()
         self.linear = model.linear
         self.fsmn_block = model.fsmn_block
@@ -167,7 +176,7 @@
         return x
 
 
-'''
+"""
 FSMN net for keyword spotting
 input_dim:              input dimension
 linear_dim:             fsmn input dimensionll
@@ -176,25 +185,26 @@
 rorder:                 fsmn right order
 num_syn:                output dimension
 fsmn_layers:            no. of sequential fsmn layers
-'''
+"""
+
 
 @tables.register("encoder_classes", "FSMN")
 class FSMN(nn.Module):
     def __init__(
-            self,
-            input_dim: int,
-            input_affine_dim: int,
-            fsmn_layers: int,
-            linear_dim: int,
-            proj_dim: int,
-            lorder: int,
-            rorder: int,
-            lstride: int,
-            rstride: int,
-            output_affine_dim: int,
-            output_dim: int
+        self,
+        input_dim: int,
+        input_affine_dim: int,
+        fsmn_layers: int,
+        linear_dim: int,
+        proj_dim: int,
+        lorder: int,
+        rorder: int,
+        lstride: int,
+        rstride: int,
+        output_affine_dim: int,
+        output_dim: int,
     ):
-        super(FSMN, self).__init__()
+        super().__init__()
 
         self.input_dim = input_dim
         self.input_affine_dim = input_affine_dim
@@ -207,25 +217,21 @@
         self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
         self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
         self.relu = RectifiedLinear(linear_dim, linear_dim)
-        self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
-                                range(fsmn_layers)])
+        self.fsmn = FsmnStack(
+            *[
+                BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i)
+                for i in range(fsmn_layers)
+            ]
+        )
         self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
         self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
         self.softmax = nn.Softmax(dim=-1)
-        
-        # export onnx or torchscripts
-        if "EXPORTING_MODEL" in os.environ and os.environ['EXPORTING_MODEL'] == 'TRUE':
-            for i, d in enumerate(self.fsmn):
-                if isinstance(d, BasicBlock):
-                    self.fsmn[i] = BasicBlock_export(d)
 
     def fuse_modules(self):
         pass
 
     def forward(
-            self,
-            input: torch.Tensor,
-            cache: Dict[str, torch.Tensor]
+        self, input: torch.Tensor, cache: Dict[str, torch.Tensor]
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
         """
         Args:
@@ -244,10 +250,51 @@
 
         return x7
 
-    def export_forward(
-            self,
-            input: torch.Tensor,
-            *args,
+
+@tables.register("encoder_classes", "FSMNExport")
+class FSMNExport(nn.Module):
+    def __init__(
+        self,
+        model,
+        **kwargs,
+    ):
+        super().__init__()
+
+        # self.input_dim = input_dim
+        # self.input_affine_dim = input_affine_dim
+        # self.fsmn_layers = fsmn_layers
+        # self.linear_dim = linear_dim
+        # self.proj_dim = proj_dim
+        # self.output_affine_dim = output_affine_dim
+        # self.output_dim = output_dim
+        #
+        # self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
+        # self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
+        # self.relu = RectifiedLinear(linear_dim, linear_dim)
+        # self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
+        #                         range(fsmn_layers)])
+        # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
+        # self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
+        # self.softmax = nn.Softmax(dim=-1)
+        self.in_linear1 = model.in_linear1
+        self.in_linear2 = model.in_linear2
+        self.relu = model.relu
+        # self.fsmn = model.fsmn
+        self.out_linear1 = model.out_linear1
+        self.out_linear2 = model.out_linear2
+        self.softmax = model.softmax
+        self.fsmn = model.fsmn
+        for i, d in enumerate(model.fsmn):
+            if isinstance(d, BasicBlock):
+                self.fsmn[i] = BasicBlock_export(d)
+
+    def fuse_modules(self):
+        pass
+
+    def forward(
+        self,
+        input: torch.Tensor,
+        *args,
     ):
         """
         Args:
@@ -271,7 +318,8 @@
 
         return x, out_caches
 
-'''
+
+"""
 one deep fsmn layer
 dimproj:                projection dimension, input and output dimension of memory blocks
 dimlinear:              dimension of mapping layer
@@ -279,7 +327,8 @@
 rorder:                 right order
 lstride:                left stride
 rstride:                right stride
-'''
+"""
+
 
 @tables.register("encoder_classes", "DFSMN")
 class DFSMN(nn.Module):
@@ -296,11 +345,13 @@
         self.shrink = LinearTransform(dimlinear, dimproj)
 
         self.conv_left = nn.Conv2d(
-            dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
+            dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False
+        )
 
         if rorder > 0:
             self.conv_right = nn.Conv2d(
-                dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
+                dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False
+            )
         else:
             self.conv_right = None
 
@@ -315,7 +366,7 @@
 
         if self.conv_right is not None:
             y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
-            y_right = y_right[:, :, self.rstride:, :]
+            y_right = y_right[:, :, self.rstride :, :]
             out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
         else:
             out = x_per + self.conv_left(y_left)
@@ -326,30 +377,28 @@
         return output
 
 
-'''
+"""
 build stacked dfsmn layers
-'''
+"""
 
 
 def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
     repeats = [
-        nn.Sequential(
-            DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
-        for i in range(fsmn_layers)
+        nn.Sequential(DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) for i in range(fsmn_layers)
     ]
 
     return nn.Sequential(*repeats)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
     print(fsmn)
 
     num_params = sum(p.numel() for p in fsmn.parameters())
-    print('the number of model params: {}'.format(num_params))
+    print("the number of model params: {}".format(num_params))
     x = torch.zeros(128, 200, 400)  # batch-size * time * dim
     y, _ = fsmn(x)  # batch-size * time * dim
-    print('input shape: {}'.format(x.shape))
-    print('output shape: {}'.format(y.shape))
+    print("input shape: {}".format(x.shape))
+    print("output shape: {}".format(y.shape))
 
     print(fsmn.to_kaldi_net())

--
Gitblit v1.9.1