zhifu gao
2024-09-25 2196844d1d6e5b8732c95896bb46f0eacdd9cf9d
funasr/models/fsmn_vad_streaming/encoder.py
@@ -85,13 +85,17 @@
        else:
            self.conv_right = None
    def forward(self, input: torch.Tensor, cache: torch.Tensor):
    def forward(self, input: torch.Tensor, cache: torch.Tensor = None):
        x = torch.unsqueeze(input, 1)
        x_per = x.permute(0, 3, 2, 1)  # B D T C
        cache = cache.to(x_per.device)
        y_left = torch.cat((cache, x_per), dim=2)
        cache = y_left[:, :, -(self.lorder - 1) * self.lstride :, :]
        if cache is not None:
            cache = cache.to(x_per.device)
            y_left = torch.cat((cache, x_per), dim=2)
            cache = y_left[:, :, -(self.lorder - 1) * self.lstride :, :]
        else:
            y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
        y_left = self.conv_left(y_left)
        out = x_per + y_left
@@ -130,14 +134,18 @@
        self.affine = AffineTransform(proj_dim, linear_dim)
        self.relu = RectifiedLinear(linear_dim, linear_dim)
    def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
    def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor] = None):
        x1 = self.linear(input)  # B T D
        cache_layer_name = "cache_layer_{}".format(self.stack_layer)
        if cache_layer_name not in cache:
            cache[cache_layer_name] = torch.zeros(
                x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1
            )
        x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
        if cache is not None:
            cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
            if cache_layer_name not in cache:
                cache[cache_layer_name] = torch.zeros(
                    x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1
                )
            x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
        else:
            x2, _ = self.fsmn_block(x1, None)
        x3 = self.affine(x2)
        x4 = self.relu(x3)
        return x4
@@ -203,6 +211,7 @@
        rstride: int,
        output_affine_dim: int,
        output_dim: int,
        use_softmax: bool = True,
    ):
        super().__init__()
@@ -225,13 +234,21 @@
        )
        self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
        self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
        self.softmax = nn.Softmax(dim=-1)
        self.use_softmax = use_softmax
        if self.use_softmax:
            self.softmax = nn.Softmax(dim=-1)
    def fuse_modules(self):
        pass
    def output_size(self) -> int:
        return self.output_dim
    def forward(
        self, input: torch.Tensor, cache: Dict[str, torch.Tensor]
        self,
        input: torch.Tensor,
        cache: Dict[str, torch.Tensor] = None
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Args:
@@ -246,9 +263,12 @@
        x4 = self.fsmn(x3, cache)  # self.cache will update automatically in self.fsmn
        x5 = self.out_linear1(x4)
        x6 = self.out_linear2(x5)
        x7 = self.softmax(x6)
        return x7
        if self.use_softmax:
            x7 = self.softmax(x6)
            return x7
        return x6
@tables.register("encoder_classes", "FSMNExport")
@@ -276,6 +296,7 @@
        # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
        # self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
        # self.softmax = nn.Softmax(dim=-1)
        self.in_linear1 = model.in_linear1
        self.in_linear2 = model.in_linear2
        self.relu = model.relu
@@ -317,88 +338,3 @@
        x = self.softmax(x)
        return x, out_caches
"""
one deep fsmn layer
dimproj:                projection dimension, input and output dimension of memory blocks
dimlinear:              dimension of mapping layer
lorder:                 left order
rorder:                 right order
lstride:                left stride
rstride:                right stride
"""
@tables.register("encoder_classes", "DFSMN")
class DFSMN(nn.Module):
    def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
        super(DFSMN, self).__init__()
        self.lorder = lorder
        self.rorder = rorder
        self.lstride = lstride
        self.rstride = rstride
        self.expand = AffineTransform(dimproj, dimlinear)
        self.shrink = LinearTransform(dimlinear, dimproj)
        self.conv_left = nn.Conv2d(
            dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False
        )
        if rorder > 0:
            self.conv_right = nn.Conv2d(
                dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False
            )
        else:
            self.conv_right = None
    def forward(self, input):
        f1 = F.relu(self.expand(input))
        p1 = self.shrink(f1)
        x = torch.unsqueeze(p1, 1)
        x_per = x.permute(0, 3, 2, 1)
        y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
        if self.conv_right is not None:
            y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
            y_right = y_right[:, :, self.rstride :, :]
            out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
        else:
            out = x_per + self.conv_left(y_left)
        out1 = out.permute(0, 3, 2, 1)
        output = input + out1.squeeze(1)
        return output
"""
build stacked dfsmn layers
"""
def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
    repeats = [
        nn.Sequential(DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) for i in range(fsmn_layers)
    ]
    return nn.Sequential(*repeats)
if __name__ == "__main__":
    fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
    print(fsmn)
    num_params = sum(p.numel() for p in fsmn.parameters())
    print("the number of model params: {}".format(num_params))
    x = torch.zeros(128, 200, 400)  # batch-size * time * dim
    y, _ = fsmn(x)  # batch-size * time * dim
    print("input shape: {}".format(x.shape))
    print("output shape: {}".format(y.shape))
    print(fsmn.to_kaldi_net())