From 4ba1011b42e041ee1d71448eefd7ef2e7bd61bb6 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 31 三月 2023 15:31:26 +0800
Subject: [PATCH] export

---
 funasr/export/models/vad_realtime_transformer.py |   45 ++++++++++++++++++++++++++-------------------
 1 files changed, 26 insertions(+), 19 deletions(-)

diff --git a/funasr/export/models/vad_realtime_transformer.py b/funasr/export/models/vad_realtime_transformer.py
index 44583d8..693b9c8 100644
--- a/funasr/export/models/vad_realtime_transformer.py
+++ b/funasr/export/models/vad_realtime_transformer.py
@@ -1,17 +1,12 @@
-from typing import Any
-from typing import List
 from typing import Tuple
 
 import torch
 import torch.nn as nn
 
-from funasr.modules.embedding import SinusoidalPositionEncoder
-from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
-from funasr.punctuation.abs_model import AbsPunctuation
-from funasr.punctuation.sanm_encoder import SANMVadEncoder
+from funasr.models.encoder.sanm_encoder import SANMVadEncoder
 from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
 
-class VadRealtimeTransformer(AbsPunctuation):
+class VadRealtimeTransformer(nn.Module):
 
     def __init__(
         self,
@@ -21,7 +16,9 @@
         **kwargs,
     ):
         super().__init__()
-
+        onnx = False
+        if "onnx" in kwargs:
+            onnx = kwargs["onnx"]
 
         self.embed = model.embed
         if isinstance(model.encoder, SANMVadEncoder):
@@ -30,11 +27,15 @@
             assert False, "Only support samn encode."
         # self.encoder = model.encoder
         self.decoder = model.decoder
+        self.model_name = model_name
 
 
 
-    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
-                vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
+    def forward(self, input: torch.Tensor,
+                text_lengths: torch.Tensor,
+                vad_indexes: torch.Tensor,
+                sub_masks: torch.Tensor,
+                ) -> Tuple[torch.Tensor, None]:
         """Compute loss value from buffer sequences.
 
         Args:
@@ -44,7 +45,7 @@
         """
         x = self.embed(input)
         # mask = self._target_mask(input)
-        h, _, _ = self.encoder(x, text_lengths, vad_indexes)
+        h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
         y = self.decoder(h)
         return y
 
@@ -53,12 +54,15 @@
 
     def get_dummy_inputs(self):
         length = 120
-        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
-        text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
-        return (text_indexes, text_lengths)
+        text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length))
+        text_lengths = torch.tensor([length], dtype=torch.int32)
+        vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
+        sub_masks = torch.ones(length, length, dtype=torch.float32)
+        sub_masks = torch.tril(sub_masks).type(torch.float32)
+        return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
 
     def get_input_names(self):
-        return ['input', 'text_lengths']
+        return ['input', 'text_lengths', 'vad_mask', 'sub_masks']
 
     def get_output_names(self):
         return ['logits']
@@ -66,14 +70,17 @@
     def get_dynamic_axes(self):
         return {
             'input': {
-                0: 'batch_size',
                 1: 'feats_length'
             },
-            'text_lengths': {
-                0: 'batch_size',
+            'vad_mask': {
+                2: 'feats_length1',
+                3: 'feats_length2'
+            },
+            'sub_masks': {
+                2: 'feats_length1',
+                3: 'feats_length2'
             },
             'logits': {
-                0: 'batch_size',
                 1: 'logits_length'
             },
         }

--
Gitblit v1.9.1