zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/paraformer/export_meta.py
@@ -17,12 +17,12 @@
        predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
        model.predictor = predictor_class(model.predictor, onnx=is_onnx)
        decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
        model.decoder = decoder_class(model.decoder, onnx=is_onnx)
        
        from funasr.utils.torch_function import sequence_mask
        model.make_pad_mask = sequence_mask(kwargs['max_seq_len'], flip=False)
    model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
        
        model.forward = types.MethodType(export_forward, model)
        model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
@@ -54,6 +54,7 @@
    return decoder_out, pre_token_length
def export_dummy_inputs(self):
    speech = torch.randn(2, 30, 560)
    speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
@@ -61,25 +62,24 @@
def export_input_names(self):
    return ['speech', 'speech_lengths']
    return ["speech", "speech_lengths"]
def export_output_names(self):
    return ['logits', 'token_num']
    return ["logits", "token_num"]
def export_dynamic_axes(self):
    return {
        'speech': {
            0: 'batch_size',
            1: 'feats_length'
        "speech": {0: "batch_size", 1: "feats_length"},
        "speech_lengths": {
            0: "batch_size",
        },
        'speech_lengths': {
            0: 'batch_size',
        },
        'logits': {
            0: 'batch_size',
            1: 'logits_length'
        },
        "logits": {0: "batch_size", 1: "logits_length"},
    }
def export_name(self, ):
def export_name(
    self,
):
    return "model.onnx"