from typing import Tuple, Dict import copy import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from funasr.models.fsmn_kws.encoder import (toKaldiMatrix, LinearTransform, AffineTransform, RectifiedLinear, FSMNBlock, FsmnStack, BasicBlock) from funasr.register import tables ''' FSMN net for keyword spotting input_dim: input dimension linear_dim: fsmn input dimensionll proj_dim: fsmn projection dimension lorder: fsmn left order rorder: fsmn right order num_syn: output dimension fsmn_layers: no. of sequential fsmn layers ''' @tables.register("encoder_classes", "FSMNMT") class FSMNMT(nn.Module): def __init__( self, input_dim: int, input_affine_dim: int, fsmn_layers: int, linear_dim: int, proj_dim: int, lorder: int, rorder: int, lstride: int, rstride: int, output_affine_dim: int, output_dim: int, output_dim2: int, use_softmax: bool = True, ): super().__init__() self.input_dim = input_dim self.input_affine_dim = input_affine_dim self.fsmn_layers = fsmn_layers self.linear_dim = linear_dim self.proj_dim = proj_dim self.output_affine_dim = output_affine_dim self.output_dim = output_dim self.output_dim2 = output_dim2 self.in_linear1 = AffineTransform(input_dim, input_affine_dim) self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) self.relu = RectifiedLinear(linear_dim, linear_dim) self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in range(fsmn_layers)]) self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) self.out_linear1_2 = AffineTransform(linear_dim, output_affine_dim) self.out_linear2 = AffineTransform(output_affine_dim, output_dim) self.out_linear2_2 = AffineTransform(output_affine_dim, output_dim2) self.use_softmax = use_softmax if self.use_softmax: self.softmax = nn.Softmax(dim=-1) def output_size(self) -> int: return self.output_dim def output_size2(self) -> int: return self.output_dim2 def forward( self, input: torch.Tensor, cache: Dict[str, torch.Tensor] = None ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Args: input (torch.Tensor): Input tensor (B, T, D) cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs, {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame """ x1 = self.in_linear1(input) x2 = self.in_linear2(x1) x3 = self.relu(x2) x4 = self.fsmn(x3, cache) # self.cache will update automatically in self.fsmn x5 = self.out_linear1(x4) x6 = self.out_linear2(x5) x5_2 = self.out_linear1_2(x4) x6_2 = self.out_linear2_2(x5_2) if self.use_softmax: x7 = self.softmax(x6) x7_2 = self.softmax(x6_2) return x7, x7_2 return x6, x6_2 @tables.register("encoder_classes", "FSMNMTConvert") class FSMNMTConvert(nn.Module): def __init__( self, input_dim: int, input_affine_dim: int, fsmn_layers: int, linear_dim: int, proj_dim: int, lorder: int, rorder: int, lstride: int, rstride: int, output_affine_dim: int, output_dim: int, output_dim2: int, use_softmax: bool = True, ): super().__init__() self.input_dim = input_dim self.input_affine_dim = input_affine_dim self.fsmn_layers = fsmn_layers self.linear_dim = linear_dim self.proj_dim = proj_dim self.output_affine_dim = output_affine_dim self.output_dim = output_dim self.output_dim2 = output_dim2 self.in_linear1 = AffineTransform(input_dim, input_affine_dim) self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) self.relu = RectifiedLinear(linear_dim, linear_dim) self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in range(fsmn_layers)]) self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) self.out_linear1_2 = AffineTransform(linear_dim, output_affine_dim) self.out_linear2 = AffineTransform(output_affine_dim, output_dim) self.out_linear2_2 = AffineTransform(output_affine_dim, output_dim2) self.use_softmax = use_softmax if self.use_softmax: self.softmax = nn.Softmax(dim=-1) def output_size(self) -> int: return self.output_dim def output_size2(self) -> int: return self.output_dim2 def to_kaldi_net(self): re_str = '' re_str += '\n' re_str += self.in_linear1.to_kaldi_net() re_str += self.in_linear2.to_kaldi_net() re_str += self.relu.to_kaldi_net() for fsmn in self.fsmn: re_str += fsmn.to_kaldi_net() re_str += self.out_linear1.to_kaldi_net() re_str += self.out_linear2.to_kaldi_net() re_str += ' %d %d\n' % (self.output_dim, self.output_dim) re_str += '\n' return re_str def to_kaldi_net2(self): re_str = '' re_str += '\n' re_str += self.in_linear1.to_kaldi_net() re_str += self.in_linear2.to_kaldi_net() re_str += self.relu.to_kaldi_net() for fsmn in self.fsmn: re_str += fsmn.to_kaldi_net() re_str += self.out_linear1_2.to_kaldi_net() re_str += self.out_linear2_2.to_kaldi_net() re_str += ' %d %d\n' % (self.output_dim2, self.output_dim2) re_str += '\n' return re_str def to_pytorch_net(self, kaldi_file): with open(kaldi_file, 'r', encoding='utf8') as fread: fread = open(kaldi_file, 'r') nnet_start_line = fread.readline() assert nnet_start_line.strip() == '' self.in_linear1.to_pytorch_net(fread) self.in_linear2.to_pytorch_net(fread) self.relu.to_pytorch_net(fread) for fsmn in self.fsmn: fsmn.to_pytorch_net(fread) self.out_linear1.to_pytorch_net(fread) self.out_linear2.to_pytorch_net(fread) softmax_line = fread.readline() softmax_split = softmax_line.strip().split() assert softmax_split[0].strip() == '' assert int(softmax_split[1]) == self.output_dim assert int(softmax_split[2]) == self.output_dim nnet_end_line = fread.readline() assert nnet_end_line.strip() == '' fread.close()