From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/ct_transformer/export_meta.py | 34 +++++++++++++++++-----------------
1 files changed, 17 insertions(+), 17 deletions(-)
diff --git a/funasr/models/ct_transformer/export_meta.py b/funasr/models/ct_transformer/export_meta.py
index 691f211..c222008 100644
--- a/funasr/models/ct_transformer/export_meta.py
+++ b/funasr/models/ct_transformer/export_meta.py
@@ -9,18 +9,18 @@
def export_rebuild_model(model, **kwargs):
-
+
is_onnx = kwargs.get("type", "onnx") == "onnx"
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
model.encoder = encoder_class(model.encoder, onnx=is_onnx)
-
+
model.forward = types.MethodType(export_forward, model)
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
model.export_input_names = types.MethodType(export_input_names, model)
model.export_output_names = types.MethodType(export_output_names, model)
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model)
-
+
return model
@@ -37,31 +37,31 @@
y = self.decoder(h)
return y
+
def export_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
- text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
+ text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
return (text_indexes, text_lengths)
+
def export_input_names(self):
- return ['inputs', 'text_lengths']
+ return ["inputs", "text_lengths"]
+
def export_output_names(self):
- return ['logits']
+ return ["logits"]
+
def export_dynamic_axes(self):
return {
- 'inputs': {
- 0: 'batch_size',
- 1: 'feats_length'
+ "inputs": {0: "batch_size", 1: "feats_length"},
+ "text_lengths": {
+ 0: "batch_size",
},
- 'text_lengths': {
- 0: 'batch_size',
- },
- 'logits': {
- 0: 'batch_size',
- 1: 'logits_length'
- },
+ "logits": {0: "batch_size", 1: "logits_length"},
}
+
+
def export_name(self):
- return "model.onnx"
\ No newline at end of file
+ return "model.onnx"
--
Gitblit v1.9.1