游雁
2023-03-28 8a788ad0d922c7d1b7c597a610b131f40c93e2b5
funasr/export/models/encoder/fsmn_encoder.py
@@ -149,8 +149,7 @@
class FSMN(nn.Module):
    def __init__(
            self,
        model,
            self, model,
    ):
        super(FSMN, self).__init__()
        
@@ -177,10 +176,10 @@
        self.out_linear1 = model.out_linear1
        self.out_linear2 = model.out_linear2
        self.softmax = model.softmax
        for i, d in enumerate(self.model.fsmn):
        self.fsmn = model.fsmn
        for i, d in enumerate(model.fsmn):
            if isinstance(d, BasicBlock):
                self.model.fsmn[i] = BasicBlock_export(d)
                self.fsmn[i] = BasicBlock_export(d)
    def fuse_modules(self):
        pass
@@ -202,7 +201,7 @@
        x = self.relu(x)
        # x4 = self.fsmn(x3, in_cache)  # self.in_cache will update automatically in self.fsmn
        out_caches = list()
        for i, d in enumerate(self.model.fsmn):
        for i, d in enumerate(self.fsmn):
            in_cache = args[i]
            x, out_cache = d(x, in_cache)
            out_caches.append(out_cache)
@@ -210,7 +209,7 @@
        x = self.out_linear2(x)
        x = self.softmax(x)
        return x, *out_caches
        return x, out_caches
'''