Shi Xian
2024-03-13 e04489ce4c0fd0095d0c79ef8f504f425e0435a8
contextual&seaco ONNX export (#1481)

* contextual&seaco ONNX export

* update ContextualEmbedderExport2

* update ContextualEmbedderExport2

* update code

* onnx (#1482)

* qwenaudio qwenaudiochat

* qwenaudio qwenaudiochat

* whisper

* whisper

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* export onnx

* export onnx

* export onnx

* dingding

* dingding

* llm

* doc

* onnx

* onnx

* onnx

* onnx

* onnx

* onnx

* v1.0.15

* qwenaudio

* qwenaudio

* issue doc

* update

* update

* bugfix

* onnx

* update export calling

* update codes

* remove useless code

* update code

---------

Co-authored-by: zhifu gao <zhifu.gzf@alibaba-inc.com>
23个文件已修改
6个文件已添加
2897 ■■■■ 已修改文件
funasr/models/bicif_paraformer/export_meta.py 91 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/bicif_paraformer/model.py 89 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/contextual_paraformer/decoder.py 563 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/contextual_paraformer/export_meta.py 108 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/contextual_paraformer/model.py 19 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/fsmn_vad_streaming/export_meta.py 59 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/fsmn_vad_streaming/model.py 48 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/decoder.py 76 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/export_meta.py 85 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/model.py 85 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/attention.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/scama/decoder.py 380 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/scama/encoder.py 157 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/seaco_paraformer/export_meta.py 181 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/seaco_paraformer/model.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sond/encoder/ci_scorers.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sond/encoder/conv_encoder.py 100 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sond/encoder/fsmn_encoder.py 137 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sond/encoder/resnet34_encoder.py 422 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sond/encoder/self_attention_encoder.py 150 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sond/pooling/statistic_pooling.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transformer/positionwise_feed_forward.py 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transformer/utils/subsampling.py 43 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/demo_contextual_paraformer.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/demo_paraformer_offline.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/demo_paraformer_online.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/demo_seaco_paraformer.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/funasr_onnx/__init__.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py 36 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/bicif_paraformer/export_meta.py
New file
@@ -0,0 +1,91 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
import types
from funasr.register import tables
def export_rebuild_model(model, **kwargs):
    is_onnx = kwargs.get("type", "onnx") == "onnx"
    encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
    model.encoder = encoder_class(model.encoder, onnx=is_onnx)
    predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
    model.predictor = predictor_class(model.predictor, onnx=is_onnx)
    decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
    model.decoder = decoder_class(model.decoder, onnx=is_onnx)
    from funasr.utils.torch_function import sequence_mask
    model.make_pad_mask = sequence_mask(kwargs['max_seq_len'], flip=False)
    model.forward = types.MethodType(export_forward, model)
    model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
    model.export_input_names = types.MethodType(export_input_names, model)
    model.export_output_names = types.MethodType(export_output_names, model)
    model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
    model.export_name = types.MethodType(export_name, model)
    return model
def export_forward(
    self,
    speech: torch.Tensor,
    speech_lengths: torch.Tensor,
):
    # a. To device
    batch = {"speech": speech, "speech_lengths": speech_lengths}
    enc, enc_len = self.encoder(**batch)
    mask = self.make_pad_mask(enc_len)[:, None, :]
    pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
    pre_token_length = pre_token_length.round().type(torch.int32)
    decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
    decoder_out = torch.log_softmax(decoder_out, dim=-1)
    # get predicted timestamps
    us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
    return decoder_out, pre_token_length, us_alphas, us_cif_peak
def export_dummy_inputs(self):
    speech = torch.randn(2, 30, 560)
    speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
    return (speech, speech_lengths)
def export_input_names(self):
    return ['speech', 'speech_lengths']
def export_output_names(self):
    return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
def export_dynamic_axes(self):
    return {
        'speech': {
            0: 'batch_size',
            1: 'feats_length'
        },
        'speech_lengths': {
            0: 'batch_size',
        },
        'logits': {
            0: 'batch_size',
            1: 'logits_length'
        },
        'us_alphas': {
            0: 'batch_size',
            1: 'alphas_length'
        },
        'us_cif_peak': {
            0: 'batch_size',
            1: 'alphas_length'
        },
    }
def export_name(self):
    return "model.onnx"
funasr/models/bicif_paraformer/model.py
@@ -343,86 +343,9 @@
        
        return results, meta_data
    def export(
        self,
        max_seq_len=512,
        **kwargs,
    ):
        self.device = kwargs.get("device")
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
        decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
        from funasr.utils.torch_function import sequence_mask
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        self.forward = self.export_forward
        return self
    def export_forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ):
        # a. To device
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        batch = to_device(batch, device=self.device)
        enc, enc_len = self.encoder(**batch)
        mask = self.make_pad_mask(enc_len)[:, None, :]
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
        pre_token_length = pre_token_length.round().type(torch.int32)
        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        # get predicted timestamps
        us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
        return decoder_out, pre_token_length, us_alphas, us_cif_peak
    def export_dummy_inputs(self):
        speech = torch.randn(2, 30, 560)
        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
        return (speech, speech_lengths)
    def export_input_names(self):
        return ['speech', 'speech_lengths']
    def export_output_names(self):
        return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
    def export_dynamic_axes(self):
        return {
            'speech': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'speech_lengths': {
                0: 'batch_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
            'us_alphas': {
                0: 'batch_size',
                1: 'alphas_length'
            },
            'us_cif_peak': {
                0: 'batch_size',
                1: 'alphas_length'
            },
        }
    def export_name(self, ):
        return "model.onnx"
    def export(self, **kwargs):
        from .export_meta import export_rebuild_model
        if 'max_seq_len' not in kwargs:
            kwargs['max_seq_len'] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models
funasr/models/contextual_paraformer/decoder.py
@@ -305,473 +305,128 @@
            x = self.output_layer(x)
        return x, olens
    def gen_tf2torch_map_dict(self):
        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
        map_dict_local = {
@tables.register("decoder_classes", "ContextualParaformerDecoderExport")
class ContextualParaformerDecoderExport(torch.nn.Module):
    def __init__(self, model,
                 max_seq_len=512,
                 model_name='decoder',
                 onnx: bool = True,
                 **kwargs,):
        super().__init__()
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
        from funasr.models.paraformer.decoder import DecoderLayerSANMExport
        from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANMExport
            ## decoder
            # ffn
            "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (1024,256),(1,256,1024)
            "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,1024),(1,1024,256)
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                d.feed_forward = PositionwiseFeedForwardDecoderSANMExport(d.feed_forward)
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
            self.model.decoders[i] = DecoderLayerSANMExport(d)
            # fsmn
            "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
                    tensor_name_prefix_tf),
                    "squeeze": None,
                    "transpose": None,
                },  # (256,),(256,)
            "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
                    tensor_name_prefix_tf),
                    "squeeze": None,
                    "transpose": None,
                },  # (256,),(256,)
            "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
                    tensor_name_prefix_tf),
                    "squeeze": 0,
                    "transpose": (1, 2, 0),
                },  # (256,1,31),(1,31,256,1)
            # src att
            "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,256),(1,256,256)
            "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (1024,256),(1,256,1024)
            "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,256),(1,256,256)
            "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            # dnn
            "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (1024,256),(1,256,1024)
            "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,1024),(1,1024,256)
        if self.model.decoders2 is not None:
            for i, d in enumerate(self.model.decoders2):
                if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                    d.feed_forward = PositionwiseFeedForwardDecoderSANMExport(d.feed_forward)
                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                    d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
                self.model.decoders2[i] = DecoderLayerSANMExport(d)
            # embed_concat_ffn
            "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (1024,256),(1,256,1024)
            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,1024),(1,1024,256)
        for i, d in enumerate(self.model.decoders3):
            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                d.feed_forward = PositionwiseFeedForwardDecoderSANMExport(d.feed_forward)
            self.model.decoders3[i] = DecoderLayerSANMExport(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
            # out norm
            "{}.after_norm.weight".format(tensor_name_prefix_torch):
                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.after_norm.bias".format(tensor_name_prefix_torch):
                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
        # bias decoder
        if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt):
            self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAttExport(self.model.bias_decoder.src_attn)
        self.bias_decoder = self.model.bias_decoder
        # last decoder
        if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt):
            self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAttExport(self.model.last_decoder.src_attn)
        if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder):
            self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoderExport(self.model.last_decoder.self_attn)
        if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM):
            self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANMExport(self.model.last_decoder.feed_forward)
        self.last_decoder = self.model.last_decoder
        self.bias_output = self.model.bias_output
        self.dropout = self.model.dropout
            # in embed
            "{}.embed.0.weight".format(tensor_name_prefix_torch):
                {"name": "{}/w_embs".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (4235,256),(4235,256)
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
            # out layer
            "{}.output_layer.weight".format(tensor_name_prefix_torch):
                {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
                 "squeeze": [None, None],
                 "transpose": [(1, 0), None],
                 },  # (4235,256),(256,4235)
            "{}.output_layer.bias".format(tensor_name_prefix_torch):
                {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
                          "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
                 "squeeze": [None, None],
                 "transpose": [None, None],
                 },  # (4235,),(4235,)
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        bias_embed: torch.Tensor,
    ):
            ## clas decoder
            # src att
            "{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,256),(1,256,256)
            "{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (1024,256),(1,256,1024)
            "{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,256),(1,256,256)
            "{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            # dnn
            "{}.bias_output.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": (2, 1, 0),
                 },  # (1024,256),(1,256,1024)
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        }
        return map_dict_local
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
    def convert_tf2torch(self,
                         var_dict_tf,
                         var_dict_torch,
                         ):
        map_dict = self.gen_tf2torch_map_dict()
        var_dict_torch_update = dict()
        decoder_layeridx_sets = set()
        for name in sorted(var_dict_torch.keys(), reverse=False):
            names = name.split('.')
            if names[0] == self.tf2torch_tensor_name_prefix_torch:
                if names[1] == "decoders":
                    layeridx = int(names[2])
                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                    layeridx_bias = 0
                    layeridx += layeridx_bias
                    decoder_layeridx_sets.add(layeridx)
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                          var_dict_tf[name_tf].shape))
                elif names[1] == "last_decoder":
                    layeridx = 15
                    name_q = name.replace("last_decoder", "decoders.layeridx")
                    layeridx_bias = 0
                    layeridx += layeridx_bias
                    decoder_layeridx_sets.add(layeridx)
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                          var_dict_tf[name_tf].shape))
        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
            x, tgt_mask, memory, memory_mask
        )
        _, _, x_self_attn, x_src_attn = self.last_decoder(
            x, tgt_mask, memory, memory_mask
        )
                elif names[1] == "decoders2":
                    layeridx = int(names[2])
                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                    name_q = name_q.replace("decoders2", "decoders")
                    layeridx_bias = len(decoder_layeridx_sets)
        # contextual paraformer related
        contextual_length = torch.Tensor([bias_embed.shape[1]]).int().repeat(hs_pad.shape[0])
        # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
        contextual_mask = self.make_pad_mask(contextual_length)
        contextual_mask, _ = self.prepare_mask(contextual_mask)
        # import pdb; pdb.set_trace()
        contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1)
        cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask)
                    layeridx += layeridx_bias
                    if "decoders." in name:
                        decoder_layeridx_sets.add(layeridx)
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                          var_dict_tf[name_tf].shape))
        if self.bias_output is not None:
            x = torch.cat([x_src_attn, cx], dim=2)
            x = self.bias_output(x.transpose(1, 2)).transpose(1, 2)  # 2D -> D
            x = x_self_attn + self.dropout(x)
                elif names[1] == "decoders3":
                    layeridx = int(names[2])
                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
        if self.model.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
            x, tgt_mask, memory, memory_mask
        )
        x = self.after_norm(x)
        x = self.output_layer(x)
                    layeridx_bias = 0
                    layeridx += layeridx_bias
                    if "decoders." in name:
                        decoder_layeridx_sets.add(layeridx)
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                          var_dict_tf[name_tf].shape))
                elif names[1] == "bias_decoder":
                    name_q = name
        return x, ys_in_lens
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                          var_dict_tf[name_tf].shape))
                elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
                    name_tf = map_dict[name]["name"]
                    if isinstance(name_tf, list):
                        idx_list = 0
                        if name_tf[idx_list] in var_dict_tf.keys():
                            pass
                        else:
                            idx_list = 1
                        data_tf = var_dict_tf[name_tf[idx_list]]
                        if map_dict[name]["squeeze"][idx_list] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
                        if map_dict[name]["transpose"][idx_list] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
                                                                                                   name_tf[idx_list],
                                                                                                   var_dict_tf[name_tf[
                                                                                                       idx_list]].shape))
                    else:
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
                        if map_dict[name]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                          var_dict_tf[name_tf].shape))
                elif names[1] == "after_norm":
                    name_tf = map_dict[name]["name"]
                    data_tf = var_dict_tf[name_tf]
                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                    var_dict_torch_update[name] = data_tf
                    logging.info(
                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                      var_dict_tf[name_tf].shape))
                elif names[1] == "embed_concat_ffn":
                    layeridx = int(names[2])
                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                    layeridx_bias = 0
                    layeridx += layeridx_bias
                    if "decoders." in name:
                        decoder_layeridx_sets.add(layeridx)
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                          var_dict_tf[name_tf].shape))
        return var_dict_torch_update
