aky15
2023-04-10 d46a542fae26009eee16204a81903862cb4dba73
funasr/export/models/e2e_vad.py
New file
@@ -0,0 +1,60 @@
from enum import Enum
from typing import List, Tuple, Dict, Any
import torch
from torch import nn
import math
from funasr.models.encoder.fsmn_encoder import FSMN
from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
class E2EVadModel(nn.Module):
    def __init__(self, model,
                max_seq_len=512,
                feats_dim=400,
                model_name='model',
                **kwargs,):
        super(E2EVadModel, self).__init__()
        self.feats_dim = feats_dim
        self.max_seq_len = max_seq_len
        self.model_name = model_name
        if isinstance(model.encoder, FSMN):
            self.encoder = FSMN_export(model.encoder)
        else:
            raise "unsupported encoder"
    def forward(self, feats: torch.Tensor, *args, ):
        scores, out_caches = self.encoder(feats, *args)
        return scores, out_caches
    def get_dummy_inputs(self, frame=30):
        speech = torch.randn(1, frame, self.feats_dim)
        in_cache0 = torch.randn(1, 128, 19, 1)
        in_cache1 = torch.randn(1, 128, 19, 1)
        in_cache2 = torch.randn(1, 128, 19, 1)
        in_cache3 = torch.randn(1, 128, 19, 1)
        return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
    # def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
    #     import numpy as np
    #     fbank = np.loadtxt(txt_file)
    #     fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
    #     speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
    #     speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
    #     return (speech, speech_lengths)
    def get_input_names(self):
        return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
    def get_output_names(self):
        return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
    def get_dynamic_axes(self):
        return {
            'speech': {
                1: 'feats_length'
            },
        }