From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/ct_transformer/export_meta.py |   34 +++++++++++++++++-----------------
 1 files changed, 17 insertions(+), 17 deletions(-)

diff --git a/funasr/models/ct_transformer/export_meta.py b/funasr/models/ct_transformer/export_meta.py
index 691f211..c222008 100644
--- a/funasr/models/ct_transformer/export_meta.py
+++ b/funasr/models/ct_transformer/export_meta.py
@@ -9,18 +9,18 @@
 
 
 def export_rebuild_model(model, **kwargs):
-    
+
     is_onnx = kwargs.get("type", "onnx") == "onnx"
     encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
     model.encoder = encoder_class(model.encoder, onnx=is_onnx)
-    
+
     model.forward = types.MethodType(export_forward, model)
     model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
     model.export_input_names = types.MethodType(export_input_names, model)
     model.export_output_names = types.MethodType(export_output_names, model)
     model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
     model.export_name = types.MethodType(export_name, model)
-        
+
     return model
 
 
@@ -37,31 +37,31 @@
     y = self.decoder(h)
     return y
 
+
 def export_dummy_inputs(self):
     length = 120
     text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
-    text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
+    text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
     return (text_indexes, text_lengths)
 
+
 def export_input_names(self):
-    return ['inputs', 'text_lengths']
+    return ["inputs", "text_lengths"]
+
 
 def export_output_names(self):
-    return ['logits']
+    return ["logits"]
+
 
 def export_dynamic_axes(self):
     return {
-        'inputs': {
-            0: 'batch_size',
-            1: 'feats_length'
+        "inputs": {0: "batch_size", 1: "feats_length"},
+        "text_lengths": {
+            0: "batch_size",
         },
-        'text_lengths': {
-            0: 'batch_size',
-        },
-        'logits': {
-            0: 'batch_size',
-            1: 'logits_length'
-        },
+        "logits": {0: "batch_size", 1: "logits_length"},
     }
+
+
 def export_name(self):
-    return "model.onnx"
\ No newline at end of file
+    return "model.onnx"

--
Gitblit v1.9.1