funasr/models/contextual_paraformer/export_meta.py
New file
@@ -0,0 +1,108 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
import types
from funasr.register import tables
from funasr.models.seaco_paraformer.export_meta import ContextualEmbedderExport
class ContextualEmbedderExport2(ContextualEmbedderExport):
    def __init__(self,
                 model,
                 **kwargs):
        super().__init__(model)
        self.embedding = model.bias_embed
        model.bias_encoder.batch_first = False
        self.bias_encoder = model.bias_encoder
def export_rebuild_model(model, **kwargs):
    is_onnx = kwargs.get("type", "onnx") == "onnx"
    encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
    model.encoder = encoder_class(model.encoder, onnx=is_onnx)
    predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
    model.predictor = predictor_class(model.predictor, onnx=is_onnx)
    # little difference with bias encoder with seaco paraformer
    embedder_class = ContextualEmbedderExport2
    embedder_model = embedder_class(model, onnx=is_onnx)
    if kwargs["decoder"] == "ParaformerSANMDecoder":
        kwargs["decoder"] = "ParaformerSANMDecoderOnline"
    decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
    model.decoder = decoder_class(model.decoder, onnx=is_onnx)
    from funasr.utils.torch_function import sequence_mask
    model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
    model.feats_dim = 560
    import copy
    backbone_model = copy.copy(model)
    # backbone
    backbone_model.forward = types.MethodType(export_backbone_forward, backbone_model)
    backbone_model.export_dummy_inputs = types.MethodType(export_backbone_dummy_inputs, backbone_model)
    backbone_model.export_input_names = types.MethodType(export_backbone_input_names, backbone_model)
    backbone_model.export_output_names = types.MethodType(export_backbone_output_names, backbone_model)
    backbone_model.export_dynamic_axes = types.MethodType(export_backbone_dynamic_axes, backbone_model)
    backbone_model.export_name = types.MethodType(export_backbone_name, backbone_model)
    return backbone_model, embedder_model
