zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/fsmn_vad_streaming/encoder.py
@@ -9,6 +9,7 @@
from funasr.register import tables
class LinearTransform(nn.Module):
    def __init__(self, input_dim, output_dim):
@@ -74,11 +75,13 @@
        self.rstride = rstride
        self.conv_left = nn.Conv2d(
            self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
            self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False
        )
        if self.rorder > 0:
            self.conv_right = nn.Conv2d(
                self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
                self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False
            )
        else:
            self.conv_right = None
@@ -106,14 +109,15 @@
class BasicBlock(nn.Module):
    def __init__(self,
    def __init__(
        self,
                 linear_dim: int,
                 proj_dim: int,
                 lorder: int,
                 rorder: int,
                 lstride: int,
                 rstride: int,
                 stack_layer: int
        stack_layer: int,
                 ):
        super(BasicBlock, self).__init__()
        self.lorder = lorder
@@ -128,15 +132,20 @@
    def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
        x1 = self.linear(input)  # B T D
        cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
        cache_layer_name = "cache_layer_{}".format(self.stack_layer)
        if cache_layer_name not in cache:
            cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
            cache[cache_layer_name] = torch.zeros(
                x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1
            )
        x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
        x3 = self.affine(x2)
        x4 = self.relu(x3)
        return x4
class BasicBlock_export(nn.Module):
    def __init__(self,
    def __init__(
        self,
                 model,
                 ):
        super(BasicBlock_export, self).__init__()
@@ -167,7 +176,7 @@
        return x
'''
"""
FSMN net for keyword spotting
input_dim:              input dimension
linear_dim:             fsmn input dimensionll
@@ -176,7 +185,8 @@
rorder:                 fsmn right order
num_syn:                output dimension
fsmn_layers:            no. of sequential fsmn layers
'''
"""
@tables.register("encoder_classes", "FSMN")
class FSMN(nn.Module):
@@ -192,7 +202,7 @@
            lstride: int,
            rstride: int,
            output_affine_dim: int,
            output_dim: int
        output_dim: int,
    ):
        super().__init__()
@@ -207,8 +217,12 @@
        self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
        self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
        self.relu = RectifiedLinear(linear_dim, linear_dim)
        self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
                                range(fsmn_layers)])
        self.fsmn = FsmnStack(
            *[
                BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i)
                for i in range(fsmn_layers)
            ]
        )
        self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
        self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
        self.softmax = nn.Softmax(dim=-1)
@@ -217,9 +231,7 @@
        pass
    def forward(
            self,
            input: torch.Tensor,
            cache: Dict[str, torch.Tensor]
        self, input: torch.Tensor, cache: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Args:
@@ -242,7 +254,9 @@
@tables.register("encoder_classes", "FSMNExport")
class FSMNExport(nn.Module):
    def __init__(
        self, model, **kwargs,
        self,
        model,
        **kwargs,
    ):
        super().__init__()
        
@@ -305,7 +319,7 @@
        return x, out_caches
'''
"""
one deep fsmn layer
dimproj:                projection dimension, input and output dimension of memory blocks
dimlinear:              dimension of mapping layer
@@ -313,7 +327,8 @@
rorder:                 right order
lstride:                left stride
rstride:                right stride
'''
"""
@tables.register("encoder_classes", "DFSMN")
class DFSMN(nn.Module):
@@ -330,11 +345,13 @@
        self.shrink = LinearTransform(dimlinear, dimproj)
        self.conv_left = nn.Conv2d(
            dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
            dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False
        )
        if rorder > 0:
            self.conv_right = nn.Conv2d(
                dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
                dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False
            )
        else:
            self.conv_right = None
@@ -360,30 +377,28 @@
        return output
'''
"""
build stacked dfsmn layers
'''
"""
def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
    repeats = [
        nn.Sequential(
            DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
        for i in range(fsmn_layers)
        nn.Sequential(DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) for i in range(fsmn_layers)
    ]
    return nn.Sequential(*repeats)
if __name__ == '__main__':
if __name__ == "__main__":
    fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
    print(fsmn)
    num_params = sum(p.numel() for p in fsmn.parameters())
    print('the number of model params: {}'.format(num_params))
    print("the number of model params: {}".format(num_params))
    x = torch.zeros(128, 200, 400)  # batch-size * time * dim
    y, _ = fsmn(x)  # batch-size * time * dim
    print('input shape: {}'.format(x.shape))
    print('output shape: {}'.format(y.shape))
    print("input shape: {}".format(x.shape))
    print("output shape: {}".format(y.shape))
    print(fsmn.to_kaldi_net())