From e04489ce4c0fd0095d0c79ef8f504f425e0435a8 Mon Sep 17 00:00:00 2001
From: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Date: 星期三, 13 三月 2024 16:34:42 +0800
Subject: [PATCH] contextual&seaco ONNX export (#1481)
---
funasr/models/paraformer/model.py | 85 -
funasr/models/sond/encoder/self_attention_encoder.py | 150 ---
funasr/models/sond/encoder/ci_scorers.py | 6
funasr/models/seaco_paraformer/export_meta.py | 181 +++
funasr/models/sond/encoder/fsmn_encoder.py | 137 --
funasr/models/contextual_paraformer/export_meta.py | 108 ++
funasr/models/seaco_paraformer/model.py | 11
runtime/python/onnxruntime/funasr_onnx/__init__.py | 2
funasr/models/bicif_paraformer/model.py | 89 -
funasr/models/bicif_paraformer/export_meta.py | 91 +
funasr/models/sanm/attention.py | 7
funasr/models/scama/encoder.py | 157 ---
funasr/models/sond/encoder/resnet34_encoder.py | 422 --------
funasr/models/fsmn_vad_streaming/model.py | 48
runtime/python/onnxruntime/demo_contextual_paraformer.py | 6
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py | 36
funasr/models/contextual_paraformer/model.py | 19
funasr/models/fsmn_vad_streaming/export_meta.py | 59 +
funasr/models/scama/decoder.py | 380 -------
funasr/models/transformer/positionwise_feed_forward.py | 13
funasr/models/contextual_paraformer/decoder.py | 563 ++---------
funasr/models/sond/encoder/conv_encoder.py | 100 --
funasr/models/paraformer/decoder.py | 76 +
funasr/models/paraformer/export_meta.py | 85 +
runtime/python/onnxruntime/demo_paraformer_offline.py | 6
runtime/python/onnxruntime/demo_seaco_paraformer.py | 12
funasr/models/sond/pooling/statistic_pooling.py | 3
runtime/python/onnxruntime/demo_paraformer_online.py | 2
funasr/models/transformer/utils/subsampling.py | 43
29 files changed, 807 insertions(+), 2,090 deletions(-)
diff --git a/funasr/models/bicif_paraformer/export_meta.py b/funasr/models/bicif_paraformer/export_meta.py
new file mode 100644
index 0000000..7ae800e
--- /dev/null
+++ b/funasr/models/bicif_paraformer/export_meta.py
@@ -0,0 +1,91 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import torch
+import types
+
+from funasr.register import tables
+
+def export_rebuild_model(model, **kwargs):
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+ model.encoder = encoder_class(model.encoder, onnx=is_onnx)
+
+ predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
+ model.predictor = predictor_class(model.predictor, onnx=is_onnx)
+
+ decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
+ model.decoder = decoder_class(model.decoder, onnx=is_onnx)
+
+ from funasr.utils.torch_function import sequence_mask
+
+ model.make_pad_mask = sequence_mask(kwargs['max_seq_len'], flip=False)
+
+ model.forward = types.MethodType(export_forward, model)
+ model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
+ model.export_input_names = types.MethodType(export_input_names, model)
+ model.export_output_names = types.MethodType(export_output_names, model)
+ model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
+ model.export_name = types.MethodType(export_name, model)
+
+ return model
+
+def export_forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+):
+ # a. To device
+ batch = {"speech": speech, "speech_lengths": speech_lengths}
+
+ enc, enc_len = self.encoder(**batch)
+ mask = self.make_pad_mask(enc_len)[:, None, :]
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+ pre_token_length = pre_token_length.round().type(torch.int32)
+
+ decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+
+ # get predicted timestamps
+ us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
+
+ return decoder_out, pre_token_length, us_alphas, us_cif_peak
+
+def export_dummy_inputs(self):
+ speech = torch.randn(2, 30, 560)
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+ return (speech, speech_lengths)
+
+def export_input_names(self):
+ return ['speech', 'speech_lengths']
+
+def export_output_names(self):
+ return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
+
+def export_dynamic_axes(self):
+ return {
+ 'speech': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'speech_lengths': {
+ 0: 'batch_size',
+ },
+ 'logits': {
+ 0: 'batch_size',
+ 1: 'logits_length'
+ },
+ 'us_alphas': {
+ 0: 'batch_size',
+ 1: 'alphas_length'
+ },
+ 'us_cif_peak': {
+ 0: 'batch_size',
+ 1: 'alphas_length'
+ },
+ }
+
+def export_name(self):
+ return "model.onnx"
\ No newline at end of file
diff --git a/funasr/models/bicif_paraformer/model.py b/funasr/models/bicif_paraformer/model.py
index 9849c8c..6f37dd4 100644
--- a/funasr/models/bicif_paraformer/model.py
+++ b/funasr/models/bicif_paraformer/model.py
@@ -343,86 +343,9 @@
return results, meta_data
- def export(
- self,
- max_seq_len=512,
- **kwargs,
- ):
- self.device = kwargs.get("device")
- is_onnx = kwargs.get("type", "onnx") == "onnx"
- encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
- self.encoder = encoder_class(self.encoder, onnx=is_onnx)
-
- predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
- self.predictor = predictor_class(self.predictor, onnx=is_onnx)
-
- decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
- self.decoder = decoder_class(self.decoder, onnx=is_onnx)
-
- from funasr.utils.torch_function import sequence_mask
-
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
-
- self.forward = self.export_forward
-
- return self
-
- def export_forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- ):
- # a. To device
- batch = {"speech": speech, "speech_lengths": speech_lengths}
- batch = to_device(batch, device=self.device)
-
- enc, enc_len = self.encoder(**batch)
- mask = self.make_pad_mask(enc_len)[:, None, :]
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
- pre_token_length = pre_token_length.round().type(torch.int32)
-
- decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
-
- # get predicted timestamps
- us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
-
- return decoder_out, pre_token_length, us_alphas, us_cif_peak
-
- def export_dummy_inputs(self):
- speech = torch.randn(2, 30, 560)
- speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
- return (speech, speech_lengths)
-
- def export_input_names(self):
- return ['speech', 'speech_lengths']
-
- def export_output_names(self):
- return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
-
- def export_dynamic_axes(self):
- return {
- 'speech': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
- 'speech_lengths': {
- 0: 'batch_size',
- },
- 'logits': {
- 0: 'batch_size',
- 1: 'logits_length'
- },
- 'us_alphas': {
- 0: 'batch_size',
- 1: 'alphas_length'
- },
- 'us_cif_peak': {
- 0: 'batch_size',
- 1: 'alphas_length'
- },
- }
-
- def export_name(self, ):
- return "model.onnx"
+ def export(self, **kwargs):
+ from .export_meta import export_rebuild_model
+ if 'max_seq_len' not in kwargs:
+ kwargs['max_seq_len'] = 512
+ models = export_rebuild_model(model=self, **kwargs)
+ return models
diff --git a/funasr/models/contextual_paraformer/decoder.py b/funasr/models/contextual_paraformer/decoder.py
index c872547..1116f84 100644
--- a/funasr/models/contextual_paraformer/decoder.py
+++ b/funasr/models/contextual_paraformer/decoder.py
@@ -305,473 +305,128 @@
x = self.output_layer(x)
return x, olens
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
+@tables.register("decoder_classes", "ContextualParaformerDecoderExport")
+class ContextualParaformerDecoderExport(torch.nn.Module):
+ def __init__(self, model,
+ max_seq_len=512,
+ model_name='decoder',
+ onnx: bool = True,
+ **kwargs,):
+ super().__init__()
+ from funasr.utils.torch_function import sequence_mask
+ self.model = model
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
+ from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
+ from funasr.models.paraformer.decoder import DecoderLayerSANMExport
+ from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANMExport
- ## decoder
- # ffn
- "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
+ for i, d in enumerate(self.model.decoders):
+ if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
+ d.feed_forward = PositionwiseFeedForwardDecoderSANMExport(d.feed_forward)
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+ d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
+ self.model.decoders[i] = DecoderLayerSANMExport(d)
- # fsmn
- "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
- tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
- tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
- tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 2, 0),
- }, # (256,1,31),(1,31,256,1)
- # src att
- "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # dnn
- "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
+ if self.model.decoders2 is not None:
+ for i, d in enumerate(self.model.decoders2):
+ if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
+ d.feed_forward = PositionwiseFeedForwardDecoderSANMExport(d.feed_forward)
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ self.model.decoders2[i] = DecoderLayerSANMExport(d)
- # embed_concat_ffn
- "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
+ for i, d in enumerate(self.model.decoders3):
+ if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
+ d.feed_forward = PositionwiseFeedForwardDecoderSANMExport(d.feed_forward)
+ self.model.decoders3[i] = DecoderLayerSANMExport(d)
+
+ self.output_layer = model.output_layer
+ self.after_norm = model.after_norm
+ self.model_name = model_name
- # out norm
- "{}.after_norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.after_norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
+ # bias decoder
+ if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt):
+ self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAttExport(self.model.bias_decoder.src_attn)
+ self.bias_decoder = self.model.bias_decoder
+
+ # last decoder
+ if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt):
+ self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAttExport(self.model.last_decoder.src_attn)
+ if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder):
+ self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoderExport(self.model.last_decoder.self_attn)
+ if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM):
+ self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANMExport(self.model.last_decoder.feed_forward)
+ self.last_decoder = self.model.last_decoder
+ self.bias_output = self.model.bias_output
+ self.dropout = self.model.dropout
+
- # in embed
- "{}.embed.0.weight".format(tensor_name_prefix_torch):
- {"name": "{}/w_embs".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (4235,256),(4235,256)
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
- # out layer
- "{}.output_layer.weight".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
- "squeeze": [None, None],
- "transpose": [(1, 0), None],
- }, # (4235,256),(256,4235)
- "{}.output_layer.bias".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
- "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
- "squeeze": [None, None],
- "transpose": [None, None],
- }, # (4235,),(4235,)
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ bias_embed: torch.Tensor,
+ ):
- ## clas decoder
- # src att
- "{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # dnn
- "{}.bias_output.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- }, # (1024,256),(1,256,1024)
+ tgt = ys_in_pad
+ tgt_mask = self.make_pad_mask(ys_in_lens)
+ tgt_mask, _ = self.prepare_mask(tgt_mask)
+ # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
- }
- return map_dict_local
+ memory = hs_pad
+ memory_mask = self.make_pad_mask(hlens)
+ _, memory_mask = self.prepare_mask(memory_mask)
+ # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
- map_dict = self.gen_tf2torch_map_dict()
- var_dict_torch_update = dict()
- decoder_layeridx_sets = set()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- if names[1] == "decoders":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- layeridx_bias = 0
- layeridx += layeridx_bias
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
- elif names[1] == "last_decoder":
- layeridx = 15
- name_q = name.replace("last_decoder", "decoders.layeridx")
- layeridx_bias = 0
- layeridx += layeridx_bias
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
+ x = tgt
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+ _, _, x_self_attn, x_src_attn = self.last_decoder(
+ x, tgt_mask, memory, memory_mask
+ )
- elif names[1] == "decoders2":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- name_q = name_q.replace("decoders2", "decoders")
- layeridx_bias = len(decoder_layeridx_sets)
+ # contextual paraformer related
+ contextual_length = torch.Tensor([bias_embed.shape[1]]).int().repeat(hs_pad.shape[0])
+ # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
+ contextual_mask = self.make_pad_mask(contextual_length)
+ contextual_mask, _ = self.prepare_mask(contextual_mask)
+ # import pdb; pdb.set_trace()
+ contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1)
+ cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask)
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
+ if self.bias_output is not None:
+ x = torch.cat([x_src_attn, cx], dim=2)
+ x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
+ x = x_self_attn + self.dropout(x)
- elif names[1] == "decoders3":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+ if self.model.decoders2 is not None:
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
+ x, tgt_mask, memory, memory_mask
+ )
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
+ x, tgt_mask, memory, memory_mask
+ )
+ x = self.after_norm(x)
+ x = self.output_layer(x)
- layeridx_bias = 0
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
- elif names[1] == "bias_decoder":
- name_q = name
+ return x, ys_in_lens
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
-
- elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
- name_tf = map_dict[name]["name"]
- if isinstance(name_tf, list):
- idx_list = 0
- if name_tf[idx_list] in var_dict_tf.keys():
- pass
- else:
- idx_list = 1
- data_tf = var_dict_tf[name_tf[idx_list]]
- if map_dict[name]["squeeze"][idx_list] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
- if map_dict[name]["transpose"][idx_list] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
- name_tf[idx_list],
- var_dict_tf[name_tf[
- idx_list]].shape))
-
- else:
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "after_norm":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "embed_concat_ffn":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- layeridx_bias = 0
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
diff --git a/funasr/models/contextual_paraformer/export_meta.py b/funasr/models/contextual_paraformer/export_meta.py
new file mode 100644
index 0000000..6e1067a
--- /dev/null
+++ b/funasr/models/contextual_paraformer/export_meta.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import torch
+import types
+
+from funasr.register import tables
+from funasr.models.seaco_paraformer.export_meta import ContextualEmbedderExport
+
+
+class ContextualEmbedderExport2(ContextualEmbedderExport):
+ def __init__(self,
+ model,
+ **kwargs):
+ super().__init__(model)
+ self.embedding = model.bias_embed
+ model.bias_encoder.batch_first = False
+ self.bias_encoder = model.bias_encoder
+
+
+def export_rebuild_model(model, **kwargs):
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
+
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+ model.encoder = encoder_class(model.encoder, onnx=is_onnx)
+
+ predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
+ model.predictor = predictor_class(model.predictor, onnx=is_onnx)
+
+ # little difference with bias encoder with seaco paraformer
+ embedder_class = ContextualEmbedderExport2
+ embedder_model = embedder_class(model, onnx=is_onnx)
+
+ if kwargs["decoder"] == "ParaformerSANMDecoder":
+ kwargs["decoder"] = "ParaformerSANMDecoderOnline"
+ decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
+ model.decoder = decoder_class(model.decoder, onnx=is_onnx)
+
+ from funasr.utils.torch_function import sequence_mask
+ model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
+ model.feats_dim = 560
+
+ import copy
+ backbone_model = copy.copy(model)
+
+ # backbone
+ backbone_model.forward = types.MethodType(export_backbone_forward, backbone_model)
+ backbone_model.export_dummy_inputs = types.MethodType(export_backbone_dummy_inputs, backbone_model)
+ backbone_model.export_input_names = types.MethodType(export_backbone_input_names, backbone_model)
+ backbone_model.export_output_names = types.MethodType(export_backbone_output_names, backbone_model)
+ backbone_model.export_dynamic_axes = types.MethodType(export_backbone_dynamic_axes, backbone_model)
+ backbone_model.export_name = types.MethodType(export_backbone_name, backbone_model)
+
+ return backbone_model, embedder_model
+
+def export_backbone_forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ bias_embed: torch.Tensor,
+ ):
+ batch = {"speech": speech, "speech_lengths": speech_lengths}
+
+ enc, enc_len = self.encoder(**batch)
+ mask = self.make_pad_mask(enc_len)[:, None, :]
+ pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(enc, mask)
+ pre_token_length = pre_token_length.floor().type(torch.int32)
+
+ decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, bias_embed)
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+
+ return decoder_out, pre_token_length
+
+def export_backbone_dummy_inputs(self):
+ speech = torch.randn(2, 30, self.feats_dim)
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+ bias_embed = torch.randn(2, 1, 512)
+ return (speech, speech_lengths, bias_embed)
+
+def export_backbone_input_names(self):
+ return ['speech', 'speech_lengths', 'bias_embed']
+
+def export_backbone_output_names(self):
+ return ['logits', 'token_num']
+
+def export_backbone_dynamic_axes(self):
+ return {
+ 'speech': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'speech_lengths': {
+ 0: 'batch_size',
+ },
+ 'bias_embed': {
+ 0: 'batch_size',
+ 1: 'num_hotwords'
+ },
+ 'logits': {
+ 0: 'batch_size',
+ 1: 'logits_length'
+ },
+ }
+
+def export_backbone_name(self):
+ return 'model.onnx'
\ No newline at end of file
diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index 18cab60..9968bf2 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -17,9 +17,6 @@
from distutils.version import LooseVersion
from funasr.register import tables
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
-)
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
from funasr.models.paraformer.model import Paraformer
@@ -29,7 +26,7 @@
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-import pdb
+
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -80,7 +77,6 @@
if self.crit_attn_weight > 0:
self.attn_loss = torch.nn.L1Loss()
self.crit_attn_smooth = crit_attn_smooth
-
def forward(
self,
@@ -156,7 +152,6 @@
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
-
def _calc_att_clas_loss(
self,
encoder_out: torch.Tensor,
@@ -231,7 +226,6 @@
return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
-
def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
ys_pad = ys_pad * tgt_mask[:, :, 0]
@@ -263,7 +257,6 @@
sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
input_mask_expand_dim, 0)
return sematic_embeds * tgt_mask, decoder_out * tgt_mask
-
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
clas_scale=1.0):
@@ -414,7 +407,6 @@
return results, meta_data
-
def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
def load_seg_dict(seg_dict_file):
seg_dict = {}
@@ -516,3 +508,12 @@
hotword_list = None
return hotword_list
+ def export(
+ self,
+ **kwargs,
+ ):
+ if 'max_seq_len' not in kwargs:
+ kwargs['max_seq_len'] = 512
+ from .export_meta import export_rebuild_model
+ models = export_rebuild_model(model=self, **kwargs)
+ return models
diff --git a/funasr/models/fsmn_vad_streaming/export_meta.py b/funasr/models/fsmn_vad_streaming/export_meta.py
new file mode 100644
index 0000000..7183026
--- /dev/null
+++ b/funasr/models/fsmn_vad_streaming/export_meta.py
@@ -0,0 +1,59 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import types
+import torch
+from funasr.register import tables
+
+
+def export_rebuild_model(model, **kwargs):
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+ model.encoder = encoder_class(model.encoder, onnx=is_onnx)
+
+ model.forward = types.MethodType(export_forward, model)
+ model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
+ model.export_input_names = types.MethodType(export_input_names, model)
+ model.export_output_names = types.MethodType(export_output_names, model)
+ model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
+ model.export_name = types.MethodType(export_name, model)
+
+ return model
+
+def export_forward(self, feats: torch.Tensor, *args, **kwargs):
+
+ scores, out_caches = self.encoder(feats, *args)
+
+ return scores, out_caches
+
+def export_dummy_inputs(self, data_in=None, frame=30):
+ if data_in is None:
+ speech = torch.randn(1, frame, self.encoder_conf.get("input_dim"))
+ else:
+ speech = None # Undo
+
+ cache_frames = self.encoder_conf.get("lorder") + self.encoder_conf.get("rorder") - 1
+ in_cache0 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
+ in_cache1 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
+ in_cache2 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
+ in_cache3 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
+
+ return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
+
+def export_input_names(self):
+ return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
+
+def export_output_names(self):
+ return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
+
+def export_dynamic_axes(self):
+ return {
+ 'speech': {
+ 1: 'feats_length'
+ },
+ }
+
+def export_name(self, ):
+ return "model.onnx"
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 602cf23..59536bb 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -644,49 +644,11 @@
return results, meta_data
def export(self, **kwargs):
- is_onnx = kwargs.get("type", "onnx") == "onnx"
- encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
- self.encoder = encoder_class(self.encoder, onnx=is_onnx)
- self.forward = self.export_forward
-
- return self
-
- def export_forward(self, feats: torch.Tensor, *args, **kwargs):
-
- scores, out_caches = self.encoder(feats, *args)
-
- return scores, out_caches
-
- def export_dummy_inputs(self, data_in=None, frame=30):
- if data_in is None:
- speech = torch.randn(1, frame, self.encoder_conf.get("input_dim"))
- else:
- speech = None # Undo
-
- cache_frames = self.encoder_conf.get("lorder") + self.encoder_conf.get("rorder") - 1
- in_cache0 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
- in_cache1 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
- in_cache2 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
- in_cache3 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
-
- return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
-
- def export_input_names(self):
- return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
-
- def export_output_names(self):
- return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
-
- def export_dynamic_axes(self):
- return {
- 'speech': {
- 1: 'feats_length'
- },
- }
-
- def export_name(self, ):
- return "model.onnx"
-
+
+ from .export_meta import export_rebuild_model
+ models = export_rebuild_model(model=self, **kwargs)
+ return models
+
def DetectCommonFrames(self, cache: dict = {}) -> int:
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index 7c370ba..f08e97b 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -616,6 +616,22 @@
return x, tgt_mask, memory, memory_mask, cache
+
+ def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ residual = tgt
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn is not None:
+ tgt = self.norm2(tgt)
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + x
+
+ residual = x
+ x = self.norm3(x)
+ x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
+ return attn_mat
@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
@@ -675,6 +691,8 @@
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
+ return_hidden: bool = False,
+ return_both: bool = False,
):
tgt = ys_in_pad
@@ -698,11 +716,60 @@
x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
x, tgt_mask, memory, memory_mask
)
- x = self.after_norm(x)
- x = self.output_layer(x)
+ hidden = self.after_norm(x)
+ # x = self.output_layer(x)
- return x, ys_in_lens
+ if self.output_layer is not None and return_hidden is False:
+ x = self.output_layer(hidden)
+ return x, ys_in_lens
+ if return_both:
+ x = self.output_layer(hidden)
+ return x, hidden, ys_in_lens
+ return hidden, ys_in_lens
+ def forward_asf2(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+ _, memory_mask = self.prepare_mask(memory_mask)
+
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
+ attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+ return attn_mat
+
+ def forward_asf6(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+ _, memory_mask = self.prepare_mask(memory_mask)
+
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](tgt, tgt_mask, memory, memory_mask)
+ attn_mat = self.model.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+ return attn_mat
+
+ '''
def get_dummy_inputs(self, enc_size):
tgt = torch.LongTensor([0]).unsqueeze(0)
memory = torch.randn(1, 100, enc_size)
@@ -751,7 +818,8 @@
for d in range(cache_num)
})
return ret
-
+ '''
+
@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
def __init__(self, model,
diff --git a/funasr/models/paraformer/export_meta.py b/funasr/models/paraformer/export_meta.py
new file mode 100644
index 0000000..4d491e9
--- /dev/null
+++ b/funasr/models/paraformer/export_meta.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import types
+import torch
+from funasr.register import tables
+
+
+def export_rebuild_model(model, **kwargs):
+ model.device = kwargs.get("device")
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
+ model.encoder = encoder_class(model.encoder, onnx=is_onnx)
+
+ predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
+ model.predictor = predictor_class(model.predictor, onnx=is_onnx)
+
+
+ decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
+ model.decoder = decoder_class(model.decoder, onnx=is_onnx)
+
+ from funasr.utils.torch_function import sequence_mask
+ model.make_pad_mask = sequence_mask(kwargs['max_seq_len'], flip=False)
+
+ model.forward = types.MethodType(export_forward, model)
+ model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
+ model.export_input_names = types.MethodType(export_input_names, model)
+ model.export_output_names = types.MethodType(export_output_names, model)
+ model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
+ model.export_name = types.MethodType(export_name, model)
+
+ return model
+
+
+def export_forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ):
+ # a. To device
+ batch = {"speech": speech, "speech_lengths": speech_lengths}
+ # batch = to_device(batch, device=self.device)
+
+ enc, enc_len = self.encoder(**batch)
+ mask = self.make_pad_mask(enc_len)[:, None, :]
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+ pre_token_length = pre_token_length.floor().type(torch.int32)
+
+ decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ # sample_ids = decoder_out.argmax(dim=-1)
+
+ return decoder_out, pre_token_length
+
+def export_dummy_inputs(self):
+ speech = torch.randn(2, 30, 560)
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+ return (speech, speech_lengths)
+
+
+def export_input_names(self):
+ return ['speech', 'speech_lengths']
+
+def export_output_names(self):
+ return ['logits', 'token_num']
+
+def export_dynamic_axes(self):
+ return {
+ 'speech': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'speech_lengths': {
+ 0: 'batch_size',
+ },
+ 'logits': {
+ 0: 'batch_size',
+ 1: 'logits_length'
+ },
+ }
+
+def export_name(self, ):
+ return "model.onnx"
\ No newline at end of file
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 41a1bf7..316255d 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -13,15 +13,16 @@
from funasr.models.ctc.ctc import CTC
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import to_device
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.search import Hypothesis
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.train_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.train_utils.device_funcs import to_device
+
@tables.register("model_classes", "Paraformer")
class Paraformer(torch.nn.Module):
@@ -549,78 +550,10 @@
return results, meta_data
- def export(
- self,
- max_seq_len=512,
- **kwargs,
- ):
- self.device = kwargs.get("device")
- is_onnx = kwargs.get("type", "onnx") == "onnx"
- encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
- self.encoder = encoder_class(self.encoder, onnx=is_onnx)
-
- predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
- self.predictor = predictor_class(self.predictor, onnx=is_onnx)
+ def export(self, **kwargs):
+ from .export_meta import export_rebuild_model
+ if 'max_seq_len' not in kwargs:
+ kwargs['max_seq_len'] = 512
+ models = export_rebuild_model(model=self, **kwargs)
+ return models
-
- decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
- self.decoder = decoder_class(self.decoder, onnx=is_onnx)
-
- from funasr.utils.torch_function import sequence_mask
-
-
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- self.forward = self.export_forward
-
- return self
-
- def export_forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- ):
- # a. To device
- batch = {"speech": speech, "speech_lengths": speech_lengths}
- batch = to_device(batch, device=self.device)
-
- enc, enc_len = self.encoder(**batch)
- mask = self.make_pad_mask(enc_len)[:, None, :]
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
- pre_token_length = pre_token_length.floor().type(torch.int32)
-
- decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- # sample_ids = decoder_out.argmax(dim=-1)
-
- return decoder_out, pre_token_length
-
- def export_dummy_inputs(self):
- speech = torch.randn(2, 30, 560)
- speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
- return (speech, speech_lengths)
-
-
- def export_input_names(self):
- return ['speech', 'speech_lengths']
-
- def export_output_names(self):
- return ['logits', 'token_num']
-
- def export_dynamic_axes(self):
- return {
- 'speech': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
- 'speech_lengths': {
- 0: 'batch_size',
- },
- 'logits': {
- 0: 'batch_size',
- 1: 'logits_length'
- },
- }
-
- def export_name(self, ):
- return "model.onnx"
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index 5f91268..1768bbd 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -697,10 +697,10 @@
self.attn = None
self.all_head_size = self.h * self.d_k
- def forward(self, x, memory, memory_mask):
+ def forward(self, x, memory, memory_mask, ret_attn=False):
q, k, v = self.forward_qkv(x, memory)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
- return self.forward_attention(v, scores, memory_mask)
+ return self.forward_attention(v, scores, memory_mask, ret_attn)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.h, self.d_k)
@@ -717,7 +717,7 @@
v = self.transpose_for_scores(v)
return q, k, v
- def forward_attention(self, value, scores, mask):
+ def forward_attention(self, value, scores, mask, ret_attn):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
@@ -726,6 +726,7 @@
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
+ if ret_attn: return self.linear_out(context_layer), self.attn
return self.linear_out(context_layer) # (batch, time1, d_model)
diff --git a/funasr/models/scama/decoder.py b/funasr/models/scama/decoder.py
index 9dcb9da..8257f59 100644
--- a/funasr/models/scama/decoder.py
+++ b/funasr/models/scama/decoder.py
@@ -474,383 +474,3 @@
return y, new_cache
- def gen_tf2torch_map_dict(self):
-
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
- map_dict_local = {
-
- ## decoder
- # ffn
- "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # fsmn
- "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
- tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
- tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
- tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 2, 0),
- }, # (256,1,31),(1,31,256,1)
- # src att
- "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # dnn
- "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # embed_concat_ffn
- "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # out norm
- "{}.after_norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.after_norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
-
- # in embed
- "{}.embed.0.weight".format(tensor_name_prefix_torch):
- {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (4235,256),(4235,256)
-
- # out layer
- "{}.output_layer.weight".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
- "{}/w_embs".format(embed_tensor_name_prefix_tf)],
- "squeeze": [None, None],
- "transpose": [(1, 0), None],
- }, # (4235,256),(256,4235)
- "{}.output_layer.bias".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
- "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
- "squeeze": [None, None],
- "transpose": [None, None],
- }, # (4235,),(4235,)
-
- }
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
- var_dict_torch_update = dict()
- decoder_layeridx_sets = set()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- if names[1] == "decoders":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- layeridx_bias = 0
- layeridx += layeridx_bias
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "decoders2":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- name_q = name_q.replace("decoders2", "decoders")
- layeridx_bias = len(decoder_layeridx_sets)
-
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "decoders3":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- layeridx_bias = 0
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "embed" or names[1] == "output_layer":
- name_tf = map_dict[name]["name"]
- if isinstance(name_tf, list):
- idx_list = 0
- if name_tf[idx_list] in var_dict_tf.keys():
- pass
- else:
- idx_list = 1
- data_tf = var_dict_tf[name_tf[idx_list]]
- if map_dict[name]["squeeze"][idx_list] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
- if map_dict[name]["transpose"][idx_list] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
- name_tf[idx_list],
- var_dict_tf[name_tf[
- idx_list]].shape))
-
- else:
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "after_norm":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "embed_concat_ffn":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- layeridx_bias = 0
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
-
diff --git a/funasr/models/scama/encoder.py b/funasr/models/scama/encoder.py
index 3651e61..2c676b2 100644
--- a/funasr/models/scama/encoder.py
+++ b/funasr/models/scama/encoder.py
@@ -460,160 +460,3 @@
return xs_pad, ilens, None
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
- ## encoder
- # cicd
- "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (768,256),(1,256,768)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (768,),(768,)
- "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 2, 0),
- }, # (256,1,31),(1,31,256,1)
- "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # ffn
- "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
- "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # out norm
- "{}.after_norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.after_norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
-
- }
-
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
-
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- if names[1] == "encoders0":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- name_q = name_q.replace("encoders0", "encoders")
- layeridx_bias = 0
- layeridx += layeridx_bias
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
- elif names[1] == "encoders":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- layeridx_bias = 1
- layeridx += layeridx_bias
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "after_norm":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
-
diff --git a/funasr/models/seaco_paraformer/export_meta.py b/funasr/models/seaco_paraformer/export_meta.py
new file mode 100644
index 0000000..260b625
--- /dev/null
+++ b/funasr/models/seaco_paraformer/export_meta.py
@@ -0,0 +1,181 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import torch
+
+from funasr.register import tables
+
+
+class ContextualEmbedderExport(torch.nn.Module):
+ def __init__(self,
+ model,
+ max_seq_len=512,
+ feats_dim=560,
+ **kwargs,):
+ super().__init__()
+ self.embedding = model.decoder.embed # model.bias_embed
+ model.bias_encoder.batch_first = False
+ self.bias_encoder = model.bias_encoder
+
+ def forward(self, hotword):
+ hotword = self.embedding(hotword).transpose(0, 1) # batch second
+ hw_embed, (_, _) = self.bias_encoder(hotword)
+ return hw_embed
+
+ def export_dummy_inputs(self):
+ hotword = torch.tensor([
+ [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
+ [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
+ [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ ],
+ dtype=torch.int32)
+ # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
+ return (hotword)
+
+ def export_input_names(self):
+ return ['hotword']
+
+ def export_output_names(self):
+ return ['hw_embed']
+
+ def export_dynamic_axes(self):
+ return {
+ 'hotword': {
+ 0: 'num_hotwords',
+ },
+ 'hw_embed': {
+ 0: 'num_hotwords',
+ },
+ }
+
+ def export_name(self):
+ return 'model_eb.onnx'
+
+
+def export_rebuild_model(model, **kwargs):
+ model.device = kwargs.get("device")
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
+ model.encoder = encoder_class(model.encoder, onnx=is_onnx)
+
+ predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
+ model.predictor = predictor_class(model.predictor, onnx=is_onnx)
+
+ # before decoder convert into export class
+ embedder_class = ContextualEmbedderExport
+ embedder_model = embedder_class(model, onnx=is_onnx)
+
+ decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
+ model.decoder = decoder_class(model.decoder, onnx=is_onnx)
+
+ seaco_decoder_class = tables.decoder_classes.get(kwargs["seaco_decoder"]+"Export")
+ model.seaco_decoder = seaco_decoder_class(model.seaco_decoder, onnx=is_onnx)
+
+ from funasr.utils.torch_function import sequence_mask
+ model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
+
+ from funasr.utils.torch_function import sequence_mask
+ model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
+ model.feats_dim = 560
+ model.NOBIAS = 8377
+
+ import copy
+ import types
+ backbone_model = copy.copy(model)
+
+ # backbone
+ backbone_model.forward = types.MethodType(export_backbone_forward, backbone_model)
+ backbone_model.export_dummy_inputs = types.MethodType(export_backbone_dummy_inputs, backbone_model)
+ backbone_model.export_input_names = types.MethodType(export_backbone_input_names, backbone_model)
+ backbone_model.export_output_names = types.MethodType(export_backbone_output_names, backbone_model)
+ backbone_model.export_dynamic_axes = types.MethodType(export_backbone_dynamic_axes, backbone_model)
+ backbone_model.export_name = types.MethodType(export_backbone_name, backbone_model)
+
+ return backbone_model, embedder_model
+
+
+def export_backbone_forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ bias_embed: torch.Tensor,
+ # lmbd: float,
+ ):
+ # a. To device
+ batch = {"speech": speech, "speech_lengths": speech_lengths}
+
+ enc, enc_len = self.encoder(**batch)
+ mask = self.make_pad_mask(enc_len)[:, None, :]
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+ pre_token_length = pre_token_length.floor().type(torch.int32)
+
+ decoder_out, decoder_hidden, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, return_hidden=True, return_both=True)
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ # seaco forward
+ B, N, D = bias_embed.shape
+ _contextual_length = torch.ones(B) * N
+
+ # ASF
+ hotword_scores = self.seaco_decoder.forward_asf6(bias_embed, _contextual_length, decoder_hidden, pre_token_length)
+ hotword_scores = hotword_scores[0].sum(0).sum(0)
+ # _ = self.decoder2(bias_embed, _contextual_length, decoder_hidden, pre_token_length)
+ # hotword_scores = self.decoder2.model.decoders[-1].attn_mat[0][0].sum(0).sum(0)
+ dec_filter = torch.sort(hotword_scores, descending=True)[1][:51]
+ contextual_info = bias_embed[:,dec_filter]
+ num_hot_word = contextual_info.shape[1]
+ _contextual_length = torch.Tensor([num_hot_word]).int().repeat(B).to(enc.device)
+
+ # again
+ cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, pre_acoustic_embeds, pre_token_length)
+ dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, pre_token_length)
+ merged = cif_attended + dec_attended
+ dha_output = self.hotword_output_layer(merged)
+ dha_pred = torch.log_softmax(dha_output, dim=-1)
+ # merging logits
+ dha_ids = dha_pred.max(-1)[-1]
+ dha_mask = (dha_ids == self.NOBIAS).int().unsqueeze(-1)
+ decoder_out = decoder_out * dha_mask + dha_pred * (1-dha_mask)
+ return decoder_out, pre_token_length, alphas
+
+def export_backbone_dummy_inputs(self):
+ speech = torch.randn(2, 30, self.feats_dim)
+ speech_lengths = torch.tensor([15, 30], dtype=torch.int32)
+ bias_embed = torch.randn(2, 1, 512)
+ return (speech, speech_lengths, bias_embed)
+
+def export_backbone_input_names(self):
+ return ['speech', 'speech_lengths', 'bias_embed']
+
+def export_backbone_output_names(self):
+ return ['logits', 'token_num', 'alphas']
+
+def export_backbone_dynamic_axes(self):
+ return {
+ 'speech': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'speech_lengths': {
+ 0: 'batch_size',
+ },
+ 'bias_embed': {
+ 0: 'batch_size',
+ 1: 'num_hotwords'
+ },
+ 'logits': {
+ 0: 'batch_size',
+ 1: 'logits_length'
+ },
+ 'pre_acoustic_embeds': {
+ 1: 'feats_length1'
+ }
+ }
+
+def export_backbone_name(self):
+ return 'model.onnx'
+
\ No newline at end of file
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 92fc989..27ff5d1 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -430,7 +430,6 @@
return results, meta_data
-
def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
def load_seg_dict(seg_dict_file):
seg_dict = {}
@@ -532,3 +531,13 @@
hotword_list = None
return hotword_list
+ def export(
+ self,
+ **kwargs,
+ ):
+ if 'max_seq_len' not in kwargs:
+ kwargs['max_seq_len'] = 512
+ from .export_meta import export_rebuild_model
+ models = export_rebuild_model(model=self, **kwargs)
+ return models
+
diff --git a/funasr/models/sond/encoder/ci_scorers.py b/funasr/models/sond/encoder/ci_scorers.py
index 50056ee..c5f45b9 100644
--- a/funasr/models/sond/encoder/ci_scorers.py
+++ b/funasr/models/sond/encoder/ci_scorers.py
@@ -16,9 +16,6 @@
scores = torch.matmul(xs_pad, spk_emb.transpose(1, 2))
return scores
- def convert_tf2torch(self, var_dict_tf, var_dict_torch):
- return {}
-
class CosScorer(torch.nn.Module):
def __init__(self):
@@ -33,6 +30,3 @@
# spk_emb: B, N, D
scores = F.cosine_similarity(xs_pad.unsqueeze(2), spk_emb.unsqueeze(1), dim=-1)
return scores
-
- def convert_tf2torch(self, var_dict_tf, var_dict_torch):
- return {}
diff --git a/funasr/models/sond/encoder/conv_encoder.py b/funasr/models/sond/encoder/conv_encoder.py
index 3933c01..2181160 100644
--- a/funasr/models/sond/encoder/conv_encoder.py
+++ b/funasr/models/sond/encoder/conv_encoder.py
@@ -173,103 +173,3 @@
return outputs, ilens, None
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
- # torch: conv1d.weight in "out_channel in_channel kernel_size"
- # tf : conv1d.weight in "kernel_size in_channel out_channel"
- # torch: linear.weight in "out_channel in_channel"
- # tf : dense.weight in "in_channel out_channel"
- "{}.cnn_a.0.conv1d.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cnn_a/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- },
- "{}.cnn_a.0.conv1d.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cnn_a/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
-
- "{}.cnn_a.layeridx.conv1d.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cnn_a/conv1d_layeridx/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- },
- "{}.cnn_a.layeridx.conv1d.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cnn_a/conv1d_layeridx/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- }
- if self.out_units is not None:
- # add output layer
- map_dict_local.update({
- "{}.conv_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cnn_a/conv1d_{}/kernel".format(tensor_name_prefix_tf, self.num_layers),
- "squeeze": None,
- "transpose": (2, 1, 0),
- }, # tf: (1, 256, 256) -> torch: (256, 256, 1)
- "{}.conv_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cnn_a/conv1d_{}/bias".format(tensor_name_prefix_tf, self.num_layers),
- "squeeze": None,
- "transpose": None,
- }, # tf: (256,) -> torch: (256,)
- })
-
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
-
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- if name.startswith(self.tf2torch_tensor_name_prefix_torch):
- # process special (first and last) layers
- if name in map_dict:
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), \
- "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[name].size(), data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
- name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
- ))
- # process general layers
- else:
- # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
- names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), \
- "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[name].size(), data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
- name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
- ))
- else:
- logging.warning("{} is missed from tf checkpoint".format(name))
-
- return var_dict_torch_update
-
diff --git a/funasr/models/sond/encoder/fsmn_encoder.py b/funasr/models/sond/encoder/fsmn_encoder.py
index 129a748..fb87ee8 100644
--- a/funasr/models/sond/encoder/fsmn_encoder.py
+++ b/funasr/models/sond/encoder/fsmn_encoder.py
@@ -195,140 +195,3 @@
return inputs, ilens, None
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
- # torch: conv1d.weight in "out_channel in_channel kernel_size"
- # tf : conv1d.weight in "kernel_size in_channel out_channel"
- # torch: linear.weight in "out_channel in_channel"
- # tf : dense.weight in "in_channel out_channel"
- # for fsmn_layers
- "{}.fsmn_layers.layeridx.ffn.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.fsmn_layers.layeridx.ffn.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.fsmn_layers.layeridx.ffn.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/fsmn_layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.fsmn_layers.layeridx.ffn.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/fsmn_layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- },
- "{}.fsmn_layers.layeridx.ffn.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/fsmn_layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- },
- "{}.fsmn_layers.layeridx.memory.fsmn_block.weight".format(tensor_name_prefix_torch):
- {"name": "{}/fsmn_layer_layeridx/memory/depth_conv_w".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 2, 0),
- }, # (1, 31, 512, 1) -> (31, 512, 1) -> (512, 1, 31)
-
- # for dnn_layers
- "{}.dnn_layers.layeridx.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.dnn_layers.layeridx.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.dnn_layers.layeridx.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.dnn_layers.layeridx.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- },
- "{}.dnn_layers.layeridx.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- },
-
- }
- if self.out_units is not None:
- # add output layer
- map_dict_local.update({
- "{}.conv1d.weight".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- },
- "{}.conv1d.bias".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- })
-
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
-
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- if name.startswith(self.tf2torch_tensor_name_prefix_torch):
- # process special (first and last) layers
- if name in map_dict:
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), \
- "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[name].size(), data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
- name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
- ))
- # process general layers
- else:
- # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
- names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), \
- "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[name].size(), data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
- name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
- ))
- else:
- logging.warning("{} is missed from tf checkpoint".format(name))
-
- return var_dict_torch_update
diff --git a/funasr/models/sond/encoder/resnet34_encoder.py b/funasr/models/sond/encoder/resnet34_encoder.py
index 8445feb..8bfe491 100644
--- a/funasr/models/sond/encoder/resnet34_encoder.py
+++ b/funasr/models/sond/encoder/resnet34_encoder.py
@@ -245,147 +245,6 @@
return features, resnet_out_lens
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- train_steps = self.tf_train_steps
- map_dict_local = {
- # torch: conv1d.weight in "out_channel in_channel kernel_size"
- # tf : conv1d.weight in "kernel_size in_channel out_channel"
- # torch: linear.weight in "out_channel in_channel"
- # tf : dense.weight in "in_channel out_channel"
- "{}.pre_conv.weight".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (3, 2, 0, 1),
- },
- "{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
- }
- for layer_idx in range(3):
- map_dict_local.update({
- "{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
- },
- "{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
- })
-
- for block_idx in range(len(self.layers_in_block)):
- for layer_idx in range(self.layers_in_block[block_idx]):
- for i in ["1", "2", "_sc"]:
- map_dict_local.update({
- "{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": (3, 2, 0, 1),
- },
- "{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": None,
- },
- "{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": None,
- },
- "{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": None,
- },
- "{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": None,
- },
- "{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
- })
-
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
-
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- if name.startswith(self.tf2torch_tensor_name_prefix_torch):
- if name in map_dict:
- if "num_batches_tracked" not in name:
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), \
- "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[name].size(), data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
- name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
- ))
- else:
- var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
- logging.info("torch tensor: {}, manually assigning to: {}".format(
- name, map_dict[name]
- ))
- else:
- logging.warning("{} is missed from tf checkpoint".format(name))
-
- return var_dict_torch_update
-
class ResNet34Diar(ResNet34):
def __init__(
@@ -477,147 +336,6 @@
endpoints["resnet2_bn"] = features
return endpoints[self.embedding_node], ilens, None
-
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- train_steps = 300000
- map_dict_local = {
- # torch: conv1d.weight in "out_channel in_channel kernel_size"
- # tf : conv1d.weight in "kernel_size in_channel out_channel"
- # torch: linear.weight in "out_channel in_channel"
- # tf : dense.weight in "in_channel out_channel"
- "{}.pre_conv.weight".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (3, 2, 0, 1),
- },
- "{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
- }
- for layer_idx in range(3):
- map_dict_local.update({
- "{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": (3, 2, 0, 1) if layer_idx == 0 else (1, 0),
- },
- "{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
- })
-
- for block_idx in range(len(self.layers_in_block)):
- for layer_idx in range(self.layers_in_block[block_idx]):
- for i in ["1", "2", "_sc"]:
- map_dict_local.update({
- "{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": (3, 2, 0, 1),
- },
- "{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": None,
- },
- "{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": None,
- },
- "{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": None,
- },
- "{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": None,
- },
- "{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
- })
-
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
-
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- if name.startswith(self.tf2torch_tensor_name_prefix_torch):
- if name in map_dict:
- if "num_batches_tracked" not in name:
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), \
- "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[name].size(), data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
- name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
- ))
- else:
- var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
- logging.info("torch tensor: {}, manually assigning to: {}".format(
- name, map_dict[name]
- ))
- else:
- logging.warning("{} is missed from tf checkpoint".format(name))
-
- return var_dict_torch_update
class ResNet34SpL2RegDiar(ResNet34_SP_L2Reg):
@@ -711,143 +429,3 @@
return endpoints[self.embedding_node], ilens, None
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- train_steps = 720000
- map_dict_local = {
- # torch: conv1d.weight in "out_channel in_channel kernel_size"
- # tf : conv1d.weight in "kernel_size in_channel out_channel"
- # torch: linear.weight in "out_channel in_channel"
- # tf : dense.weight in "in_channel out_channel"
- "{}.pre_conv.weight".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (3, 2, 0, 1),
- },
- "{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
- {"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- },
- "{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
- }
- for layer_idx in range(3):
- map_dict_local.update({
- "{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
- },
- "{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
- {"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
- "squeeze": None,
- "transpose": None,
- },
- "{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
- })
-
- for block_idx in range(len(self.layers_in_block)):
- for layer_idx in range(self.layers_in_block[block_idx]):
- for i in ["1", "2", "_sc"]:
- map_dict_local.update({
- "{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": (3, 2, 0, 1),
- },
- "{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": None,
- },
- "{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": None,
- },
- "{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": None,
- },
- "{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
- {"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
- "squeeze": None,
- "transpose": None,
- },
- "{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
- })
-
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
-
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- if name.startswith(self.tf2torch_tensor_name_prefix_torch):
- if name in map_dict:
- if "num_batches_tracked" not in name:
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), \
- "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[name].size(), data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
- name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
- ))
- else:
- var_dict_torch_update[name] = torch.from_numpy(np.array(map_dict[name])).type(torch.int64).to("cpu")
- logging.info("torch tensor: {}, manually assigning to: {}".format(
- name, map_dict[name]
- ))
- else:
- logging.warning("{} is missed from tf checkpoint".format(name))
-
- return var_dict_torch_update
\ No newline at end of file
diff --git a/funasr/models/sond/encoder/self_attention_encoder.py b/funasr/models/sond/encoder/self_attention_encoder.py
index ea974c6..f3c4736 100644
--- a/funasr/models/sond/encoder/self_attention_encoder.py
+++ b/funasr/models/sond/encoder/self_attention_encoder.py
@@ -326,153 +326,3 @@
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
- # cicd
- # torch: conv1d.weight in "out_channel in_channel kernel_size"
- # tf : conv1d.weight in "kernel_size in_channel out_channel"
- # torch: linear.weight in "out_channel in_channel"
- # tf : dense.weight in "in_channel out_channel"
- "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (768,256),(1,256,768)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (768,),(768,)
- "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # ffn
- "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
- "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # out norm
- "{}.after_norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.after_norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- }
- if self.out_units is not None:
- map_dict_local.update({
- "{}.output_linear.weight".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- },
- "{}.output_linear.bias".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- })
-
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
-
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- if name.startswith(self.tf2torch_tensor_name_prefix_torch):
- # process special (first and last) layers
- if name in map_dict:
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- assert var_dict_torch[name].size() == data_tf.size(), \
- "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[name].size(), data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
- name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
- ))
- # process general layers
- else:
- # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
- names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), \
- "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[name].size(), data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
- name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
- ))
- else:
- logging.warning("{} is missed from tf checkpoint".format(name))
-
- return var_dict_torch_update
diff --git a/funasr/models/sond/pooling/statistic_pooling.py b/funasr/models/sond/pooling/statistic_pooling.py
index 392e333..a3d7e34 100644
--- a/funasr/models/sond/pooling/statistic_pooling.py
+++ b/funasr/models/sond/pooling/statistic_pooling.py
@@ -38,9 +38,6 @@
return stat_pooling
- def convert_tf2torch(self, var_dict_tf, var_dict_torch):
- return {}
-
def statistic_pooling(
xs_pad: torch.Tensor,
diff --git a/funasr/models/transformer/positionwise_feed_forward.py b/funasr/models/transformer/positionwise_feed_forward.py
index 7ca55cb..081ff5b 100644
--- a/funasr/models/transformer/positionwise_feed_forward.py
+++ b/funasr/models/transformer/positionwise_feed_forward.py
@@ -34,3 +34,16 @@
return self.w_2(self.dropout(self.activation(self.w_1(x))))
+class PositionwiseFeedForwardDecoderSANMExport(torch.nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.w_1 = model.w_1
+ self.w_2 = model.w_2
+ self.activation = model.activation
+ self.norm = model.norm
+
+ def forward(self, x):
+ x = self.activation(self.w_1(x))
+ x = self.w_2(self.norm(x))
+ return x
+
diff --git a/funasr/models/transformer/utils/subsampling.py b/funasr/models/transformer/utils/subsampling.py
index 64a9dbc..088675e 100644
--- a/funasr/models/transformer/utils/subsampling.py
+++ b/funasr/models/transformer/utils/subsampling.py
@@ -368,49 +368,6 @@
x_len = (x_len - 1) // self.stride + 1
return x, x_len
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
- ## predictor
- "{}.conv.weight".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- }, # (256,256,3),(3,256,256)
- "{}.conv.bias".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- }
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
-
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-
- var_dict_torch_update[name] = data_tf
-
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
- return var_dict_torch_update
class StreamingConvInput(torch.nn.Module):
"""Streaming ConvInput module definition.
diff --git a/runtime/python/onnxruntime/demo_contextual_paraformer.py b/runtime/python/onnxruntime/demo_contextual_paraformer.py
index 4f8fdbd..d8aee23 100644
--- a/runtime/python/onnxruntime/demo_contextual_paraformer.py
+++ b/runtime/python/onnxruntime/demo_contextual_paraformer.py
@@ -1,12 +1,12 @@
from funasr_onnx import ContextualParaformer
from pathlib import Path
-model_dir = "../export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" # your export dir
+model_dir = "damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model = ContextualParaformer(model_dir, batch_size=1)
-wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
+wav_path = ['{}/.cache/modelscope/hub/{}/example/asr_example.wav'.format(Path.home(), model_dir)]
-hotwords = '闅忔満鐑瘝 鍚勭鐑瘝 榄旀惌 闃块噷宸村反 浠�'
+hotwords = '浣犵殑鐑瘝 榄旀惌'
result = model(wav_path, hotwords)
print(result)
diff --git a/runtime/python/onnxruntime/demo_paraformer_offline.py b/runtime/python/onnxruntime/demo_paraformer_offline.py
index bc8355b..c6fc79c 100644
--- a/runtime/python/onnxruntime/demo_paraformer_offline.py
+++ b/runtime/python/onnxruntime/demo_paraformer_offline.py
@@ -2,14 +2,14 @@
from pathlib import Path
model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
-model_dir = "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
-model = Paraformer(model_dir, batch_size=1, quantize=True)
+# model_dir = "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model = Paraformer(model_dir, batch_size=1, quantize=False)
# model = Paraformer(model_dir, batch_size=1, device_id=0) # gpu
# when using paraformer-large-vad-punc model, you can set plot_timestamp_to="./xx.png" to get figure of alignment besides timestamps
# model = Paraformer(model_dir, batch_size=1, plot_timestamp_to="test.png")
-wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav'.format(Path.home())]
+wav_path = ['{}/.cache/modelscope/hub/{}/example/asr_example.wav'.format(Path.home(), model_dir)]
result = model(wav_path)
print(result)
diff --git a/runtime/python/onnxruntime/demo_paraformer_online.py b/runtime/python/onnxruntime/demo_paraformer_online.py
index b5c9371..210f3ea 100644
--- a/runtime/python/onnxruntime/demo_paraformer_online.py
+++ b/runtime/python/onnxruntime/demo_paraformer_online.py
@@ -3,7 +3,7 @@
from pathlib import Path
model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
-wav_path = '{}/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/example/asr_example.wav'.format(Path.home())
+wav_path = ['{}/.cache/modelscope/hub/{}/example/asr_example.wav'.format(Path.home(), model_dir)]
chunk_size = [5, 10, 5]
model = Paraformer(model_dir, batch_size=1, quantize=True, chunk_size=chunk_size, intra_op_num_threads=4) # only support batch_size = 1
diff --git a/runtime/python/onnxruntime/demo_seaco_paraformer.py b/runtime/python/onnxruntime/demo_seaco_paraformer.py
new file mode 100644
index 0000000..29cff3e
--- /dev/null
+++ b/runtime/python/onnxruntime/demo_seaco_paraformer.py
@@ -0,0 +1,12 @@
+from funasr_onnx import SeacoParaformer
+from pathlib import Path
+
+model_dir = "iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model = SeacoParaformer(model_dir, batch_size=1)
+
+wav_path = ['{}/.cache/modelscope/hub/{}/example/asr_example.wav'.format(Path.home(), model_dir)]
+
+hotwords = '浣犵殑鐑瘝 榄旀惌'
+
+result = model(wav_path, hotwords)
+print(result)
diff --git a/runtime/python/onnxruntime/funasr_onnx/__init__.py b/runtime/python/onnxruntime/funasr_onnx/__init__.py
index c03d7e5..d0d6651 100644
--- a/runtime/python/onnxruntime/funasr_onnx/__init__.py
+++ b/runtime/python/onnxruntime/funasr_onnx/__init__.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
-from .paraformer_bin import Paraformer, ContextualParaformer
+from .paraformer_bin import Paraformer, ContextualParaformer, SeacoParaformer
from .vad_bin import Fsmn_vad
from .vad_bin import Fsmn_vad_online
from .punc_bin import CT_Transformer
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index 82548ad..af3c9b9 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -266,22 +266,32 @@
model_bb_file = os.path.join(model_dir, 'model.onnx')
model_eb_file = os.path.join(model_dir, 'model_eb.onnx')
- token_list_file = os.path.join(model_dir, 'tokens.txt')
- self.vocab = {}
- with open(Path(token_list_file), 'r') as fin:
- for i, line in enumerate(fin.readlines()):
- self.vocab[line.strip()] = i
+ if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
+ print(".onnx is not exist, begin to export onnx")
+ try:
+ from funasr import AutoModel
+ except:
+ raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
+ "\npip3 install -U funasr\n" \
+ "For the users in China, you could install with the command:\n" \
+ "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
- #if quantize:
- # model_file = os.path.join(model_dir, 'model_quant.onnx')
- #if not os.path.exists(model_file):
- # logging.error(".onnx model not exist, please export first.")
+ model = AutoModel(model=model_dir)
+ model_dir = model.export(type="onnx", quantize=quantize)
config_file = os.path.join(model_dir, 'config.yaml')
cmvn_file = os.path.join(model_dir, 'am.mvn')
config = read_yaml(config_file)
+ token_list = os.path.join(model_dir, 'tokens.json')
+ with open(token_list, 'r', encoding='utf-8') as f:
+ token_list = json.load(f)
+
+ # revert token_list into vocab dict
+ self.vocab = {}
+ for i, token in enumerate(token_list):
+ self.vocab[token] = i
- self.converter = TokenIDConverter(config['token_list'])
+ self.converter = TokenIDConverter(token_list)
self.tokenizer = CharTokenizer()
self.frontend = WavFrontend(
cmvn_file=cmvn_file,
@@ -389,4 +399,8 @@
token = self.converter.ids2tokens(token_int)
token = token[:valid_token_num-self.pred_bias]
# texts = sentence_postprocess(token)
- return token
\ No newline at end of file
+ return token
+
+
+class SeacoParaformer(ContextualParaformer):
+ pass # no difference with contextual_paraformer in method of calling onnx models
\ No newline at end of file
--
Gitblit v1.9.1