liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/ct_transformer/export_meta.py
@@ -9,18 +9,18 @@
def export_rebuild_model(model, **kwargs):
    is_onnx = kwargs.get("type", "onnx") == "onnx"
    encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
    model.encoder = encoder_class(model.encoder, onnx=is_onnx)
    model.forward = types.MethodType(export_forward, model)
    model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
    model.export_input_names = types.MethodType(export_input_names, model)
    model.export_output_names = types.MethodType(export_output_names, model)
    model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
    model.export_name = types.MethodType(export_name, model)
    return model
@@ -37,31 +37,31 @@
    y = self.decoder(h)
    return y
def export_dummy_inputs(self):
    length = 120
    text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
    text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
    text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
    return (text_indexes, text_lengths)
def export_input_names(self):
    return ['inputs', 'text_lengths']
    return ["inputs", "text_lengths"]
def export_output_names(self):
    return ['logits']
    return ["logits"]
def export_dynamic_axes(self):
    return {
        'inputs': {
            0: 'batch_size',
            1: 'feats_length'
        "inputs": {0: "batch_size", 1: "feats_length"},
        "text_lengths": {
            0: "batch_size",
        },
        'text_lengths': {
            0: 'batch_size',
        },
        'logits': {
            0: 'batch_size',
            1: 'logits_length'
        },
        "logits": {0: "batch_size", 1: "logits_length"},
    }
def export_name(self):
    return "model.onnx"
    return "model.onnx"