From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/models/ct_transformer_streaming/export_meta.py |   47 +++++++++++++++++++++++------------------------
 1 files changed, 23 insertions(+), 24 deletions(-)

diff --git a/funasr/models/ct_transformer_streaming/export_meta.py b/funasr/models/ct_transformer_streaming/export_meta.py
index e4745d6..ba0283e 100644
--- a/funasr/models/ct_transformer_streaming/export_meta.py
+++ b/funasr/models/ct_transformer_streaming/export_meta.py
@@ -9,25 +9,28 @@
 
 
 def export_rebuild_model(model, **kwargs):
-    
+
     is_onnx = kwargs.get("type", "onnx") == "onnx"
     encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
     model.encoder = encoder_class(model.encoder, onnx=is_onnx)
-    
+
     model.forward = types.MethodType(export_forward, model)
     model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
     model.export_input_names = types.MethodType(export_input_names, model)
     model.export_output_names = types.MethodType(export_output_names, model)
     model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
     model.export_name = types.MethodType(export_name, model)
-        
+
     return model
 
-def export_forward(self, inputs: torch.Tensor,
-            text_lengths: torch.Tensor,
-            vad_indexes: torch.Tensor,
-            sub_masks: torch.Tensor,
-            ):
+
+def export_forward(
+    self,
+    inputs: torch.Tensor,
+    text_lengths: torch.Tensor,
+    vad_indexes: torch.Tensor,
+    sub_masks: torch.Tensor,
+):
     """Compute loss value from buffer sequences.
 
     Args:
@@ -41,6 +44,7 @@
     y = self.decoder(h)
     return y
 
+
 def export_dummy_inputs(self):
     length = 120
     text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
@@ -50,28 +54,23 @@
     sub_masks = torch.tril(sub_masks).type(torch.float32)
     return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
 
+
 def export_input_names(self):
-    return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
+    return ["inputs", "text_lengths", "vad_masks", "sub_masks"]
+
 
 def export_output_names(self):
-    return ['logits']
+    return ["logits"]
+
 
 def export_dynamic_axes(self):
     return {
-        'inputs': {
-            1: 'feats_length'
-        },
-        'vad_masks': {
-            2: 'feats_length1',
-            3: 'feats_length2'
-        },
-        'sub_masks': {
-            2: 'feats_length1',
-            3: 'feats_length2'
-        },
-        'logits': {
-            1: 'logits_length'
-        },
+        "inputs": {1: "feats_length"},
+        "vad_masks": {2: "feats_length1", 3: "feats_length2"},
+        "sub_masks": {2: "feats_length1", 3: "feats_length2"},
+        "logits": {1: "logits_length"},
     }
+
+
 def export_name(self):
     return "model.onnx"

--
Gitblit v1.9.1