From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords
---
funasr/models/ct_transformer_streaming/export_meta.py | 47 +++++++++++++++++++++++------------------------
1 files changed, 23 insertions(+), 24 deletions(-)
diff --git a/funasr/models/ct_transformer_streaming/export_meta.py b/funasr/models/ct_transformer_streaming/export_meta.py
index e4745d6..ba0283e 100644
--- a/funasr/models/ct_transformer_streaming/export_meta.py
+++ b/funasr/models/ct_transformer_streaming/export_meta.py
@@ -9,25 +9,28 @@
def export_rebuild_model(model, **kwargs):
-
+
is_onnx = kwargs.get("type", "onnx") == "onnx"
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
model.encoder = encoder_class(model.encoder, onnx=is_onnx)
-
+
model.forward = types.MethodType(export_forward, model)
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
model.export_input_names = types.MethodType(export_input_names, model)
model.export_output_names = types.MethodType(export_output_names, model)
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model)
-
+
return model
-def export_forward(self, inputs: torch.Tensor,
- text_lengths: torch.Tensor,
- vad_indexes: torch.Tensor,
- sub_masks: torch.Tensor,
- ):
+
+def export_forward(
+ self,
+ inputs: torch.Tensor,
+ text_lengths: torch.Tensor,
+ vad_indexes: torch.Tensor,
+ sub_masks: torch.Tensor,
+):
"""Compute loss value from buffer sequences.
Args:
@@ -41,6 +44,7 @@
y = self.decoder(h)
return y
+
def export_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
@@ -50,28 +54,23 @@
sub_masks = torch.tril(sub_masks).type(torch.float32)
return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
+
def export_input_names(self):
- return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
+ return ["inputs", "text_lengths", "vad_masks", "sub_masks"]
+
def export_output_names(self):
- return ['logits']
+ return ["logits"]
+
def export_dynamic_axes(self):
return {
- 'inputs': {
- 1: 'feats_length'
- },
- 'vad_masks': {
- 2: 'feats_length1',
- 3: 'feats_length2'
- },
- 'sub_masks': {
- 2: 'feats_length1',
- 3: 'feats_length2'
- },
- 'logits': {
- 1: 'logits_length'
- },
+ "inputs": {1: "feats_length"},
+ "vad_masks": {2: "feats_length1", 3: "feats_length2"},
+ "sub_masks": {2: "feats_length1", 3: "feats_length2"},
+ "logits": {1: "logits_length"},
}
+
+
def export_name(self):
return "model.onnx"
--
Gitblit v1.9.1