nichongjia-2007
2023-07-07 e5151e047479e3414ed2faa2890bc3e7e17259be
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
 
import torch
import torch.nn as nn
 
class SequentialRNNLM(nn.Module):
    def __init__(self, model, **kwargs):
        super().__init__()
        self.encoder = model.encoder
        self.rnn = model.rnn
        self.rnn_type = model.rnn_type
        self.decoder = model.decoder
        self.nlayers = model.nlayers
        self.nhid = model.nhid
        self.model_name = "seq_rnnlm"
 
    def forward(self, y, hidden1, hidden2=None):
        # batch_score function.
        emb = self.encoder(y)
        if self.rnn_type == "LSTM":
            output, (hidden1, hidden2) = self.rnn(emb, (hidden1, hidden2))
        else:
            output, hidden1 = self.rnn(emb, hidden1)
 
        decoded = self.decoder(
            output.contiguous().view(output.size(0) * output.size(1), output.size(2))
        )
        if self.rnn_type == "LSTM":
            return (
                decoded.view(output.size(0), output.size(1), decoded.size(1)),
                hidden1,
                hidden2,
            )
        else:
            return (
                decoded.view(output.size(0), output.size(1), decoded.size(1)),
                hidden1,
            )
 
    def get_dummy_inputs(self):
        tgt = torch.LongTensor([0, 1]).unsqueeze(0)
        hidden = torch.randn(self.nlayers, 1, self.nhid)
        if self.rnn_type == "LSTM":
            return (tgt, hidden, hidden)
        else:
            return (tgt, hidden)
 
    def get_input_names(self):
        if self.rnn_type == "LSTM":
            return ["x", "in_hidden1", "in_hidden2"]
        else:
            return ["x", "in_hidden1"]
 
    def get_output_names(self):
        if self.rnn_type == "LSTM":
            return ["y", "out_hidden1", "out_hidden2"]
        else:
            return ["y", "out_hidden1"]
 
    def get_dynamic_axes(self):
        ret = {
            "x": {0: "x_batch", 1: "x_length"},
            "y": {0: "y_batch"},
            "in_hidden1": {1: "hidden1_batch"},
            "out_hidden1": {1: "out_hidden1_batch"},
        }
        if self.rnn_type == "LSTM":
            ret.update(
                {
                    "in_hidden2": {1: "hidden2_batch"},
                    "out_hidden2": {1: "out_hidden2_batch"},
                }
            )
        return ret
 
    def get_model_config(self, path):
        return {
            "use_lm": True,
            "model_path": os.path.join(path, f"{self.model_name}.onnx"),
            "lm_type": "SequentialRNNLM",
            "rnn_type": self.rnn_type,
            "nhid": self.nhid,
            "nlayers": self.nlayers,
        }