九耳
2023-03-30 3cd71a385a31f987f2db99df902ca36ee02b1813
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from typing import Any
from typing import List
from typing import Tuple
 
import torch
import torch.nn as nn
 
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
#from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
from funasr.punctuation.sanm_encoder import SANMEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.punctuation.abs_model import AbsPunctuation
 
 
class TargetDelayTransformer(nn.Module):
 
    def __init__(
            self,
            model,
            max_seq_len=512,
            model_name='punc_model',
            **kwargs,
    ):
        super().__init__()
        onnx = False
        if "onnx" in kwargs:
            onnx = kwargs["onnx"]
        self.embed = model.embed
        self.decoder = model.decoder
        # self.model = model
        self.feats_dim = self.embed.embedding_dim
        self.num_embeddings = self.embed.num_embeddings
        self.model_name = model_name
        from typing import Any
        from typing import List
        from typing import Tuple
 
        import torch
        import torch.nn as nn
 
        from funasr.export.utils.torch_function import MakePadMask
        from funasr.export.utils.torch_function import sequence_mask
        # from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
        from funasr.punctuation.sanm_encoder import SANMEncoder
        from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
        from funasr.punctuation.abs_model import AbsPunctuation
 
        # class TargetDelayTransformer(nn.Module):
        #
        #     def __init__(
        #             self,
        #             model,
        #             max_seq_len=512,
        #             model_name='punc_model',
        #             **kwargs,
        #     ):
        #         super().__init__()
        #         onnx = False
        #         if "onnx" in kwargs:
        #             onnx = kwargs["onnx"]
        #         self.embed = model.embed
        #         self.decoder = model.decoder
        #         self.model = model
        #         self.feats_dim = self.embed.embedding_dim
        #         self.num_embeddings = self.embed.num_embeddings
        #         self.model_name = model_name
        #
        #         if isinstance(model.encoder, SANMEncoder):
        #             self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
        #         else:
        #             assert False, "Only support samn encode."
        #
        #     def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        #         """Compute loss value from buffer sequences.
        #
        #         Args:
        #             input (torch.Tensor): Input ids. (batch, len)
        #             hidden (torch.Tensor): Target ids. (batch, len)
        #
        #         """
        #         x = self.embed(input)
        #         # mask = self._target_mask(input)
        #         h, _ = self.encoder(x, text_lengths)
        #         y = self.decoder(h)
        #         return y
        #
        #     def get_dummy_inputs(self):
        #         length = 120
        #         text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
        #         text_lengths = torch.tensor([length - 20, length], dtype=torch.int32)
        #         return (text_indexes, text_lengths)
        #
        #     def get_input_names(self):
        #         return ['input', 'text_lengths']
        #
        #     def get_output_names(self):
        #         return ['logits']
        #
        #     def get_dynamic_axes(self):
        #         return {
        #             'input': {
        #                 0: 'batch_size',
        #                 1: 'feats_length'
        #             },
        #             'text_lengths': {
        #                 0: 'batch_size',
        #             },
        #             'logits': {
        #                 0: 'batch_size',
        #                 1: 'logits_length'
        #             },
        #         }
 
        if isinstance(model.encoder, SANMEncoder):
            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
        else:
            assert False, "Only support samn encode."
 
    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        """Compute loss value from buffer sequences.
 
        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)
 
        """
        x = self.embed(input)
        # mask = self._target_mask(input)
        h, _ = self.encoder(x, text_lengths)
        y = self.decoder(h)
        return y
 
    def get_dummy_inputs(self):
        length = 120
        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
        text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
        return (text_indexes, text_lengths)
 
    def get_input_names(self):
        return ['input', 'text_lengths']
 
    def get_output_names(self):
        return ['logits']
 
    def get_dynamic_axes(self):
        return {
            'input': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'text_lengths': {
                0: 'batch_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
        }