def export_backbone_forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            bias_embed: torch.Tensor,
    ):
    batch = {"speech": speech, "speech_lengths": speech_lengths}
    enc, enc_len = self.encoder(**batch)
    mask = self.make_pad_mask(enc_len)[:, None, :]
    pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(enc, mask)
    pre_token_length = pre_token_length.floor().type(torch.int32)
    decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, bias_embed)
    decoder_out = torch.log_softmax(decoder_out, dim=-1)
    return decoder_out, pre_token_length
def export_backbone_dummy_inputs(self):
    speech = torch.randn(2, 30, self.feats_dim)
    speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
    bias_embed = torch.randn(2, 1, 512)
    return (speech, speech_lengths, bias_embed)
def export_backbone_input_names(self):
    return ['speech', 'speech_lengths', 'bias_embed']
def export_backbone_output_names(self):
    return ['logits', 'token_num']
def export_backbone_dynamic_axes(self):
    return {
        'speech': {
            0: 'batch_size',
            1: 'feats_length'
        },
        'speech_lengths': {
            0: 'batch_size',
        },
        'bias_embed': {
            0: 'batch_size',
            1: 'num_hotwords'
        },
        'logits': {
            0: 'batch_size',
            1: 'logits_length'
        },
    }
def export_backbone_name(self):
    return 'model.onnx'
funasr/models/contextual_paraformer/model.py
@@ -17,9 +17,6 @@
from distutils.version import LooseVersion
from funasr.register import tables
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
)
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
from funasr.models.paraformer.model import Paraformer
@@ -29,7 +26,7 @@
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
import pdb
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
@@ -80,7 +77,6 @@
        if self.crit_attn_weight > 0:
            self.attn_loss = torch.nn.L1Loss()
        self.crit_attn_smooth = crit_attn_smooth
    def forward(
        self,
@@ -156,7 +152,6 @@
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    
    def _calc_att_clas_loss(
        self,
        encoder_out: torch.Tensor,
@@ -231,7 +226,6 @@
        
        return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
    
    def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
        ys_pad = ys_pad * tgt_mask[:, :, 0]
@@ -263,7 +257,6 @@
        sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
            input_mask_expand_dim, 0)
        return sematic_embeds * tgt_mask, decoder_out * tgt_mask
    
    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
                                   clas_scale=1.0):
@@ -414,7 +407,6 @@
        
        return results, meta_data
    def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
        def load_seg_dict(seg_dict_file):
            seg_dict = {}
@@ -516,3 +508,12 @@
            hotword_list = None
        return hotword_list
    def export(
        self,
        **kwargs,
    ):
        if 'max_seq_len' not in kwargs:
            kwargs['max_seq_len'] = 512
        from .export_meta import export_rebuild_model
        models = export_rebuild_model(model=self, **kwargs)
        return models
funasr/models/fsmn_vad_streaming/export_meta.py
New file
@@ -0,0 +1,59 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import types
import torch
from funasr.register import tables
def export_rebuild_model(model, **kwargs):
    is_onnx = kwargs.get("type", "onnx") == "onnx"
    encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
    model.encoder = encoder_class(model.encoder, onnx=is_onnx)
    model.forward = types.MethodType(export_forward, model)
    model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
    model.export_input_names = types.MethodType(export_input_names, model)
    model.export_output_names = types.MethodType(export_output_names, model)
    model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
    model.export_name = types.MethodType(export_name, model)
    return model
def export_forward(self, feats: torch.Tensor, *args, **kwargs):
    scores, out_caches = self.encoder(feats, *args)
    return scores, out_caches
def export_dummy_inputs(self, data_in=None, frame=30):
    if data_in is None:
        speech = torch.randn(1, frame, self.encoder_conf.get("input_dim"))
    else:
        speech = None # Undo
    cache_frames = self.encoder_conf.get("lorder") + self.encoder_conf.get("rorder") - 1
    in_cache0 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
    in_cache1 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
    in_cache2 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
    in_cache3 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
    return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
def export_input_names(self):
    return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
def export_output_names(self):
    return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
def export_dynamic_axes(self):
    return {
        'speech': {
            1: 'feats_length'
        },
    }
def export_name(self, ):
    return "model.onnx"
funasr/models/fsmn_vad_streaming/model.py
@@ -644,49 +644,11 @@
        return results, meta_data
    
    def export(self, **kwargs):
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        self.forward = self.export_forward
        return self
    def export_forward(self, feats: torch.Tensor, *args, **kwargs):
        scores, out_caches = self.encoder(feats, *args)
        return scores, out_caches
    def export_dummy_inputs(self, data_in=None, frame=30):
        if data_in is None:
            speech = torch.randn(1, frame, self.encoder_conf.get("input_dim"))
        else:
            speech = None # Undo
        cache_frames = self.encoder_conf.get("lorder") + self.encoder_conf.get("rorder") - 1
        in_cache0 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
        in_cache1 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
        in_cache2 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
        in_cache3 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
        return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
    def export_input_names(self):
        return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
    def export_output_names(self):
        return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
    def export_dynamic_axes(self):
        return {
            'speech': {
                1: 'feats_length'
            },
        }
    def export_name(self, ):
        return "model.onnx"
        from .export_meta import export_rebuild_model
        models = export_rebuild_model(model=self, **kwargs)
        return models
    def DetectCommonFrames(self, cache: dict = {}) -> int:
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
funasr/models/paraformer/decoder.py
@@ -616,6 +616,22 @@
        return x, tgt_mask, memory, memory_mask, cache
    def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        residual = tgt
        tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn is not None:
            tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + x
        residual = x
        x = self.norm3(x)
        x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
        return attn_mat
@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
@@ -675,6 +691,8 @@
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        return_hidden: bool = False,
        return_both: bool = False,
    ):
        
        tgt = ys_in_pad
