| | |
| | | |
| | | class FSMN(nn.Module): |
| | | def __init__( |
| | | self, |
| | | model, |
| | | self, model, |
| | | ): |
| | | super(FSMN, self).__init__() |
| | | |
| | |
| | | self.out_linear1 = model.out_linear1 |
| | | self.out_linear2 = model.out_linear2 |
| | | self.softmax = model.softmax |
| | | |
| | | for i, d in enumerate(self.model.fsmn): |
| | | self.fsmn = model.fsmn |
| | | for i, d in enumerate(model.fsmn): |
| | | if isinstance(d, BasicBlock): |
| | | self.model.fsmn[i] = BasicBlock_export(d) |
| | | self.fsmn[i] = BasicBlock_export(d) |
| | | |
| | | def fuse_modules(self): |
| | | pass |
| | |
| | | x = self.relu(x) |
| | | # x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn |
| | | out_caches = list() |
| | | for i, d in enumerate(self.model.fsmn): |
| | | for i, d in enumerate(self.fsmn): |
| | | in_cache = args[i] |
| | | x, out_cache = d(x, in_cache) |
| | | out_caches.append(out_cache) |
| | |
| | | x = self.out_linear2(x) |
| | | x = self.softmax(x) |
| | | |
| | | return x, *out_caches |
| | | return x, out_caches |
| | | |
| | | |
| | | ''' |