From c4ac64fd5d24bb3fc8ccc441d36a07c83c8b9015 Mon Sep 17 00:00:00 2001
From: Yu Cao <monstercy@hotmail.com>
Date: 星期三, 01 十月 2025 14:46:21 +0800
Subject: [PATCH] fix "can not find model issue when running libtorch runtime" (#2504)

---
 funasr/models/paraformer/export_meta.py |   80 ++++++++++++++++++++-------------------
 1 files changed, 41 insertions(+), 39 deletions(-)

diff --git a/funasr/models/paraformer/export_meta.py b/funasr/models/paraformer/export_meta.py
index 4d491e9..8e086a2 100644
--- a/funasr/models/paraformer/export_meta.py
+++ b/funasr/models/paraformer/export_meta.py
@@ -9,36 +9,37 @@
 
 
 def export_rebuild_model(model, **kwargs):
-        model.device = kwargs.get("device")
-        is_onnx = kwargs.get("type", "onnx") == "onnx"
-        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
-        model.encoder = encoder_class(model.encoder, onnx=is_onnx)
-        
-        predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
-        model.predictor = predictor_class(model.predictor, onnx=is_onnx)
+    model.device = kwargs.get("device")
+    is_onnx = kwargs.get("type", "onnx") == "onnx"
+    encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+    model.encoder = encoder_class(model.encoder, onnx=is_onnx)
 
+    predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
+    model.predictor = predictor_class(model.predictor, onnx=is_onnx)
 
-        decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
-        model.decoder = decoder_class(model.decoder, onnx=is_onnx)
-        
-        from funasr.utils.torch_function import sequence_mask
-        model.make_pad_mask = sequence_mask(kwargs['max_seq_len'], flip=False)
-        
-        model.forward = types.MethodType(export_forward, model)
-        model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
-        model.export_input_names = types.MethodType(export_input_names, model)
-        model.export_output_names = types.MethodType(export_output_names, model)
-        model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
-        model.export_name = types.MethodType(export_name, model)
-        
-        return model
+    decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
+    model.decoder = decoder_class(model.decoder, onnx=is_onnx)
+
+    from funasr.utils.torch_function import sequence_mask
+
+    model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
+
+    model.forward = types.MethodType(export_forward, model)
+    model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
+    model.export_input_names = types.MethodType(export_input_names, model)
+    model.export_output_names = types.MethodType(export_output_names, model)
+    model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
+    model.export_name = types.MethodType(export_name, model)
+
+    model.export_name = "model"
+    return model
 
 
 def export_forward(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-    ):
+    self,
+    speech: torch.Tensor,
+    speech_lengths: torch.Tensor,
+):
     # a. To device
     batch = {"speech": speech, "speech_lengths": speech_lengths}
     # batch = to_device(batch, device=self.device)
@@ -54,6 +55,7 @@
 
     return decoder_out, pre_token_length
 
+
 def export_dummy_inputs(self):
     speech = torch.randn(2, 30, 560)
     speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
@@ -61,25 +63,25 @@
 
 
 def export_input_names(self):
-    return ['speech', 'speech_lengths']
+    return ["speech", "speech_lengths"]
+
 
 def export_output_names(self):
-    return ['logits', 'token_num']
+    return ["logits", "token_num"]
+
 
 def export_dynamic_axes(self):
     return {
-        'speech': {
-            0: 'batch_size',
-            1: 'feats_length'
+        "speech": {0: "batch_size", 1: "feats_length"},
+        "speech_lengths": {
+            0: "batch_size",
         },
-        'speech_lengths': {
-            0: 'batch_size',
-        },
-        'logits': {
-            0: 'batch_size',
-            1: 'logits_length'
-        },
+        "logits": {0: "batch_size", 1: "logits_length"},
+        "token_num": {0: "batch_size"}
     }
 
-def export_name(self, ):
-    return "model.onnx"
\ No newline at end of file
+
+def export_name(
+    self,
+):
+    return "model.onnx"

--
Gitblit v1.9.1