@@ -698,11 +716,60 @@
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
            x, tgt_mask, memory, memory_mask
        )
        x = self.after_norm(x)
        x = self.output_layer(x)
        hidden = self.after_norm(x)
        # x = self.output_layer(x)
        
        return x, ys_in_lens
        if self.output_layer is not None and return_hidden is False:
            x = self.output_layer(hidden)
            return x, ys_in_lens
        if return_both:
            x = self.output_layer(hidden)
            return x, hidden, ys_in_lens
        return hidden, ys_in_lens
    
    def forward_asf2(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        _, memory_mask = self.prepare_mask(memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
        attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    def forward_asf6(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        _, memory_mask = self.prepare_mask(memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](tgt, tgt_mask, memory, memory_mask)
        attn_mat = self.model.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    '''
    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
@@ -751,7 +818,8 @@
            for d in range(cache_num)
        })
        return ret
    '''
@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
    def __init__(self, model,
funasr/models/paraformer/export_meta.py
New file
@@ -0,0 +1,85 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import types
import torch
from funasr.register import tables
def export_rebuild_model(model, **kwargs):
        model.device = kwargs.get("device")
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
        model.encoder = encoder_class(model.encoder, onnx=is_onnx)
        predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
        model.predictor = predictor_class(model.predictor, onnx=is_onnx)
        decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
        model.decoder = decoder_class(model.decoder, onnx=is_onnx)
        from funasr.utils.torch_function import sequence_mask
        model.make_pad_mask = sequence_mask(kwargs['max_seq_len'], flip=False)
        model.forward = types.MethodType(export_forward, model)
        model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
        model.export_input_names = types.MethodType(export_input_names, model)
        model.export_output_names = types.MethodType(export_output_names, model)
        model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
        model.export_name = types.MethodType(export_name, model)
        return model
def export_forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ):
    # a. To device
    batch = {"speech": speech, "speech_lengths": speech_lengths}
    # batch = to_device(batch, device=self.device)
    enc, enc_len = self.encoder(**batch)
    mask = self.make_pad_mask(enc_len)[:, None, :]
    pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
    pre_token_length = pre_token_length.floor().type(torch.int32)
    decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
    decoder_out = torch.log_softmax(decoder_out, dim=-1)
    # sample_ids = decoder_out.argmax(dim=-1)
    return decoder_out, pre_token_length
def export_dummy_inputs(self):
    speech = torch.randn(2, 30, 560)
    speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
    return (speech, speech_lengths)
def export_input_names(self):
    return ['speech', 'speech_lengths']
def export_output_names(self):
    return ['logits', 'token_num']
def export_dynamic_axes(self):
    return {
        'speech': {
            0: 'batch_size',
            1: 'feats_length'
        },
        'speech_lengths': {
            0: 'batch_size',
        },
        'logits': {
            0: 'batch_size',
            1: 'logits_length'
        },
    }
def export_name(self, ):
    return "model.onnx"
funasr/models/paraformer/model.py
@@ -13,15 +13,16 @@
from funasr.models.ctc.ctc import CTC
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import to_device
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.search import Hypothesis
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.train_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.train_utils.device_funcs import to_device
@tables.register("model_classes", "Paraformer")
class Paraformer(torch.nn.Module):
@@ -549,78 +550,10 @@
                
        return results, meta_data
    def export(
        self,
        max_seq_len=512,
        **kwargs,
    ):
        self.device = kwargs.get("device")
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
    def export(self, **kwargs):
        from .export_meta import export_rebuild_model
        if 'max_seq_len' not in kwargs:
            kwargs['max_seq_len'] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models
        decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
        from funasr.utils.torch_function import sequence_mask
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        self.forward = self.export_forward
        return self
    def export_forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ):
        # a. To device
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        batch = to_device(batch, device=self.device)
        enc, enc_len = self.encoder(**batch)
        mask = self.make_pad_mask(enc_len)[:, None, :]
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
        pre_token_length = pre_token_length.floor().type(torch.int32)
        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        # sample_ids = decoder_out.argmax(dim=-1)
        return decoder_out, pre_token_length
    def export_dummy_inputs(self):
        speech = torch.randn(2, 30, 560)
        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
        return (speech, speech_lengths)
    def export_input_names(self):
        return ['speech', 'speech_lengths']
    def export_output_names(self):
        return ['logits', 'token_num']
    def export_dynamic_axes(self):
        return {
            'speech': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'speech_lengths': {
                0: 'batch_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
        }
    def export_name(self, ):
        return "model.onnx"
funasr/models/sanm/attention.py
@@ -697,10 +697,10 @@
        self.attn = None
        self.all_head_size = self.h * self.d_k
    def forward(self, x, memory, memory_mask):
    def forward(self, x, memory, memory_mask, ret_attn=False):
        q, k, v = self.forward_qkv(x, memory)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        return self.forward_attention(v, scores, memory_mask)
        return self.forward_attention(v, scores, memory_mask, ret_attn)
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
@@ -717,7 +717,7 @@
        v = self.transpose_for_scores(v)
        return q, k, v
    def forward_attention(self, value, scores, mask):
    def forward_attention(self, value, scores, mask, ret_attn):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
@@ -726,6 +726,7 @@
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        if ret_attn: return self.linear_out(context_layer), self.attn
        return self.linear_out(context_layer)  # (batch, time1, d_model)
funasr/models/scama/decoder.py
@@ -474,383 +474,3 @@
        return y, new_cache
    def gen_tf2torch_map_dict(self):
        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
        embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
        map_dict_local = {
            ## decoder
            # ffn
            "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (1024,256),(1,256,1024)
            "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,1024),(1,1024,256)
            # fsmn
            "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
                    tensor_name_prefix_tf),
                    "squeeze": None,
                    "transpose": None,
                },  # (256,),(256,)
            "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
                    tensor_name_prefix_tf),
                    "squeeze": None,
                    "transpose": None,
                },  # (256,),(256,)
            "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
                    tensor_name_prefix_tf),
                    "squeeze": 0,
                    "transpose": (1, 2, 0),
                },  # (256,1,31),(1,31,256,1)
            # src att
            "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,256),(1,256,256)
            "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (1024,256),(1,256,1024)
            "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,256),(1,256,256)
            "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            # dnn
            "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (1024,256),(1,256,1024)
            "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,1024),(1,1024,256)
            # embed_concat_ffn
            "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (1024,256),(1,256,1024)
            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,1024),(1,1024,256)
            # out norm
            "{}.after_norm.weight".format(tensor_name_prefix_torch):
                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.after_norm.bias".format(tensor_name_prefix_torch):
                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            # in embed
            "{}.embed.0.weight".format(tensor_name_prefix_torch):
                {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (4235,256),(4235,256)
            # out layer
            "{}.output_layer.weight".format(tensor_name_prefix_torch):
                {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
                          "{}/w_embs".format(embed_tensor_name_prefix_tf)],
                 "squeeze": [None, None],
                 "transpose": [(1, 0), None],
                 },  # (4235,256),(256,4235)
            "{}.output_layer.bias".format(tensor_name_prefix_torch):
                {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
                          "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
                 "squeeze": [None, None],
                 "transpose": [None, None],
                 },  # (4235,),(4235,)
        }
        return map_dict_local
    def convert_tf2torch(self,
                         var_dict_tf,
                         var_dict_torch,
                         ):
        map_dict = self.gen_tf2torch_map_dict()
        var_dict_torch_update = dict()
        decoder_layeridx_sets = set()
        for name in sorted(var_dict_torch.keys(), reverse=False):
            names = name.split('.')
            if names[0] == self.tf2torch_tensor_name_prefix_torch:
                if names[1] == "decoders":
                    layeridx = int(names[2])
                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                    layeridx_bias = 0
                    layeridx += layeridx_bias
                    decoder_layeridx_sets.add(layeridx)
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                          var_dict_tf[name_tf].shape))
                elif names[1] == "decoders2":
                    layeridx = int(names[2])
                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                    name_q = name_q.replace("decoders2", "decoders")
                    layeridx_bias = len(decoder_layeridx_sets)
                    layeridx += layeridx_bias
                    if "decoders." in name:
                        decoder_layeridx_sets.add(layeridx)
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                          var_dict_tf[name_tf].shape))
                elif names[1] == "decoders3":
                    layeridx = int(names[2])
                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                    layeridx_bias = 0
                    layeridx += layeridx_bias
                    if "decoders." in name:
                        decoder_layeridx_sets.add(layeridx)
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                          var_dict_tf[name_tf].shape))
                elif names[1] == "embed" or names[1] == "output_layer":
                    name_tf = map_dict[name]["name"]
                    if isinstance(name_tf, list):
                        idx_list = 0
                        if name_tf[idx_list] in var_dict_tf.keys():
                            pass
                        else:
                            idx_list = 1
                        data_tf = var_dict_tf[name_tf[idx_list]]
                        if map_dict[name]["squeeze"][idx_list] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
                        if map_dict[name]["transpose"][idx_list] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
                                                                                                   name_tf[idx_list],
                                                                                                   var_dict_tf[name_tf[
                                                                                                       idx_list]].shape))
                    else:
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
                        if map_dict[name]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                          var_dict_tf[name_tf].shape))
                elif names[1] == "after_norm":
                    name_tf = map_dict[name]["name"]
                    data_tf = var_dict_tf[name_tf]
                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                    var_dict_torch_update[name] = data_tf
                    logging.info(
                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                      var_dict_tf[name_tf].shape))
                elif names[1] == "embed_concat_ffn":
                    layeridx = int(names[2])
                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                    layeridx_bias = 0
                    layeridx += layeridx_bias
                    if "decoders." in name:
                        decoder_layeridx_sets.add(layeridx)
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                          var_dict_tf[name_tf].shape))
        return var_dict_torch_update
