游雁
2023-03-29 a030ff0f85fd6b1cc2a1d443d2fcfb11ccb1aa8f
funasr/export/models/target_delay_transformer.py
@@ -28,7 +28,7 @@
            onnx = kwargs["onnx"]
        self.embed = model.embed
        self.decoder = model.decoder
        self.model = model
        # self.model = model
        self.feats_dim = self.embed.embedding_dim
        self.num_embeddings = self.embed.num_embeddings
        self.model_name = model_name
@@ -46,71 +46,71 @@
        from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
        from funasr.punctuation.abs_model import AbsPunctuation
        class TargetDelayTransformer(nn.Module):
            def __init__(
                    self,
                    model,
                    max_seq_len=512,
                    model_name='punc_model',
                    **kwargs,
            ):
                super().__init__()
                onnx = False
                if "onnx" in kwargs:
                    onnx = kwargs["onnx"]
                self.embed = model.embed
                self.decoder = model.decoder
                self.model = model
                self.feats_dim = self.embed.embedding_dim
                self.num_embeddings = self.embed.num_embeddings
                self.model_name = model_name
                if isinstance(model.encoder, SANMEncoder):
                    self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
                else:
                    assert False, "Only support samn encode."
            def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
                """Compute loss value from buffer sequences.
                Args:
                    input (torch.Tensor): Input ids. (batch, len)
                    hidden (torch.Tensor): Target ids. (batch, len)
                """
                x = self.embed(input)
                # mask = self._target_mask(input)
                h, _ = self.encoder(x, text_lengths)
                y = self.decoder(h)
                return y
            def get_dummy_inputs(self):
                length = 120
                text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
                text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
                return (text_indexes, text_lengths)
            def get_input_names(self):
                return ['input', 'text_lengths']
            def get_output_names(self):
                return ['logits']
            def get_dynamic_axes(self):
                return {
                    'input': {
                        0: 'batch_size',
                        1: 'feats_length'
                    },
                    'text_lengths': {
                        0: 'batch_size',
                    },
                    'logits': {
                        0: 'batch_size',
                        1: 'logits_length'
                    },
                }
        # class TargetDelayTransformer(nn.Module):
        #
        #     def __init__(
        #             self,
        #             model,
        #             max_seq_len=512,
        #             model_name='punc_model',
        #             **kwargs,
        #     ):
        #         super().__init__()
        #         onnx = False
        #         if "onnx" in kwargs:
        #             onnx = kwargs["onnx"]
        #         self.embed = model.embed
        #         self.decoder = model.decoder
        #         self.model = model
        #         self.feats_dim = self.embed.embedding_dim
        #         self.num_embeddings = self.embed.num_embeddings
        #         self.model_name = model_name
        #
        #         if isinstance(model.encoder, SANMEncoder):
        #             self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
        #         else:
        #             assert False, "Only support samn encode."
        #
        #     def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        #         """Compute loss value from buffer sequences.
        #
        #         Args:
        #             input (torch.Tensor): Input ids. (batch, len)
        #             hidden (torch.Tensor): Target ids. (batch, len)
        #
        #         """
        #         x = self.embed(input)
        #         # mask = self._target_mask(input)
        #         h, _ = self.encoder(x, text_lengths)
        #         y = self.decoder(h)
        #         return y
        #
        #     def get_dummy_inputs(self):
        #         length = 120
        #         text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
        #         text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
        #         return (text_indexes, text_lengths)
        #
        #     def get_input_names(self):
        #         return ['input', 'text_lengths']
        #
        #     def get_output_names(self):
        #         return ['logits']
        #
        #     def get_dynamic_axes(self):
        #         return {
        #             'input': {
        #                 0: 'batch_size',
        #                 1: 'feats_length'
        #             },
        #             'text_lengths': {
        #                 0: 'batch_size',
        #             },
        #             'logits': {
        #                 0: 'batch_size',
        #                 1: 'logits_length'
        #             },
        #         }
        if isinstance(model.encoder, SANMEncoder):
            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)