游雁
2023-03-28 8a788ad0d922c7d1b7c597a610b131f40c93e2b5
export
3个文件已修改
18 ■■■■ 已修改文件
funasr/export/export_model.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/e2e_vad.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/encoder/fsmn_encoder.py 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/export_model.py
@@ -193,6 +193,7 @@
            model, vad_infer_args = VADTask.build_model_from_file(
                config, model_file, 'cpu'
            )
            self.export_config["feats_dim"] = 400
        self._export(model, tag_name)
            
funasr/export/models/e2e_vad.py
@@ -11,7 +11,7 @@
class E2EVadModel(nn.Module):
    def __init__(self, model,
                max_seq_len=512,
                feats_dim=560,
                feats_dim=400,
                model_name='model',
                **kwargs,):
        super(E2EVadModel, self).__init__()
@@ -31,7 +31,7 @@
                       in_cache3: torch.Tensor,
                       ):
        scores, cache0, cache1, cache2, cache3 = self.encoder(feats,
        scores, (cache0, cache1, cache2, cache3) = self.encoder(feats,
                                                              in_cache0,
                                                              in_cache1,
                                                              in_cache2,
funasr/export/models/encoder/fsmn_encoder.py
@@ -149,8 +149,7 @@
class FSMN(nn.Module):
    def __init__(
            self,
        model,
            self, model,
    ):
        super(FSMN, self).__init__()
        
@@ -177,10 +176,10 @@
        self.out_linear1 = model.out_linear1
        self.out_linear2 = model.out_linear2
        self.softmax = model.softmax
        for i, d in enumerate(self.model.fsmn):
        self.fsmn = model.fsmn
        for i, d in enumerate(model.fsmn):
            if isinstance(d, BasicBlock):
                self.model.fsmn[i] = BasicBlock_export(d)
                self.fsmn[i] = BasicBlock_export(d)
    def fuse_modules(self):
        pass
@@ -202,7 +201,7 @@
        x = self.relu(x)
        # x4 = self.fsmn(x3, in_cache)  # self.in_cache will update automatically in self.fsmn
        out_caches = list()
        for i, d in enumerate(self.model.fsmn):
        for i, d in enumerate(self.fsmn):
            in_cache = args[i]
            x, out_cache = d(x, in_cache)
            out_caches.append(out_cache)
@@ -210,7 +209,7 @@
        x = self.out_linear2(x)
        x = self.softmax(x)
        return x, *out_caches
        return x, out_caches
'''