funasr/models/scama/encoder.py
@@ -460,160 +460,3 @@
        return xs_pad, ilens, None
    def gen_tf2torch_map_dict(self):
        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
        map_dict_local = {
            ## encoder
            # cicd
            "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (768,256),(1,256,768)
            "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (768,),(768,)
            "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 2, 0),
                 },  # (256,1,31),(1,31,256,1)
            "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,256),(1,256,256)
            "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            # ffn
            "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (1024,256),(1,256,1024)
            "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,1024),(1,1024,256)
            "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            # out norm
            "{}.after_norm.weight".format(tensor_name_prefix_torch):
                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.after_norm.bias".format(tensor_name_prefix_torch):
                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
        }
        return map_dict_local
    def convert_tf2torch(self,
                         var_dict_tf,
                         var_dict_torch,
                         ):
        map_dict = self.gen_tf2torch_map_dict()
        var_dict_torch_update = dict()
        for name in sorted(var_dict_torch.keys(), reverse=False):
            names = name.split('.')
            if names[0] == self.tf2torch_tensor_name_prefix_torch:
                if names[1] == "encoders0":
                    layeridx = int(names[2])
                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                    name_q = name_q.replace("encoders0", "encoders")
                    layeridx_bias = 0
                    layeridx += layeridx_bias
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                          var_dict_tf[name_tf].shape))
                elif names[1] == "encoders":
                    layeridx = int(names[2])
                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                    layeridx_bias = 1
                    layeridx += layeridx_bias
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
                                                                                                        var_dict_torch[
                                                                                                            name].size(),
                                                                                                        data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info(
                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                          var_dict_tf[name_tf].shape))
                elif names[1] == "after_norm":
                    name_tf = map_dict[name]["name"]
                    data_tf = var_dict_tf[name_tf]
                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                    var_dict_torch_update[name] = data_tf
                    logging.info(
                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                      var_dict_tf[name_tf].shape))
        return var_dict_torch_update
funasr/models/seaco_paraformer/export_meta.py
New file
@@ -0,0 +1,181 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
from funasr.register import tables
class ContextualEmbedderExport(torch.nn.Module):
    def __init__(self,
                 model,
                 max_seq_len=512,
                 feats_dim=560,
                 **kwargs,):
        super().__init__()
        self.embedding = model.decoder.embed # model.bias_embed
        model.bias_encoder.batch_first = False
        self.bias_encoder = model.bias_encoder
    def forward(self, hotword):
        hotword = self.embedding(hotword).transpose(0, 1) # batch second
        hw_embed, (_, _) = self.bias_encoder(hotword)
        return hw_embed
    def export_dummy_inputs(self):
        hotword = torch.tensor([
                                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
                                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
                                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
                                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
                                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                               ],
                                dtype=torch.int32)
        # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
        return (hotword)
    def export_input_names(self):
        return ['hotword']
    def export_output_names(self):
        return ['hw_embed']
    def export_dynamic_axes(self):
        return {
            'hotword': {
                0: 'num_hotwords',
            },
            'hw_embed': {
                0: 'num_hotwords',
            },
        }
    def export_name(self):
        return 'model_eb.onnx'
def export_rebuild_model(model, **kwargs):
        model.device = kwargs.get("device")
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
        model.encoder = encoder_class(model.encoder, onnx=is_onnx)
        predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
        model.predictor = predictor_class(model.predictor, onnx=is_onnx)
        # before decoder convert into export class
        embedder_class = ContextualEmbedderExport
        embedder_model = embedder_class(model, onnx=is_onnx)
        decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
        model.decoder = decoder_class(model.decoder, onnx=is_onnx)
        seaco_decoder_class = tables.decoder_classes.get(kwargs["seaco_decoder"]+"Export")
        model.seaco_decoder = seaco_decoder_class(model.seaco_decoder, onnx=is_onnx)
        from funasr.utils.torch_function import sequence_mask
        model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
        from funasr.utils.torch_function import sequence_mask
        model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
        model.feats_dim = 560
        model.NOBIAS = 8377
        import copy
        import types
        backbone_model = copy.copy(model)
        # backbone
        backbone_model.forward = types.MethodType(export_backbone_forward, backbone_model)
        backbone_model.export_dummy_inputs = types.MethodType(export_backbone_dummy_inputs, backbone_model)
        backbone_model.export_input_names = types.MethodType(export_backbone_input_names, backbone_model)
        backbone_model.export_output_names = types.MethodType(export_backbone_output_names, backbone_model)
        backbone_model.export_dynamic_axes = types.MethodType(export_backbone_dynamic_axes, backbone_model)
        backbone_model.export_name = types.MethodType(export_backbone_name, backbone_model)
        return backbone_model, embedder_model
def export_backbone_forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        bias_embed: torch.Tensor,
        # lmbd: float,
    ):
    # a. To device
    batch = {"speech": speech, "speech_lengths": speech_lengths}
    enc, enc_len = self.encoder(**batch)
    mask = self.make_pad_mask(enc_len)[:, None, :]
    pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
    pre_token_length = pre_token_length.floor().type(torch.int32)
    decoder_out, decoder_hidden, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, return_hidden=True, return_both=True)
    decoder_out = torch.log_softmax(decoder_out, dim=-1)
    # seaco forward
    B, N, D = bias_embed.shape
    _contextual_length = torch.ones(B) * N
    # ASF
    hotword_scores = self.seaco_decoder.forward_asf6(bias_embed, _contextual_length, decoder_hidden, pre_token_length)
    hotword_scores = hotword_scores[0].sum(0).sum(0)
    # _ = self.decoder2(bias_embed, _contextual_length, decoder_hidden, pre_token_length)
    # hotword_scores = self.decoder2.model.decoders[-1].attn_mat[0][0].sum(0).sum(0)
    dec_filter = torch.sort(hotword_scores, descending=True)[1][:51]
    contextual_info = bias_embed[:,dec_filter]
    num_hot_word = contextual_info.shape[1]
    _contextual_length = torch.Tensor([num_hot_word]).int().repeat(B).to(enc.device)
    # again
    cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, pre_acoustic_embeds, pre_token_length)
    dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, pre_token_length)
    merged = cif_attended + dec_attended
    dha_output = self.hotword_output_layer(merged)
    dha_pred = torch.log_softmax(dha_output, dim=-1)
    # merging logits
    dha_ids = dha_pred.max(-1)[-1]
    dha_mask = (dha_ids == self.NOBIAS).int().unsqueeze(-1)
    decoder_out = decoder_out * dha_mask + dha_pred * (1-dha_mask)
    return decoder_out, pre_token_length, alphas
def export_backbone_dummy_inputs(self):
    speech = torch.randn(2, 30, self.feats_dim)
    speech_lengths = torch.tensor([15, 30], dtype=torch.int32)
    bias_embed = torch.randn(2, 1, 512)
    return (speech, speech_lengths, bias_embed)
def export_backbone_input_names(self):
    return ['speech', 'speech_lengths', 'bias_embed']
def export_backbone_output_names(self):
    return ['logits', 'token_num', 'alphas']
def export_backbone_dynamic_axes(self):
    return {
        'speech': {
            0: 'batch_size',
            1: 'feats_length'
        },
        'speech_lengths': {
            0: 'batch_size',
        },
        'bias_embed': {
            0: 'batch_size',
            1: 'num_hotwords'
        },
        'logits': {
            0: 'batch_size',
            1: 'logits_length'
        },
        'pre_acoustic_embeds': {
            1: 'feats_length1'
        }
    }
def export_backbone_name(self):
    return 'model.onnx'
funasr/models/seaco_paraformer/model.py
@@ -430,7 +430,6 @@
        
        return results, meta_data
    def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
        def load_seg_dict(seg_dict_file):
            seg_dict = {}
@@ -532,3 +531,13 @@
            hotword_list = None
        return hotword_list
    def export(
        self,
        **kwargs,
    ):
        if 'max_seq_len' not in kwargs:
            kwargs['max_seq_len'] = 512
        from .export_meta import export_rebuild_model
        models = export_rebuild_model(model=self, **kwargs)
        return models
