From d4aaa84ad16c2c862ffcb5d73bf7852c8ee90d24 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 21 三月 2024 14:17:22 +0800
Subject: [PATCH] fix func FunASRWfstDecoderInit

---
 funasr/models/ct_transformer/model.py |   55 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 54 insertions(+), 1 deletions(-)

diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index 45f5746..9f680fd 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -18,7 +18,6 @@
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
 
-
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
 else:
@@ -365,3 +364,57 @@
         results.append(result_i)
         return results, meta_data
 
+    def export(
+        self,
+        **kwargs,
+    ):
+
+        is_onnx = kwargs.get("type", "onnx") == "onnx"
+        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
+        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
+
+        self.forward = self.export_forward
+        
+        return self
+
+    def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
+        """Compute loss value from buffer sequences.
+
+        Args:
+            input (torch.Tensor): Input ids. (batch, len)
+            hidden (torch.Tensor): Target ids. (batch, len)
+
+        """
+        x = self.embed(inputs)
+        h, _ = self.encoder(x, text_lengths)
+        y = self.decoder(h)
+        return y
+
+    def export_dummy_inputs(self):
+        length = 120
+        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
+        text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
+        return (text_indexes, text_lengths)
+
+    def export_input_names(self):
+        return ['inputs', 'text_lengths']
+
+    def export_output_names(self):
+        return ['logits']
+
+    def export_dynamic_axes(self):
+        return {
+            'inputs': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'text_lengths': {
+                0: 'batch_size',
+            },
+            'logits': {
+                0: 'batch_size',
+                1: 'logits_length'
+            },
+        }
+    def export_name(self):
+        return "model.onnx"
\ No newline at end of file

--
Gitblit v1.9.1