| | |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | class LinearTransform(nn.Module): |
| | | |
| | | def __init__(self, input_dim, output_dim): |
| | |
| | | class FSMNBlock(nn.Module): |
| | | |
| | | def __init__( |
| | | self, |
| | | input_dim: int, |
| | | output_dim: int, |
| | | lorder=None, |
| | | rorder=None, |
| | | lstride=1, |
| | | rstride=1, |
| | | self, |
| | | input_dim: int, |
| | | output_dim: int, |
| | | lorder=None, |
| | | rorder=None, |
| | | lstride=1, |
| | | rstride=1, |
| | | ): |
| | | super(FSMNBlock, self).__init__() |
| | | |
| | |
| | | self.rstride = rstride |
| | | |
| | | self.conv_left = nn.Conv2d( |
| | | self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False) |
| | | self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False |
| | | ) |
| | | |
| | | if self.rorder > 0: |
| | | self.conv_right = nn.Conv2d( |
| | | self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False) |
| | | self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False |
| | | ) |
| | | else: |
| | | self.conv_right = None |
| | | |
| | | def forward(self, input: torch.Tensor, cache: torch.Tensor): |
| | | x = torch.unsqueeze(input, 1) |
| | | x_per = x.permute(0, 3, 2, 1) # B D T C |
| | | |
| | | |
| | | cache = cache.to(x_per.device) |
| | | y_left = torch.cat((cache, x_per), dim=2) |
| | | cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :] |
| | | cache = y_left[:, :, -(self.lorder - 1) * self.lstride :, :] |
| | | y_left = self.conv_left(y_left) |
| | | out = x_per + y_left |
| | | |
| | | if self.conv_right is not None: |
| | | # maybe need to check |
| | | y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride]) |
| | | y_right = y_right[:, :, self.rstride:, :] |
| | | y_right = y_right[:, :, self.rstride :, :] |
| | | y_right = self.conv_right(y_right) |
| | | out += y_right |
| | | |
| | |
| | | |
| | | |
| | | class BasicBlock(nn.Module): |
| | | def __init__(self, |
| | | linear_dim: int, |
| | | proj_dim: int, |
| | | lorder: int, |
| | | rorder: int, |
| | | lstride: int, |
| | | rstride: int, |
| | | stack_layer: int |
| | | ): |
| | | def __init__( |
| | | self, |
| | | linear_dim: int, |
| | | proj_dim: int, |
| | | lorder: int, |
| | | rorder: int, |
| | | lstride: int, |
| | | rstride: int, |
| | | stack_layer: int, |
| | | ): |
| | | super(BasicBlock, self).__init__() |
| | | self.lorder = lorder |
| | | self.rorder = rorder |
| | |
| | | |
| | | def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]): |
| | | x1 = self.linear(input) # B T D |
| | | cache_layer_name = 'cache_layer_{}'.format(self.stack_layer) |
| | | cache_layer_name = "cache_layer_{}".format(self.stack_layer) |
| | | if cache_layer_name not in cache: |
| | | cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1) |
| | | cache[cache_layer_name] = torch.zeros( |
| | | x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1 |
| | | ) |
| | | x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name]) |
| | | x3 = self.affine(x2) |
| | | x4 = self.relu(x3) |
| | | return x4 |
| | | |
| | | |
| | | class BasicBlock_export(nn.Module): |
| | | def __init__(self, |
| | | model, |
| | | ): |
| | | def __init__( |
| | | self, |
| | | model, |
| | | ): |
| | | super(BasicBlock_export, self).__init__() |
| | | self.linear = model.linear |
| | | self.fsmn_block = model.fsmn_block |
| | |
| | | return x |
| | | |
| | | |
| | | ''' |
| | | """ |
| | | FSMN net for keyword spotting |
| | | input_dim: input dimension |
| | | linear_dim: fsmn input dimensionll |
| | |
| | | rorder: fsmn right order |
| | | num_syn: output dimension |
| | | fsmn_layers: no. of sequential fsmn layers |
| | | ''' |
| | | """ |
| | | |
| | | |
| | | @tables.register("encoder_classes", "FSMN") |
| | | class FSMN(nn.Module): |
| | | def __init__( |
| | | self, |
| | | input_dim: int, |
| | | input_affine_dim: int, |
| | | fsmn_layers: int, |
| | | linear_dim: int, |
| | | proj_dim: int, |
| | | lorder: int, |
| | | rorder: int, |
| | | lstride: int, |
| | | rstride: int, |
| | | output_affine_dim: int, |
| | | output_dim: int |
| | | self, |
| | | input_dim: int, |
| | | input_affine_dim: int, |
| | | fsmn_layers: int, |
| | | linear_dim: int, |
| | | proj_dim: int, |
| | | lorder: int, |
| | | rorder: int, |
| | | lstride: int, |
| | | rstride: int, |
| | | output_affine_dim: int, |
| | | output_dim: int, |
| | | ): |
| | | super(FSMN, self).__init__() |
| | | super().__init__() |
| | | |
| | | self.input_dim = input_dim |
| | | self.input_affine_dim = input_affine_dim |
| | |
| | | self.in_linear1 = AffineTransform(input_dim, input_affine_dim) |
| | | self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) |
| | | self.relu = RectifiedLinear(linear_dim, linear_dim) |
| | | self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in |
| | | range(fsmn_layers)]) |
| | | self.fsmn = FsmnStack( |
| | | *[ |
| | | BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) |
| | | for i in range(fsmn_layers) |
| | | ] |
| | | ) |
| | | self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) |
| | | self.out_linear2 = AffineTransform(output_affine_dim, output_dim) |
| | | self.softmax = nn.Softmax(dim=-1) |
| | | |
| | | # export onnx or torchscripts |
| | | if "EXPORTING_MODEL" in os.environ and os.environ['EXPORTING_MODEL'] == 'TRUE': |
| | | for i, d in enumerate(self.fsmn): |
| | | if isinstance(d, BasicBlock): |
| | | self.fsmn[i] = BasicBlock_export(d) |
| | | |
| | | def fuse_modules(self): |
| | | pass |
| | | |
| | | def forward( |
| | | self, |
| | | input: torch.Tensor, |
| | | cache: Dict[str, torch.Tensor] |
| | | self, input: torch.Tensor, cache: Dict[str, torch.Tensor] |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| | | """ |
| | | Args: |
| | |
| | | |
| | | return x7 |
| | | |
| | | def export_forward( |
| | | self, |
| | | input: torch.Tensor, |
| | | *args, |
| | | |
| | | @tables.register("encoder_classes", "FSMNExport") |
| | | class FSMNExport(nn.Module): |
| | | def __init__( |
| | | self, |
| | | model, |
| | | **kwargs, |
| | | ): |
| | | super().__init__() |
| | | |
| | | # self.input_dim = input_dim |
| | | # self.input_affine_dim = input_affine_dim |
| | | # self.fsmn_layers = fsmn_layers |
| | | # self.linear_dim = linear_dim |
| | | # self.proj_dim = proj_dim |
| | | # self.output_affine_dim = output_affine_dim |
| | | # self.output_dim = output_dim |
| | | # |
| | | # self.in_linear1 = AffineTransform(input_dim, input_affine_dim) |
| | | # self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) |
| | | # self.relu = RectifiedLinear(linear_dim, linear_dim) |
| | | # self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in |
| | | # range(fsmn_layers)]) |
| | | # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) |
| | | # self.out_linear2 = AffineTransform(output_affine_dim, output_dim) |
| | | # self.softmax = nn.Softmax(dim=-1) |
| | | self.in_linear1 = model.in_linear1 |
| | | self.in_linear2 = model.in_linear2 |
| | | self.relu = model.relu |
| | | # self.fsmn = model.fsmn |
| | | self.out_linear1 = model.out_linear1 |
| | | self.out_linear2 = model.out_linear2 |
| | | self.softmax = model.softmax |
| | | self.fsmn = model.fsmn |
| | | for i, d in enumerate(model.fsmn): |
| | | if isinstance(d, BasicBlock): |
| | | self.fsmn[i] = BasicBlock_export(d) |
| | | |
| | | def fuse_modules(self): |
| | | pass |
| | | |
| | | def forward( |
| | | self, |
| | | input: torch.Tensor, |
| | | *args, |
| | | ): |
| | | """ |
| | | Args: |
| | |
| | | |
| | | return x, out_caches |
| | | |
| | | ''' |
| | | |
| | | """ |
| | | one deep fsmn layer |
| | | dimproj: projection dimension, input and output dimension of memory blocks |
| | | dimlinear: dimension of mapping layer |
| | |
| | | rorder: right order |
| | | lstride: left stride |
| | | rstride: right stride |
| | | ''' |
| | | """ |
| | | |
| | | |
| | | @tables.register("encoder_classes", "DFSMN") |
| | | class DFSMN(nn.Module): |
| | |
| | | self.shrink = LinearTransform(dimlinear, dimproj) |
| | | |
| | | self.conv_left = nn.Conv2d( |
| | | dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False) |
| | | dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False |
| | | ) |
| | | |
| | | if rorder > 0: |
| | | self.conv_right = nn.Conv2d( |
| | | dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False) |
| | | dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False |
| | | ) |
| | | else: |
| | | self.conv_right = None |
| | | |
| | |
| | | |
| | | if self.conv_right is not None: |
| | | y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) |
| | | y_right = y_right[:, :, self.rstride:, :] |
| | | y_right = y_right[:, :, self.rstride :, :] |
| | | out = x_per + self.conv_left(y_left) + self.conv_right(y_right) |
| | | else: |
| | | out = x_per + self.conv_left(y_left) |
| | |
| | | return output |
| | | |
| | | |
| | | ''' |
| | | """ |
| | | build stacked dfsmn layers |
| | | ''' |
| | | """ |
| | | |
| | | |
| | | def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6): |
| | | repeats = [ |
| | | nn.Sequential( |
| | | DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) |
| | | for i in range(fsmn_layers) |
| | | nn.Sequential(DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) for i in range(fsmn_layers) |
| | | ] |
| | | |
| | | return nn.Sequential(*repeats) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | if __name__ == "__main__": |
| | | fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599) |
| | | print(fsmn) |
| | | |
| | | num_params = sum(p.numel() for p in fsmn.parameters()) |
| | | print('the number of model params: {}'.format(num_params)) |
| | | print("the number of model params: {}".format(num_params)) |
| | | x = torch.zeros(128, 200, 400) # batch-size * time * dim |
| | | y, _ = fsmn(x) # batch-size * time * dim |
| | | print('input shape: {}'.format(x.shape)) |
| | | print('output shape: {}'.format(y.shape)) |
| | | print("input shape: {}".format(x.shape)) |
| | | print("output shape: {}".format(y.shape)) |
| | | |
| | | print(fsmn.to_kaldi_net()) |