funasr/models/sond/encoder/ci_scorers.py
@@ -16,9 +16,6 @@
        scores = torch.matmul(xs_pad, spk_emb.transpose(1, 2))
        return scores
    def convert_tf2torch(self, var_dict_tf, var_dict_torch):
        return {}
class CosScorer(torch.nn.Module):
    def __init__(self):
@@ -33,6 +30,3 @@
        # spk_emb: B, N, D
        scores = F.cosine_similarity(xs_pad.unsqueeze(2), spk_emb.unsqueeze(1), dim=-1)
        return scores
    def convert_tf2torch(self, var_dict_tf, var_dict_torch):
        return {}
funasr/models/sond/encoder/conv_encoder.py
@@ -173,103 +173,3 @@
        return outputs, ilens, None
    def gen_tf2torch_map_dict(self):
        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
        map_dict_local = {
            # torch: conv1d.weight in "out_channel in_channel kernel_size"
            # tf   : conv1d.weight in "kernel_size in_channel out_channel"
            # torch: linear.weight in "out_channel in_channel"
            # tf   :  dense.weight in "in_channel out_channel"
            "{}.cnn_a.0.conv1d.weight".format(tensor_name_prefix_torch):
                {"name": "{}/cnn_a/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": (2, 1, 0),
                 },
            "{}.cnn_a.0.conv1d.bias".format(tensor_name_prefix_torch):
                {"name": "{}/cnn_a/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.cnn_a.layeridx.conv1d.weight".format(tensor_name_prefix_torch):
                {"name": "{}/cnn_a/conv1d_layeridx/kernel".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": (2, 1, 0),
                 },
            "{}.cnn_a.layeridx.conv1d.bias".format(tensor_name_prefix_torch):
                {"name": "{}/cnn_a/conv1d_layeridx/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
        }
        if self.out_units is not None:
            # add output layer
            map_dict_local.update({
                "{}.conv_out.weight".format(tensor_name_prefix_torch):
                    {"name": "{}/cnn_a/conv1d_{}/kernel".format(tensor_name_prefix_tf, self.num_layers),
                     "squeeze": None,
                     "transpose": (2, 1, 0),
                     },  # tf: (1, 256, 256) -> torch: (256, 256, 1)
                "{}.conv_out.bias".format(tensor_name_prefix_torch):
                    {"name": "{}/cnn_a/conv1d_{}/bias".format(tensor_name_prefix_tf, self.num_layers),
                     "squeeze": None,
                     "transpose": None,
                     },  # tf: (256,) -> torch: (256,)
            })
        return map_dict_local
    def convert_tf2torch(self,
                         var_dict_tf,
                         var_dict_torch,
                         ):
        map_dict = self.gen_tf2torch_map_dict()
        var_dict_torch_update = dict()
        for name in sorted(var_dict_torch.keys(), reverse=False):
            if name.startswith(self.tf2torch_tensor_name_prefix_torch):
                # process special (first and last) layers
                if name in map_dict:
                    name_tf = map_dict[name]["name"]
                    data_tf = var_dict_tf[name_tf]
                    if map_dict[name]["squeeze"] is not None:
                        data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
                    if map_dict[name]["transpose"] is not None:
                        data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                    assert var_dict_torch[name].size() == data_tf.size(), \
                        "{}, {}, {} != {}".format(name, name_tf,
                                                  var_dict_torch[name].size(), data_tf.size())
                    var_dict_torch_update[name] = data_tf
                    logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
                        name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
                    ))
                # process general layers
                else:
                    # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
                    names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
                    layeridx = int(names[2])
                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), \
                            "{}, {}, {} != {}".format(name, name_tf,
                                                      var_dict_torch[name].size(), data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
                            name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
                        ))
                    else:
                        logging.warning("{} is missed from tf checkpoint".format(name))
        return var_dict_torch_update
funasr/models/sond/encoder/fsmn_encoder.py
@@ -195,140 +195,3 @@
        return inputs, ilens, None
    def gen_tf2torch_map_dict(self):
        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
        map_dict_local = {
            # torch: conv1d.weight in "out_channel in_channel kernel_size"
            # tf   : conv1d.weight in "kernel_size in_channel out_channel"
            # torch: linear.weight in "out_channel in_channel"
            # tf   :  dense.weight in "in_channel out_channel"
            # for fsmn_layers
            "{}.fsmn_layers.layeridx.ffn.norm.bias".format(tensor_name_prefix_torch):
                {"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.fsmn_layers.layeridx.ffn.norm.weight".format(tensor_name_prefix_torch):
                {"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.fsmn_layers.layeridx.ffn.w_1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/fsmn_layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.fsmn_layers.layeridx.ffn.w_1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/fsmn_layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": (2, 1, 0),
                 },
            "{}.fsmn_layers.layeridx.ffn.w_2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/fsmn_layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": (2, 1, 0),
                 },
            "{}.fsmn_layers.layeridx.memory.fsmn_block.weight".format(tensor_name_prefix_torch):
                {"name": "{}/fsmn_layer_layeridx/memory/depth_conv_w".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 2, 0),
                 },  # (1, 31, 512, 1) -> (31, 512, 1) -> (512, 1, 31)
            # for dnn_layers
            "{}.dnn_layers.layeridx.norm.bias".format(tensor_name_prefix_torch):
                {"name": "{}/dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.dnn_layers.layeridx.norm.weight".format(tensor_name_prefix_torch):
                {"name": "{}/dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.dnn_layers.layeridx.w_1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.dnn_layers.layeridx.w_1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": (2, 1, 0),
                 },
            "{}.dnn_layers.layeridx.w_2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": (2, 1, 0),
                 },
        }
        if self.out_units is not None:
            # add output layer
            map_dict_local.update({
                "{}.conv1d.weight".format(tensor_name_prefix_torch):
                    {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
                     "squeeze": None,
                     "transpose": (2, 1, 0),
                     },
                "{}.conv1d.bias".format(tensor_name_prefix_torch):
                    {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
                     "squeeze": None,
                     "transpose": None,
                     },
            })
        return map_dict_local
    def convert_tf2torch(self,
                         var_dict_tf,
                         var_dict_torch,
                         ):
        map_dict = self.gen_tf2torch_map_dict()
        var_dict_torch_update = dict()
        for name in sorted(var_dict_torch.keys(), reverse=False):
            if name.startswith(self.tf2torch_tensor_name_prefix_torch):
                # process special (first and last) layers
                if name in map_dict:
                    name_tf = map_dict[name]["name"]
                    data_tf = var_dict_tf[name_tf]
                    if map_dict[name]["squeeze"] is not None:
                        data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
                    if map_dict[name]["transpose"] is not None:
                        data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                    assert var_dict_torch[name].size() == data_tf.size(), \
                        "{}, {}, {} != {}".format(name, name_tf,
                                                  var_dict_torch[name].size(), data_tf.size())
                    var_dict_torch_update[name] = data_tf
                    logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
                        name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
                    ))
                # process general layers
                else:
                    # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
                    names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
                    layeridx = int(names[2])
                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), \
                            "{}, {}, {} != {}".format(name, name_tf,
                                                      var_dict_torch[name].size(), data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
                            name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
                        ))
                    else:
                        logging.warning("{} is missed from tf checkpoint".format(name))
        return var_dict_torch_update
