游雁
2023-03-31 d0cd484fdc21c06b8bc892bb2ab1c2a25fb1da8a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
 
import torch
from torch import nn
 
 
class EncoderLayerSANM(nn.Module):
    def __init__(
        self,
        model,
    ):
        """Construct an EncoderLayer object."""
        super().__init__()
        self.self_attn = model.self_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2
        self.in_size = model.in_size
        self.size = model.size
 
    def forward(self, x, mask):
 
        residual = x
        x = self.norm1(x)
        x = self.self_attn(x, mask)
        if self.in_size == self.size:
            x = x + residual
        residual = x
        x = self.norm2(x)
        x = self.feed_forward(x)
        x = x + residual
 
        return x, mask
 
 
class EncoderLayerConformer(nn.Module):
    def __init__(
        self,
        model,
    ):
        """Construct an EncoderLayer object."""
        super().__init__()
        self.self_attn = model.self_attn
        self.feed_forward = model.feed_forward
        self.feed_forward_macaron = model.feed_forward_macaron
        self.conv_module = model.conv_module
        self.norm_ff = model.norm_ff
        self.norm_mha = model.norm_mha
        self.norm_ff_macaron = model.norm_ff_macaron
        self.norm_conv = model.norm_conv
        self.norm_final = model.norm_final
        self.size = model.size
 
    def forward(self, x, mask):
        if isinstance(x, tuple):
            x, pos_emb = x[0], x[1]
        else:
            x, pos_emb = x, None
 
        if self.feed_forward_macaron is not None:
            residual = x
            x = self.norm_ff_macaron(x)
            x = residual + self.feed_forward_macaron(x) * 0.5
 
        residual = x
        x = self.norm_mha(x)
 
        x_q = x
 
        if pos_emb is not None:
            x_att = self.self_attn(x_q, x, x, pos_emb, mask)
        else:
            x_att = self.self_attn(x_q, x, x, mask)
        x = residual + x_att
 
        if self.conv_module is not None:
            residual = x
            x = self.norm_conv(x)
            x = residual +  self.conv_module(x)
 
        residual = x
        x = self.norm_ff(x)
        x = residual + self.feed_forward(x) * 0.5
 
        x = self.norm_final(x)
 
        if pos_emb is not None:
            return (x, pos_emb), mask
 
        return x, mask