游雁
2024-06-24 1596f6f414f6f41da66506debb1dff19fffeb3ec
funasr/models/ct_transformer_streaming/export_meta.py
@@ -9,25 +9,28 @@
def export_rebuild_model(model, **kwargs):
    is_onnx = kwargs.get("type", "onnx") == "onnx"
    encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
    model.encoder = encoder_class(model.encoder, onnx=is_onnx)
    model.forward = types.MethodType(export_forward, model)
    model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
    model.export_input_names = types.MethodType(export_input_names, model)
    model.export_output_names = types.MethodType(export_output_names, model)
    model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
    model.export_name = types.MethodType(export_name, model)
    return model
def export_forward(self, inputs: torch.Tensor,
            text_lengths: torch.Tensor,
            vad_indexes: torch.Tensor,
            sub_masks: torch.Tensor,
            ):
def export_forward(
    self,
    inputs: torch.Tensor,
    text_lengths: torch.Tensor,
    vad_indexes: torch.Tensor,
    sub_masks: torch.Tensor,
):
    """Compute loss value from buffer sequences.
    Args:
@@ -41,6 +44,7 @@
    y = self.decoder(h)
    return y
def export_dummy_inputs(self):
    length = 120
    text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
@@ -50,28 +54,23 @@
    sub_masks = torch.tril(sub_masks).type(torch.float32)
    return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
def export_input_names(self):
    return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
    return ["inputs", "text_lengths", "vad_masks", "sub_masks"]
def export_output_names(self):
    return ['logits']
    return ["logits"]
def export_dynamic_axes(self):
    return {
        'inputs': {
            1: 'feats_length'
        },
        'vad_masks': {
            2: 'feats_length1',
            3: 'feats_length2'
        },
        'sub_masks': {
            2: 'feats_length1',
            3: 'feats_length2'
        },
        'logits': {
            1: 'logits_length'
        },
        "inputs": {1: "feats_length"},
        "vad_masks": {2: "feats_length1", 3: "feats_length2"},
        "sub_masks": {2: "feats_length1", 3: "feats_length2"},
        "logits": {1: "logits_length"},
    }
def export_name(self):
    return "model.onnx"