From d46a542fae26009eee16204a81903862cb4dba73 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期一, 10 四月 2023 16:02:41 +0800
Subject: [PATCH] Merge branch 'main' into dev_aky
---
funasr/export/models/e2e_vad.py | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 60 insertions(+), 0 deletions(-)
diff --git a/funasr/export/models/e2e_vad.py b/funasr/export/models/e2e_vad.py
new file mode 100644
index 0000000..d3e8f30
--- /dev/null
+++ b/funasr/export/models/e2e_vad.py
@@ -0,0 +1,60 @@
+from enum import Enum
+from typing import List, Tuple, Dict, Any
+
+import torch
+from torch import nn
+import math
+
+from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
+
+class E2EVadModel(nn.Module):
+ def __init__(self, model,
+ max_seq_len=512,
+ feats_dim=400,
+ model_name='model',
+ **kwargs,):
+ super(E2EVadModel, self).__init__()
+ self.feats_dim = feats_dim
+ self.max_seq_len = max_seq_len
+ self.model_name = model_name
+ if isinstance(model.encoder, FSMN):
+ self.encoder = FSMN_export(model.encoder)
+ else:
+ raise "unsupported encoder"
+
+
+ def forward(self, feats: torch.Tensor, *args, ):
+
+ scores, out_caches = self.encoder(feats, *args)
+ return scores, out_caches
+
+ def get_dummy_inputs(self, frame=30):
+ speech = torch.randn(1, frame, self.feats_dim)
+ in_cache0 = torch.randn(1, 128, 19, 1)
+ in_cache1 = torch.randn(1, 128, 19, 1)
+ in_cache2 = torch.randn(1, 128, 19, 1)
+ in_cache3 = torch.randn(1, 128, 19, 1)
+
+ return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
+
+ # def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
+ # import numpy as np
+ # fbank = np.loadtxt(txt_file)
+ # fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
+ # speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
+ # speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
+ # return (speech, speech_lengths)
+
+ def get_input_names(self):
+ return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
+
+ def get_output_names(self):
+ return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
+
+ def get_dynamic_axes(self):
+ return {
+ 'speech': {
+ 1: 'feats_length'
+ },
+ }
--
Gitblit v1.9.1