From 6ca0b838d48106030984eacf204e8f1f2f05985b Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 13 六月 2024 16:07:49 +0800
Subject: [PATCH] decoding
---
funasr/models/paraformer_streaming/export_meta.py | 98 +++++++++++++++++++++++--------------------------
1 files changed, 46 insertions(+), 52 deletions(-)
diff --git a/funasr/models/paraformer_streaming/export_meta.py b/funasr/models/paraformer_streaming/export_meta.py
index 0193dc8..9740ecc 100644
--- a/funasr/models/paraformer_streaming/export_meta.py
+++ b/funasr/models/paraformer_streaming/export_meta.py
@@ -9,29 +9,29 @@
def export_rebuild_model(model, **kwargs):
- model.device = kwargs.get("device")
- is_onnx = kwargs.get("type", "onnx") == "onnx"
- encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
- model.encoder = encoder_class(model.encoder, onnx=is_onnx)
-
- predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
- model.predictor = predictor_class(model.predictor, onnx=is_onnx)
+ model.device = kwargs.get("device")
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+ model.encoder = encoder_class(model.encoder, onnx=is_onnx)
+ predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
+ model.predictor = predictor_class(model.predictor, onnx=is_onnx)
- decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
- model.decoder = decoder_class(model.decoder, onnx=is_onnx)
-
- from funasr.utils.torch_function import sequence_mask
- model.make_pad_mask = sequence_mask(kwargs['max_seq_len'], flip=False)
-
- model.forward = types.MethodType(export_forward, model)
- model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
- model.export_input_names = types.MethodType(export_input_names, model)
- model.export_output_names = types.MethodType(export_output_names, model)
- model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
- model.export_name = types.MethodType(export_name, model)
-
- return model
+ decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
+ model.decoder = decoder_class(model.decoder, onnx=is_onnx)
+
+ from funasr.utils.torch_function import sequence_mask
+
+ model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
+
+ model.forward = types.MethodType(export_forward, model)
+ model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
+ model.export_input_names = types.MethodType(export_input_names, model)
+ model.export_output_names = types.MethodType(export_output_names, model)
+ model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
+ model.export_name = types.MethodType(export_name, model)
+
+ return model
def export_rebuild_model(model, **kwargs):
@@ -39,24 +39,25 @@
is_onnx = kwargs.get("type", "onnx") == "onnx"
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
model.encoder = encoder_class(model.encoder, onnx=is_onnx)
-
+
predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
model.predictor = predictor_class(model.predictor, onnx=is_onnx)
-
+
if kwargs["decoder"] == "ParaformerSANMDecoder":
kwargs["decoder"] = "ParaformerSANMDecoderOnline"
decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
model.decoder = decoder_class(model.decoder, onnx=is_onnx)
-
+
from funasr.utils.torch_function import sequence_mask
-
+
model.make_pad_mask = sequence_mask(max_seq_len=None, flip=False)
-
+
import copy
import types
+
encoder_model = copy.copy(model)
decoder_model = copy.copy(model)
-
+
# encoder
encoder_model.forward = types.MethodType(export_encoder_forward, encoder_model)
encoder_model.export_dummy_inputs = types.MethodType(export_encoder_dummy_inputs, encoder_model)
@@ -64,7 +65,7 @@
encoder_model.export_output_names = types.MethodType(export_encoder_output_names, encoder_model)
encoder_model.export_dynamic_axes = types.MethodType(export_encoder_dynamic_axes, encoder_model)
encoder_model.export_name = types.MethodType(export_encoder_name, encoder_model)
-
+
# decoder
decoder_model.forward = types.MethodType(export_decoder_forward, decoder_model)
decoder_model.export_dummy_inputs = types.MethodType(export_decoder_dummy_inputs, decoder_model)
@@ -72,7 +73,7 @@
decoder_model.export_output_names = types.MethodType(export_decoder_output_names, decoder_model)
decoder_model.export_dynamic_axes = types.MethodType(export_decoder_dynamic_axes, decoder_model)
decoder_model.export_name = types.MethodType(export_decoder_name, decoder_model)
-
+
return encoder_model, decoder_model
@@ -84,11 +85,11 @@
# a. To device
batch = {"speech": speech, "speech_lengths": speech_lengths, "online": True}
# batch = to_device(batch, device=self.device)
-
+
enc, enc_len = self.encoder(**batch)
mask = self.make_pad_mask(enc_len)[:, None, :]
alphas, _ = self.predictor.forward_cnn(enc, mask)
-
+
return enc, enc_len, alphas
@@ -99,33 +100,24 @@
def export_encoder_input_names(self):
- return ['speech', 'speech_lengths']
+ return ["speech", "speech_lengths"]
def export_encoder_output_names(self):
- return ['enc', 'enc_len', 'alphas']
+ return ["enc", "enc_len", "alphas"]
def export_encoder_dynamic_axes(self):
return {
- 'speech': {
- 0: 'batch_size',
- 1: 'feats_length'
+ "speech": {0: "batch_size", 1: "feats_length"},
+ "speech_lengths": {
+ 0: "batch_size",
},
- 'speech_lengths': {
- 0: 'batch_size',
+ "enc": {0: "batch_size", 1: "feats_length"},
+ "enc_len": {
+ 0: "batch_size",
},
- 'enc': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
- 'enc_len': {
- 0: 'batch_size',
- },
- 'alphas': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
+ "alphas": {0: "batch_size", 1: "feats_length"},
}
@@ -141,9 +133,11 @@
acoustic_embeds_len: torch.Tensor,
*args,
):
- decoder_out, out_caches = self.decoder(enc, enc_len, acoustic_embeds, acoustic_embeds_len, *args)
+ decoder_out, out_caches = self.decoder(
+ enc, enc_len, acoustic_embeds, acoustic_embeds_len, *args
+ )
sample_ids = decoder_out.argmax(dim=-1)
-
+
return decoder_out, sample_ids, out_caches
@@ -165,4 +159,4 @@
def export_decoder_name(self):
- return "decoder.onnx"
\ No newline at end of file
+ return "decoder.onnx"
--
Gitblit v1.9.1