| | |
| | | def identity(t, *args, **kwargs): |
| | | return t |
| | | |
| | | |
| | | def append_dims(x, num_dims): |
| | | if num_dims <= 0: |
| | | return x |
| | | return x.view(*x.shape, *((1,) * num_dims)) |
| | | |
| | | |
| | | def exists(val): |
| | | return val is not None |
| | | |
| | | |
| | | def default(val, d): |
| | | return val if exists(val) else d |
| | | |
| | | |
| | | def padding_to_multiple_of(n, mult): |
| | | remainder = n % mult |
| | |
| | | |
| | | |
| | | class Transpose(nn.Module): |
| | | """ Wrapper class of torch.transpose() for Sequential module. """ |
| | | """Wrapper class of torch.transpose() for Sequential module.""" |
| | | |
| | | def __init__(self, shape: tuple): |
| | | super(Transpose, self).__init__() |
| | | self.shape = shape |
| | |
| | | Returns: outputs |
| | | - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | in_channels: int, |
| | | out_channels: int, |
| | | kernel_size: int, |
| | | stride: int = 1, |
| | | padding: int = 0, |
| | | bias: bool = False, |
| | | self, |
| | | in_channels: int, |
| | | out_channels: int, |
| | | kernel_size: int, |
| | | stride: int = 1, |
| | | padding: int = 0, |
| | | bias: bool = False, |
| | | ) -> None: |
| | | super(DepthwiseConv1d, self).__init__() |
| | | assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels" |
| | | assert ( |
| | | out_channels % in_channels == 0 |
| | | ), "out_channels should be constant multiple of in_channels" |
| | | self.conv = nn.Conv1d( |
| | | in_channels=in_channels, |
| | | out_channels=out_channels, |
| | |
| | | Outputs: outputs |
| | | outputs (batch, time, dim): Tensor produces by conformer convolution module. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | in_channels: int, |
| | | kernel_size: int = 17, |
| | | expansion_factor: int = 2, |
| | | dropout_p: float = 0.1, |
| | | self, |
| | | in_channels: int, |
| | | kernel_size: int = 17, |
| | | expansion_factor: int = 2, |
| | | dropout_p: float = 0.1, |
| | | ) -> None: |
| | | super(ConvModule, self).__init__() |
| | | assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" |
| | |
| | | |
| | | self.sequential = nn.Sequential( |
| | | Transpose(shape=(1, 2)), |
| | | DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2), |
| | | DepthwiseConv1d( |
| | | in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2 |
| | | ), |
| | | ) |
| | | |
| | | def forward(self, inputs): |
| | |
| | | |
| | | |
| | | class OffsetScale(nn.Module): |
| | | def __init__(self, dim, heads = 1): |
| | | def __init__(self, dim, heads=1): |
| | | super().__init__() |
| | | self.gamma = nn.Parameter(torch.ones(heads, dim)) |
| | | self.beta = nn.Parameter(torch.zeros(heads, dim)) |
| | | nn.init.normal_(self.gamma, std = 0.02) |
| | | nn.init.normal_(self.gamma, std=0.02) |
| | | |
| | | def forward(self, x): |
| | | out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta |
| | | return out.unbind(dim = -2) |
| | | out = einsum("... d, h d -> ... h d", x, self.gamma) + self.beta |
| | | return out.unbind(dim=-2) |
| | | |
| | | |
| | | class FFConvM(nn.Module): |
| | | def __init__( |
| | | self, |
| | | dim_in, |
| | | dim_out, |
| | | norm_klass = nn.LayerNorm, |
| | | dropout = 0.1 |
| | | ): |
| | | def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1): |
| | | super().__init__() |
| | | self.mdl = nn.Sequential( |
| | | norm_klass(dim_in), |
| | | nn.Linear(dim_in, dim_out), |
| | | nn.SiLU(), |
| | | ConvModule(dim_out), |
| | | nn.Dropout(dropout) |
| | | nn.Dropout(dropout), |
| | | ) |
| | | |
| | | def forward( |
| | | self, |
| | | x, |
| | |
| | | self, |
| | | *, |
| | | dim, |
| | | group_size = 256, |
| | | query_key_dim = 128, |
| | | expansion_factor = 1., |
| | | causal = False, |
| | | dropout = 0.1, |
| | | rotary_pos_emb = None, |
| | | norm_klass = nn.LayerNorm, |
| | | shift_tokens = True |
| | | group_size=256, |
| | | query_key_dim=128, |
| | | expansion_factor=1.0, |
| | | causal=False, |
| | | dropout=0.1, |
| | | rotary_pos_emb=None, |
| | | norm_klass=nn.LayerNorm, |
| | | shift_tokens=True |
| | | ): |
| | | super().__init__() |
| | | hidden_dim = int(dim * expansion_factor) |
| | | hidden_dim = int(dim * expansion_factor) |
| | | self.group_size = group_size |
| | | self.causal = causal |
| | | self.shift_tokens = shift_tokens |
| | |
| | | # norm |
| | | self.dropout = nn.Dropout(dropout) |
| | | # projections |
| | | |
| | | self.to_hidden = FFConvM( |
| | | dim_in = dim, |
| | | dim_out = hidden_dim, |
| | | norm_klass = norm_klass, |
| | | dropout = dropout, |
| | | ) |
| | | self.to_qk = FFConvM( |
| | | dim_in = dim, |
| | | dim_out = query_key_dim, |
| | | norm_klass = norm_klass, |
| | | dropout = dropout, |
| | | ) |
| | | |
| | | self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4) |
| | | self.to_hidden = FFConvM( |
| | | dim_in=dim, |
| | | dim_out=hidden_dim, |
| | | norm_klass=norm_klass, |
| | | dropout=dropout, |
| | | ) |
| | | self.to_qk = FFConvM( |
| | | dim_in=dim, |
| | | dim_out=query_key_dim, |
| | | norm_klass=norm_klass, |
| | | dropout=dropout, |
| | | ) |
| | | |
| | | self.qk_offset_scale = OffsetScale(query_key_dim, heads=4) |
| | | |
| | | self.to_out = FFConvM( |
| | | dim_in = dim*2, |
| | | dim_out = dim, |
| | | norm_klass = norm_klass, |
| | | dropout = dropout, |
| | | ) |
| | | |
| | | self.gateActivate=nn.Sigmoid() |
| | | dim_in=dim * 2, |
| | | dim_out=dim, |
| | | norm_klass=norm_klass, |
| | | dropout=dropout, |
| | | ) |
| | | |
| | | def forward( |
| | | self, |
| | | x, |
| | | *, |
| | | mask = None |
| | | ): |
| | | self.gateActivate = nn.Sigmoid() |
| | | |
| | | def forward(self, x, *, mask=None): |
| | | """ |
| | | b - batch |
| | | n - sequence length (within groups) |
| | |
| | | j - sequence dimension (target) |
| | | """ |
| | | |
| | | normed_x = x |
| | | normed_x = x |
| | | |
| | | # do token shift - a great, costless trick from an independent AI researcher in Shenzhen |
| | | residual = x |
| | | |
| | | if self.shift_tokens: |
| | | x_shift, x_pass = normed_x.chunk(2, dim = -1) |
| | | x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.) |
| | | normed_x = torch.cat((x_shift, x_pass), dim = -1) |
| | | x_shift, x_pass = normed_x.chunk(2, dim=-1) |
| | | x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.0) |
| | | normed_x = torch.cat((x_shift, x_pass), dim=-1) |
| | | |
| | | # initial projections |
| | | |
| | | v, u = self.to_hidden(normed_x).chunk(2, dim = -1) |
| | | v, u = self.to_hidden(normed_x).chunk(2, dim=-1) |
| | | qk = self.to_qk(normed_x) |
| | | |
| | | # offset and scale |
| | | quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk) |
| | | att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u) |
| | | out = (att_u*v ) * self.gateActivate(att_v*u) |
| | | out = (att_u * v) * self.gateActivate(att_v * u) |
| | | x = x + self.to_out(out) |
| | | return x |
| | | |
| | | def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None): |
| | | def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask=None): |
| | | b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size |
| | | |
| | | if exists(mask): |
| | | lin_mask = rearrange(mask, '... -> ... 1') |
| | | lin_k = lin_k.masked_fill(~lin_mask, 0.) |
| | | lin_mask = rearrange(mask, "... -> ... 1") |
| | | lin_k = lin_k.masked_fill(~lin_mask, 0.0) |
| | | |
| | | # rotate queries and keys |
| | | |
| | | if exists(self.rotary_pos_emb): |
| | | quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k)) |
| | | quad_q, lin_q, quad_k, lin_k = map( |
| | | self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k) |
| | | ) |
| | | |
| | | # padding for groups |
| | | |
| | | padding = padding_to_multiple_of(n, g) |
| | | |
| | | if padding > 0: |
| | | quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u)) |
| | | quad_q, quad_k, lin_q, lin_k, v, u = map( |
| | | lambda t: F.pad(t, (0, 0, 0, padding), value=0.0), |
| | | (quad_q, quad_k, lin_q, lin_k, v, u), |
| | | ) |
| | | |
| | | mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool)) |
| | | mask = F.pad(mask, (0, padding), value = False) |
| | | mask = default(mask, torch.ones((b, n), device=device, dtype=torch.bool)) |
| | | mask = F.pad(mask, (0, padding), value=False) |
| | | |
| | | # group along sequence |
| | | |
| | | quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u)) |
| | | quad_q, quad_k, lin_q, lin_k, v, u = map( |
| | | lambda t: rearrange(t, "b (g n) d -> b g n d", n=self.group_size), |
| | | (quad_q, quad_k, lin_q, lin_k, v, u), |
| | | ) |
| | | |
| | | if exists(mask): |
| | | mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g) |
| | | mask = rearrange(mask, "b (g j) -> b g 1 j", j=g) |
| | | |
| | | # calculate quadratic attention output |
| | | |
| | | sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g |
| | | sim = einsum("... i d, ... j d -> ... i j", quad_q, quad_k) / g |
| | | |
| | | attn = F.relu(sim) ** 2 |
| | | attn = self.dropout(attn) |
| | | |
| | | if exists(mask): |
| | | attn = attn.masked_fill(~mask, 0.) |
| | | attn = attn.masked_fill(~mask, 0.0) |
| | | |
| | | if self.causal: |
| | | causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1) |
| | | attn = attn.masked_fill(causal_mask, 0.) |
| | | causal_mask = torch.ones((g, g), dtype=torch.bool, device=device).triu(1) |
| | | attn = attn.masked_fill(causal_mask, 0.0) |
| | | |
| | | quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v) |
| | | quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u) |
| | | quad_out_v = einsum("... i j, ... j d -> ... i d", attn, v) |
| | | quad_out_u = einsum("... i j, ... j d -> ... i d", attn, u) |
| | | |
| | | # calculate linear attention output |
| | | |
| | | if self.causal: |
| | | lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g |
| | | lin_kv = einsum("b g n d, b g n e -> b g d e", lin_k, v) / g |
| | | # exclusive cumulative sum along group dimension |
| | | lin_kv = lin_kv.cumsum(dim = 1) |
| | | lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.) |
| | | lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q) |
| | | lin_kv = lin_kv.cumsum(dim=1) |
| | | lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.0) |
| | | lin_out_v = einsum("b g d e, b g n d -> b g n e", lin_kv, lin_q) |
| | | |
| | | lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g |
| | | lin_ku = einsum("b g n d, b g n e -> b g d e", lin_k, u) / g |
| | | # exclusive cumulative sum along group dimension |
| | | lin_ku = lin_ku.cumsum(dim = 1) |
| | | lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.) |
| | | lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q) |
| | | lin_ku = lin_ku.cumsum(dim=1) |
| | | lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value=0.0) |
| | | lin_out_u = einsum("b g d e, b g n d -> b g n e", lin_ku, lin_q) |
| | | else: |
| | | lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n |
| | | lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv) |
| | | lin_kv = einsum("b g n d, b g n e -> b d e", lin_k, v) / n |
| | | lin_out_v = einsum("b g n d, b d e -> b g n e", lin_q, lin_kv) |
| | | |
| | | lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n |
| | | lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku) |
| | | lin_ku = einsum("b g n d, b g n e -> b d e", lin_k, u) / n |
| | | lin_out_u = einsum("b g n d, b d e -> b g n e", lin_q, lin_ku) |
| | | |
| | | # fold back groups into full sequence, and excise out padding |
| | | return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u)) |
| | | |
| | | return map( |
| | | lambda t: rearrange(t, "b g n d -> b (g n) d")[:, :n], |
| | | (quad_out_v + lin_out_v, quad_out_u + lin_out_u), |
| | | ) |