1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
| from enum import Enum
| from typing import List, Tuple, Dict, Any
|
| import torch
| from torch import nn
| import math
|
| from funasr.models.encoder.fsmn_encoder import FSMN
| from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
|
| class E2EVadModel(nn.Module):
| def __init__(self, model,
| max_seq_len=512,
| feats_dim=560,
| model_name='model',
| **kwargs,):
| super(E2EVadModel, self).__init__()
| self.feats_dim = feats_dim
| self.max_seq_len = max_seq_len
| self.model_name = model_name
| if isinstance(model.encoder, FSMN):
| self.encoder = FSMN_export(model.encoder)
| else:
| raise "unsupported encoder"
|
|
| def forward(self, feats: torch.Tensor,
| in_cache0: torch.Tensor,
| in_cache1: torch.Tensor,
| in_cache2: torch.Tensor,
| in_cache3: torch.Tensor,
| ):
|
| scores, cache0, cache1, cache2, cache3 = self.encoder(feats,
| in_cache0,
| in_cache1,
| in_cache2,
| in_cache3) # return B * T * D
| return scores, cache0, cache1, cache2, cache3
|
| def get_dummy_inputs(self, frame=30):
| speech = torch.randn(1, frame, self.feats_dim)
| in_cache0 = torch.randn(1, 128, 19, 1)
| in_cache1 = torch.randn(1, 128, 19, 1)
| in_cache2 = torch.randn(1, 128, 19, 1)
| in_cache3 = torch.randn(1, 128, 19, 1)
|
| return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
|
| # def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
| # import numpy as np
| # fbank = np.loadtxt(txt_file)
| # fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
| # speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
| # speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
| # return (speech, speech_lengths)
|
| def get_input_names(self):
| return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
|
| def get_output_names(self):
| return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
|
| def get_dynamic_axes(self):
| return {
| 'speech': {
| 1: 'feats_length'
| },
| }
|
|