funasr/models/sond/encoder/resnet34_encoder.py
@@ -245,147 +245,6 @@
        return features, resnet_out_lens
    def gen_tf2torch_map_dict(self):
        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
        train_steps = self.tf_train_steps
        map_dict_local = {
            # torch: conv1d.weight in "out_channel in_channel kernel_size"
            # tf   : conv1d.weight in "kernel_size in_channel out_channel"
            # torch: linear.weight in "out_channel in_channel"
            # tf   :  dense.weight in "in_channel out_channel"
            "{}.pre_conv.weight".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": (3, 2, 0, 1),
                 },
            "{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
        }
        for layer_idx in range(3):
            map_dict_local.update({
                "{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
                     },
                "{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
            })
        for block_idx in range(len(self.layers_in_block)):
            for layer_idx in range(self.layers_in_block[block_idx]):
                for i in ["1", "2", "_sc"]:
                    map_dict_local.update({
                        "{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": (3, 2, 0, 1),
                             },
                        "{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": None,
                             },
                        "{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": None,
                             },
                        "{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": None,
                             },
                        "{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": None,
                             },
                        "{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
                    })
        return map_dict_local
    def convert_tf2torch(self,
                         var_dict_tf,
                         var_dict_torch,
                         ):
        map_dict = self.gen_tf2torch_map_dict()
        var_dict_torch_update = dict()
        for name in sorted(var_dict_torch.keys(), reverse=False):
            if name.startswith(self.tf2torch_tensor_name_prefix_torch):
                if name in map_dict:
                    if "num_batches_tracked" not in name:
                        name_tf = map_dict[name]["name"]
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
                        if map_dict[name]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), \
                            "{}, {}, {} != {}".format(name, name_tf,
                                                      var_dict_torch[name].size(), data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
                            name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
                        ))
                    else:
                        var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
                        logging.info("torch tensor: {}, manually assigning to: {}".format(
                            name, map_dict[name]
                        ))
                else:
                    logging.warning("{} is missed from tf checkpoint".format(name))
        return var_dict_torch_update
class ResNet34Diar(ResNet34):
    def __init__(
@@ -477,147 +336,6 @@
        endpoints["resnet2_bn"] = features
        return endpoints[self.embedding_node], ilens, None
    def gen_tf2torch_map_dict(self):
        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
        train_steps = 300000
        map_dict_local = {
            # torch: conv1d.weight in "out_channel in_channel kernel_size"
            # tf   : conv1d.weight in "kernel_size in_channel out_channel"
            # torch: linear.weight in "out_channel in_channel"
            # tf   :  dense.weight in "in_channel out_channel"
            "{}.pre_conv.weight".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": (3, 2, 0, 1),
                 },
            "{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
        }
        for layer_idx in range(3):
            map_dict_local.update({
                "{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": (3, 2, 0, 1) if layer_idx == 0 else (1, 0),
                     },
                "{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
            })
        for block_idx in range(len(self.layers_in_block)):
            for layer_idx in range(self.layers_in_block[block_idx]):
                for i in ["1", "2", "_sc"]:
                    map_dict_local.update({
                        "{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": (3, 2, 0, 1),
                             },
                        "{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": None,
                             },
                        "{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": None,
                             },
                        "{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": None,
                             },
                        "{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": None,
                             },
                        "{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
                    })
        return map_dict_local
    def convert_tf2torch(self,
                         var_dict_tf,
                         var_dict_torch,
                         ):
        map_dict = self.gen_tf2torch_map_dict()
        var_dict_torch_update = dict()
        for name in sorted(var_dict_torch.keys(), reverse=False):
            if name.startswith(self.tf2torch_tensor_name_prefix_torch):
                if name in map_dict:
                    if "num_batches_tracked" not in name:
                        name_tf = map_dict[name]["name"]
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
                        if map_dict[name]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), \
                            "{}, {}, {} != {}".format(name, name_tf,
                                                      var_dict_torch[name].size(), data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
                            name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
                        ))
                    else:
                        var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
                        logging.info("torch tensor: {}, manually assigning to: {}".format(
                            name, map_dict[name]
                        ))
                else:
                    logging.warning("{} is missed from tf checkpoint".format(name))
        return var_dict_torch_update
class ResNet34SpL2RegDiar(ResNet34_SP_L2Reg):
@@ -711,143 +429,3 @@
        return endpoints[self.embedding_node], ilens, None
    def gen_tf2torch_map_dict(self):
        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
        train_steps = 720000
        map_dict_local = {
            # torch: conv1d.weight in "out_channel in_channel kernel_size"
            # tf   : conv1d.weight in "kernel_size in_channel out_channel"
            # torch: linear.weight in "out_channel in_channel"
            # tf   :  dense.weight in "in_channel out_channel"
            "{}.pre_conv.weight".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": (3, 2, 0, 1),
                 },
            "{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
                {"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },
            "{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
        }
        for layer_idx in range(3):
            map_dict_local.update({
                "{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
                     },
                "{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
                    {"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
                     "squeeze": None,
                     "transpose": None,
                     },
                "{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
            })
        for block_idx in range(len(self.layers_in_block)):
            for layer_idx in range(self.layers_in_block[block_idx]):
                for i in ["1", "2", "_sc"]:
                    map_dict_local.update({
                        "{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": (3, 2, 0, 1),
                             },
                        "{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": None,
                             },
                        "{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": None,
                             },
                        "{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": None,
                             },
                        "{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
                            {"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
                             "squeeze": None,
                             "transpose": None,
                             },
                        "{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
                    })
        return map_dict_local
    def convert_tf2torch(self,
                         var_dict_tf,
                         var_dict_torch,
                         ):
        map_dict = self.gen_tf2torch_map_dict()
        var_dict_torch_update = dict()
        for name in sorted(var_dict_torch.keys(), reverse=False):
            if name.startswith(self.tf2torch_tensor_name_prefix_torch):
                if name in map_dict:
                    if "num_batches_tracked" not in name:
                        name_tf = map_dict[name]["name"]
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
                        if map_dict[name]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), \
                            "{}, {}, {} != {}".format(name, name_tf,
                                                      var_dict_torch[name].size(), data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
                            name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
                        ))
                    else:
                        var_dict_torch_update[name] = torch.from_numpy(np.array(map_dict[name])).type(torch.int64).to("cpu")
                        logging.info("torch tensor: {}, manually assigning to: {}".format(
                            name, map_dict[name]
                        ))
                else:
                    logging.warning("{} is missed from tf checkpoint".format(name))
        return var_dict_torch_update
funasr/models/sond/encoder/self_attention_encoder.py
@@ -326,153 +326,3 @@
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None
    def gen_tf2torch_map_dict(self):
        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
        map_dict_local = {
            # cicd
            # torch: conv1d.weight in "out_channel in_channel kernel_size"
            # tf   : conv1d.weight in "kernel_size in_channel out_channel"
            # torch: linear.weight in "out_channel in_channel"
            # tf   :  dense.weight in "in_channel out_channel"
            "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (768,256),(1,256,768)
            "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (768,),(768,)
            "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,256),(1,256,256)
            "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            # ffn
            "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (1024,256),(1,256,1024)
            "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (1024,),(1024,)
            "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
                 "squeeze": 0,
                 "transpose": (1, 0),
                 },  # (256,1024),(1,1024,256)
            "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
                {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            # out norm
            "{}.after_norm.weight".format(tensor_name_prefix_torch):
                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
            "{}.after_norm.bias".format(tensor_name_prefix_torch):
                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
        }
        if self.out_units is not None:
            map_dict_local.update({
                "{}.output_linear.weight".format(tensor_name_prefix_torch):
                    {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
                     "squeeze": 0,
                     "transpose": (1, 0),
                     },
                "{}.output_linear.bias".format(tensor_name_prefix_torch):
                    {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
                     "squeeze": None,
                     "transpose": None,
                     },  # (256,),(256,)
            })
        return map_dict_local
    def convert_tf2torch(self,
                         var_dict_tf,
                         var_dict_torch,
                         ):
        map_dict = self.gen_tf2torch_map_dict()
        var_dict_torch_update = dict()
        for name in sorted(var_dict_torch.keys(), reverse=False):
            if name.startswith(self.tf2torch_tensor_name_prefix_torch):
                # process special (first and last) layers
                if name in map_dict:
                    name_tf = map_dict[name]["name"]
                    data_tf = var_dict_tf[name_tf]
                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                    if map_dict[name]["squeeze"] is not None:
                        data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
                    if map_dict[name]["transpose"] is not None:
                        data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
                    assert var_dict_torch[name].size() == data_tf.size(), \
                        "{}, {}, {} != {}".format(name, name_tf,
                                                  var_dict_torch[name].size(), data_tf.size())
                    var_dict_torch_update[name] = data_tf
                    logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
                        name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
                    ))
                # process general layers
                else:
                    # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
                    names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
                    layeridx = int(names[2])
                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                    if name_q in map_dict.keys():
                        name_v = map_dict[name_q]["name"]
                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
                        data_tf = var_dict_tf[name_tf]
                        if map_dict[name_q]["squeeze"] is not None:
                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
                        if map_dict[name_q]["transpose"] is not None:
                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                        assert var_dict_torch[name].size() == data_tf.size(), \
                            "{}, {}, {} != {}".format(name, name_tf,
                                                      var_dict_torch[name].size(), data_tf.size())
                        var_dict_torch_update[name] = data_tf
                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
                            name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
                        ))
                    else:
                        logging.warning("{} is missed from tf checkpoint".format(name))
        return var_dict_torch_update
