| | |
| | | from typing import Tuple, Dict |
| | | import copy |
| | | import os |
| | | |
| | | import numpy as np |
| | | import torch |
| | |
| | | x3 = self.affine(x2) |
| | | x4 = self.relu(x3) |
| | | return x4 |
| | | class BasicBlock_export(nn.Module): |
| | | def __init__(self, |
| | | model, |
| | | ): |
| | | super(BasicBlock_export, self).__init__() |
| | | self.linear = model.linear |
| | | self.fsmn_block = model.fsmn_block |
| | | self.affine = model.affine |
| | | self.relu = model.relu |
| | | |
| | | def forward(self, input: torch.Tensor, in_cache: torch.Tensor): |
| | | x = self.linear(input) # B T D |
| | | # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer) |
| | | # if cache_layer_name not in in_cache: |
| | | # in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1) |
| | | x, out_cache = self.fsmn_block(x, in_cache) |
| | | x = self.affine(x) |
| | | x = self.relu(x) |
| | | return x, out_cache |
| | | |
| | | |
| | | class FsmnStack(nn.Sequential): |
| | |
| | | self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) |
| | | self.out_linear2 = AffineTransform(output_affine_dim, output_dim) |
| | | self.softmax = nn.Softmax(dim=-1) |
| | | |
| | | # export onnx or torchscripts |
| | | if "EXPORTING_MODEL" in os.environ and os.environ['EXPORTING_MODEL'] == 'TRUE': |
| | | for i, d in enumerate(self.fsmn): |
| | | if isinstance(d, BasicBlock): |
| | | self.fsmn[i] = BasicBlock_export(d) |
| | | |
| | | def fuse_modules(self): |
| | | pass |
| | |
| | | |
| | | return x7 |
| | | |
| | | def export_forward( |
| | | self, |
| | | input: torch.Tensor, |
| | | *args, |
| | | ): |
| | | """ |
| | | Args: |
| | | input (torch.Tensor): Input tensor (B, T, D) |
| | | in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs, |
| | | {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame |
| | | """ |
| | | |
| | | x = self.in_linear1(input) |
| | | x = self.in_linear2(x) |
| | | x = self.relu(x) |
| | | # x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn |
| | | out_caches = list() |
| | | for i, d in enumerate(self.fsmn): |
| | | in_cache = args[i] |
| | | x, out_cache = d(x, in_cache) |
| | | out_caches.append(out_cache) |
| | | x = self.out_linear1(x) |
| | | x = self.out_linear2(x) |
| | | x = self.softmax(x) |
| | | |
| | | return x, out_caches |
| | | |
| | | ''' |
| | | one deep fsmn layer |