From 2ae59b6ce06305724e2eaf30b9f9e93447a7832e Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期一, 22 七月 2024 16:58:27 +0800
Subject: [PATCH] ONNX and torchscript export for sensevoice

---
 funasr/models/sense_voice/export_meta.py |   58 +++++++++++++++++++---------------------------------------
 1 files changed, 19 insertions(+), 39 deletions(-)

diff --git a/funasr/models/sense_voice/export_meta.py b/funasr/models/sense_voice/export_meta.py
index fe09ee1..449388e 100644
--- a/funasr/models/sense_voice/export_meta.py
+++ b/funasr/models/sense_voice/export_meta.py
@@ -5,30 +5,19 @@
 
 import types
 import torch
-import torch.nn as nn
-from funasr.register import tables
+from funasr.utils.torch_function import sequence_mask
 
 
 def export_rebuild_model(model, **kwargs):
     model.device = kwargs.get("device")
-    is_onnx = kwargs.get("type", "onnx") == "onnx"
-    # encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
-    # model.encoder = encoder_class(model.encoder, onnx=is_onnx)
-
-    from funasr.utils.torch_function import sequence_mask
-
     model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
-
     model.forward = types.MethodType(export_forward, model)
     model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
     model.export_input_names = types.MethodType(export_input_names, model)
     model.export_output_names = types.MethodType(export_output_names, model)
     model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
     model.export_name = types.MethodType(export_name, model)
-
-    model.export_name = "model"
     return model
-
 
 def export_forward(
     self,
@@ -38,32 +27,28 @@
     textnorm: torch.Tensor,
     **kwargs,
 ):
-    speech = speech.to(device=kwargs["device"])
-    speech_lengths = speech_lengths.to(device=kwargs["device"])
-
-    language_query = self.embed(language).to(speech.device)
-
-    textnorm_query = self.embed(textnorm).to(speech.device)
+    # speech = speech.to(device="cuda")
+    # speech_lengths = speech_lengths.to(device="cuda")
+    language_query = self.embed(language.to(speech.device)).unsqueeze(1)
+    textnorm_query = self.embed(textnorm.to(speech.device)).unsqueeze(1)
+    
     speech = torch.cat((textnorm_query, speech), dim=1)
-    speech_lengths += 1
-
+    
     event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
         speech.size(0), 1, 1
     )
     input_query = torch.cat((language_query, event_emo_query), dim=1)
     speech = torch.cat((input_query, speech), dim=1)
-    speech_lengths += 3
-
-    # Encoder
-    encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
+    
+    speech_lengths_new = speech_lengths + 4
+    encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths_new)
+    
     if isinstance(encoder_out, tuple):
         encoder_out = encoder_out[0]
 
-    # c. Passed the encoder result and the beam search
-    ctc_logits = self.ctc.log_softmax(encoder_out)
-
+    ctc_logits = self.ctc.ctc_lo(encoder_out)
+    
     return ctc_logits, encoder_out_lens
-
 
 def export_dummy_inputs(self):
     speech = torch.randn(2, 30, 560)
@@ -72,26 +57,21 @@
     textnorm = torch.tensor([15, 15], dtype=torch.int32)
     return (speech, speech_lengths, language, textnorm)
 
-
 def export_input_names(self):
     return ["speech", "speech_lengths", "language", "textnorm"]
-
 
 def export_output_names(self):
     return ["ctc_logits", "encoder_out_lens"]
 
-
 def export_dynamic_axes(self):
     return {
         "speech": {0: "batch_size", 1: "feats_length"},
-        "speech_lengths": {
-            0: "batch_size",
-        },
-        "logits": {0: "batch_size", 1: "logits_length"},
+        "speech_lengths": {0: "batch_size"},
+        "language": {0: "batch_size"},
+        "textnorm": {0: "batch_size"},
+        "ctc_logits": {0: "batch_size", 1: "logits_length"},
+        "encoder_out_lens":  {0: "batch_size"},
     }
 
-
-def export_name(
-    self,
-):
+def export_name(self):
     return "model.onnx"

--
Gitblit v1.9.1