From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/contextual_paraformer/export_meta.py | 90 ++++++++++++++++++++++++++++-----------------
1 files changed, 56 insertions(+), 34 deletions(-)
diff --git a/funasr/models/contextual_paraformer/export_meta.py b/funasr/models/contextual_paraformer/export_meta.py
index 6e1067a..9d3a63b 100644
--- a/funasr/models/contextual_paraformer/export_meta.py
+++ b/funasr/models/contextual_paraformer/export_meta.py
@@ -11,56 +11,82 @@
class ContextualEmbedderExport2(ContextualEmbedderExport):
- def __init__(self,
- model,
- **kwargs):
+ def __init__(self, model, **kwargs):
super().__init__(model)
self.embedding = model.bias_embed
model.bias_encoder.batch_first = False
self.bias_encoder = model.bias_encoder
+
+ def export_dummy_inputs(self):
+ hotword = torch.tensor(
+ [
+ [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
+ [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
+ [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ ],
+ dtype=torch.int32,
+ )
+ # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
+ return (hotword)
def export_rebuild_model(model, **kwargs):
is_onnx = kwargs.get("type", "onnx") == "onnx"
-
+
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
model.encoder = encoder_class(model.encoder, onnx=is_onnx)
predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
model.predictor = predictor_class(model.predictor, onnx=is_onnx)
-
+
# little difference with bias encoder with seaco paraformer
embedder_class = ContextualEmbedderExport2
embedder_model = embedder_class(model, onnx=is_onnx)
-
+
if kwargs["decoder"] == "ParaformerSANMDecoder":
kwargs["decoder"] = "ParaformerSANMDecoderOnline"
decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
model.decoder = decoder_class(model.decoder, onnx=is_onnx)
from funasr.utils.torch_function import sequence_mask
+
model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
model.feats_dim = 560
import copy
+
backbone_model = copy.copy(model)
# backbone
backbone_model.forward = types.MethodType(export_backbone_forward, backbone_model)
- backbone_model.export_dummy_inputs = types.MethodType(export_backbone_dummy_inputs, backbone_model)
- backbone_model.export_input_names = types.MethodType(export_backbone_input_names, backbone_model)
- backbone_model.export_output_names = types.MethodType(export_backbone_output_names, backbone_model)
- backbone_model.export_dynamic_axes = types.MethodType(export_backbone_dynamic_axes, backbone_model)
- backbone_model.export_name = types.MethodType(export_backbone_name, backbone_model)
+ backbone_model.export_dummy_inputs = types.MethodType(
+ export_backbone_dummy_inputs, backbone_model
+ )
+ backbone_model.export_input_names = types.MethodType(
+ export_backbone_input_names, backbone_model
+ )
+ backbone_model.export_output_names = types.MethodType(
+ export_backbone_output_names, backbone_model
+ )
+ backbone_model.export_dynamic_axes = types.MethodType(
+ export_backbone_dynamic_axes, backbone_model
+ )
+ embedder_model.export_name = "model_eb"
+ backbone_model.export_name = "model"
+
return backbone_model, embedder_model
+
def export_backbone_forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- bias_embed: torch.Tensor,
- ):
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ bias_embed: torch.Tensor,
+):
batch = {"speech": speech, "speech_lengths": speech_lengths}
enc, enc_len = self.encoder(**batch)
@@ -73,36 +99,32 @@
return decoder_out, pre_token_length
+
def export_backbone_dummy_inputs(self):
speech = torch.randn(2, 30, self.feats_dim)
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
bias_embed = torch.randn(2, 1, 512)
return (speech, speech_lengths, bias_embed)
+
def export_backbone_input_names(self):
- return ['speech', 'speech_lengths', 'bias_embed']
+ return ["speech", "speech_lengths", "bias_embed"]
+
def export_backbone_output_names(self):
- return ['logits', 'token_num']
+ return ["logits", "token_num"]
+
def export_backbone_dynamic_axes(self):
return {
- 'speech': {
- 0: 'batch_size',
- 1: 'feats_length'
+ "speech": {0: "batch_size", 1: "feats_length"},
+ "speech_lengths": {
+ 0: "batch_size",
},
- 'speech_lengths': {
- 0: 'batch_size',
- },
- 'bias_embed': {
- 0: 'batch_size',
- 1: 'num_hotwords'
- },
- 'logits': {
- 0: 'batch_size',
- 1: 'logits_length'
- },
+ "bias_embed": {0: "batch_size", 1: "num_hotwords"},
+ "logits": {0: "batch_size", 1: "logits_length"},
}
-
+
+
def export_backbone_name(self):
- return 'model.onnx'
\ No newline at end of file
+ return "model.onnx"
--
Gitblit v1.9.1