funasr/models/sond/pooling/statistic_pooling.py
@@ -38,9 +38,6 @@
        return stat_pooling
    def convert_tf2torch(self, var_dict_tf, var_dict_torch):
        return {}
def statistic_pooling(
        xs_pad: torch.Tensor,
funasr/models/transformer/positionwise_feed_forward.py
@@ -34,3 +34,16 @@
        return self.w_2(self.dropout(self.activation(self.w_1(x))))
class PositionwiseFeedForwardDecoderSANMExport(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.w_1 = model.w_1
        self.w_2 = model.w_2
        self.activation = model.activation
        self.norm = model.norm
    def forward(self, x):
        x = self.activation(self.w_1(x))
        x = self.w_2(self.norm(x))
        return x
funasr/models/transformer/utils/subsampling.py
@@ -368,49 +368,6 @@
        x_len = (x_len - 1) // self.stride + 1
        return x, x_len
    def gen_tf2torch_map_dict(self):
        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
        map_dict_local = {
            ## predictor
            "{}.conv.weight".format(tensor_name_prefix_torch):
                {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": (2, 1, 0),
                 },  # (256,256,3),(3,256,256)
            "{}.conv.bias".format(tensor_name_prefix_torch):
                {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
                 "squeeze": None,
                 "transpose": None,
                 },  # (256,),(256,)
        }
        return map_dict_local
    def convert_tf2torch(self,
                         var_dict_tf,
                         var_dict_torch,
                         ):
        map_dict = self.gen_tf2torch_map_dict()
        var_dict_torch_update = dict()
        for name in sorted(var_dict_torch.keys(), reverse=False):
            names = name.split('.')
            if names[0] == self.tf2torch_tensor_name_prefix_torch:
                name_tf = map_dict[name]["name"]
                data_tf = var_dict_tf[name_tf]
                if map_dict[name]["squeeze"] is not None:
                    data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
                if map_dict[name]["transpose"] is not None:
                    data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
                data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
                var_dict_torch_update[name] = data_tf
                logging.info(
                    "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                  var_dict_tf[name_tf].shape))
        return var_dict_torch_update
class StreamingConvInput(torch.nn.Module):
    """Streaming ConvInput module definition.
runtime/python/onnxruntime/demo_contextual_paraformer.py
@@ -1,12 +1,12 @@
from funasr_onnx import ContextualParaformer
from pathlib import Path
model_dir = "../export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" # your export dir
model_dir = "damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model = ContextualParaformer(model_dir, batch_size=1)
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
wav_path = ['{}/.cache/modelscope/hub/{}/example/asr_example.wav'.format(Path.home(), model_dir)]
hotwords = '随机热词 各种热词 魔搭 阿里巴巴 仏'
hotwords = '你的热词 魔搭'
result = model(wav_path, hotwords)
print(result)
runtime/python/onnxruntime/demo_paraformer_offline.py
@@ -2,14 +2,14 @@
from pathlib import Path
model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_dir = "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir, batch_size=1, quantize=True)
# model_dir = "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir, batch_size=1, quantize=False)
# model = Paraformer(model_dir, batch_size=1, device_id=0)  # gpu
# when using paraformer-large-vad-punc model, you can set plot_timestamp_to="./xx.png" to get figure of alignment besides timestamps
# model = Paraformer(model_dir, batch_size=1, plot_timestamp_to="test.png")
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav'.format(Path.home())]
wav_path = ['{}/.cache/modelscope/hub/{}/example/asr_example.wav'.format(Path.home(), model_dir)]
result = model(wav_path)
print(result)
runtime/python/onnxruntime/demo_paraformer_online.py
@@ -3,7 +3,7 @@
from pathlib import Path
model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
wav_path = '{}/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/example/asr_example.wav'.format(Path.home())
wav_path = ['{}/.cache/modelscope/hub/{}/example/asr_example.wav'.format(Path.home(), model_dir)]
chunk_size = [5, 10, 5]
model = Paraformer(model_dir, batch_size=1, quantize=True, chunk_size=chunk_size, intra_op_num_threads=4) # only support batch_size = 1
runtime/python/onnxruntime/demo_seaco_paraformer.py
New file
@@ -0,0 +1,12 @@
from funasr_onnx import SeacoParaformer
from pathlib import Path
model_dir = "iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = SeacoParaformer(model_dir, batch_size=1)
wav_path = ['{}/.cache/modelscope/hub/{}/example/asr_example.wav'.format(Path.home(), model_dir)]
hotwords = '你的热词 魔搭'
result = model(wav_path, hotwords)
print(result)
runtime/python/onnxruntime/funasr_onnx/__init__.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
from .paraformer_bin import Paraformer, ContextualParaformer
from .paraformer_bin import Paraformer, ContextualParaformer, SeacoParaformer
from .vad_bin import Fsmn_vad
from .vad_bin import Fsmn_vad_online
from .punc_bin import CT_Transformer
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -266,22 +266,32 @@
            model_bb_file = os.path.join(model_dir, 'model.onnx')
            model_eb_file = os.path.join(model_dir, 'model_eb.onnx')
        token_list_file = os.path.join(model_dir, 'tokens.txt')
        self.vocab = {}
        with open(Path(token_list_file), 'r') as fin:
            for i, line in enumerate(fin.readlines()):
                self.vocab[line.strip()] = i
        if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
            print(".onnx is not exist, begin to export onnx")
            try:
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                      "\npip3 install -U funasr\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
        #if quantize:
        #    model_file = os.path.join(model_dir, 'model_quant.onnx')
        #if not os.path.exists(model_file):
        #    logging.error(".onnx model not exist, please export first.")
            model = AutoModel(model=model_dir)
            model_dir = model.export(type="onnx", quantize=quantize)
            
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)
        token_list = os.path.join(model_dir, 'tokens.json')
        with open(token_list, 'r', encoding='utf-8') as f:
            token_list = json.load(f)
        # revert token_list into vocab dict
        self.vocab = {}
        for i, token in enumerate(token_list):
                self.vocab[token] = i
        self.converter = TokenIDConverter(config['token_list'])
        self.converter = TokenIDConverter(token_list)
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(
            cmvn_file=cmvn_file,
@@ -389,4 +399,8 @@
        token = self.converter.ids2tokens(token_int)
        token = token[:valid_token_num-self.pred_bias]
        # texts = sentence_postprocess(token)
        return token
        return token
class SeacoParaformer(ContextualParaformer):
    pass # no difference with contextual_paraformer in method of calling onnx models