From e299cfecaf979833d9c4d7c70e44cb92ea066afe Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 09 五月 2024 20:02:37 +0800
Subject: [PATCH] total_time/accum_grad
---
funasr/models/bicif_paraformer/export_meta.py | 39 +++++++++++++++++----------------------
1 files changed, 17 insertions(+), 22 deletions(-)
diff --git a/funasr/models/bicif_paraformer/export_meta.py b/funasr/models/bicif_paraformer/export_meta.py
index 7ae800e..e9d0a25 100644
--- a/funasr/models/bicif_paraformer/export_meta.py
+++ b/funasr/models/bicif_paraformer/export_meta.py
@@ -8,6 +8,7 @@
from funasr.register import tables
+
def export_rebuild_model(model, **kwargs):
is_onnx = kwargs.get("type", "onnx") == "onnx"
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
@@ -21,7 +22,7 @@
from funasr.utils.torch_function import sequence_mask
- model.make_pad_mask = sequence_mask(kwargs['max_seq_len'], flip=False)
+ model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
model.forward = types.MethodType(export_forward, model)
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
@@ -31,6 +32,7 @@
model.export_name = types.MethodType(export_name, model)
return model
+
def export_forward(
self,
@@ -53,39 +55,32 @@
return decoder_out, pre_token_length, us_alphas, us_cif_peak
+
def export_dummy_inputs(self):
speech = torch.randn(2, 30, 560)
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
return (speech, speech_lengths)
+
def export_input_names(self):
- return ['speech', 'speech_lengths']
+ return ["speech", "speech_lengths"]
+
def export_output_names(self):
- return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
+ return ["logits", "token_num", "us_alphas", "us_cif_peak"]
+
def export_dynamic_axes(self):
return {
- 'speech': {
- 0: 'batch_size',
- 1: 'feats_length'
+ "speech": {0: "batch_size", 1: "feats_length"},
+ "speech_lengths": {
+ 0: "batch_size",
},
- 'speech_lengths': {
- 0: 'batch_size',
- },
- 'logits': {
- 0: 'batch_size',
- 1: 'logits_length'
- },
- 'us_alphas': {
- 0: 'batch_size',
- 1: 'alphas_length'
- },
- 'us_cif_peak': {
- 0: 'batch_size',
- 1: 'alphas_length'
- },
+ "logits": {0: "batch_size", 1: "logits_length"},
+ "us_alphas": {0: "batch_size", 1: "alphas_length"},
+ "us_cif_peak": {0: "batch_size", 1: "alphas_length"},
}
+
def export_name(self):
- return "model.onnx"
\ No newline at end of file
+ return "model.onnx"
--
Gitblit v1.9.1