From aa3fe1a353bde71d106755d030d9e5300fbde328 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 22 七月 2024 19:02:15 +0800
Subject: [PATCH] python runtime
---
funasr/models/mossformer/mossformer.py | 206 +++++++++++++++++++++++++++------------------------
1 files changed, 108 insertions(+), 98 deletions(-)
diff --git a/funasr/models/mossformer/mossformer.py b/funasr/models/mossformer/mossformer.py
index f1e8e28..30aab48 100644
--- a/funasr/models/mossformer/mossformer.py
+++ b/funasr/models/mossformer/mossformer.py
@@ -7,16 +7,20 @@
def identity(t, *args, **kwargs):
return t
+
def append_dims(x, num_dims):
if num_dims <= 0:
return x
return x.view(*x.shape, *((1,) * num_dims))
+
def exists(val):
return val is not None
+
def default(val, d):
return val if exists(val) else d
+
def padding_to_multiple_of(n, mult):
remainder = n % mult
@@ -26,7 +30,8 @@
class Transpose(nn.Module):
- """ Wrapper class of torch.transpose() for Sequential module. """
+ """Wrapper class of torch.transpose() for Sequential module."""
+
def __init__(self, shape: tuple):
super(Transpose, self).__init__()
self.shape = shape
@@ -51,17 +56,20 @@
Returns: outputs
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
"""
+
def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: int = 1,
- padding: int = 0,
- bias: bool = False,
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ bias: bool = False,
) -> None:
super(DepthwiseConv1d, self).__init__()
- assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
+ assert (
+ out_channels % in_channels == 0
+ ), "out_channels should be constant multiple of in_channels"
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
@@ -90,12 +98,13 @@
Outputs: outputs
outputs (batch, time, dim): Tensor produces by conformer convolution module.
"""
+
def __init__(
- self,
- in_channels: int,
- kernel_size: int = 17,
- expansion_factor: int = 2,
- dropout_p: float = 0.1,
+ self,
+ in_channels: int,
+ kernel_size: int = 17,
+ expansion_factor: int = 2,
+ dropout_p: float = 0.1,
) -> None:
super(ConvModule, self).__init__()
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
@@ -103,7 +112,9 @@
self.sequential = nn.Sequential(
Transpose(shape=(1, 2)),
- DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
+ DepthwiseConv1d(
+ in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2
+ ),
)
def forward(self, inputs):
@@ -111,33 +122,28 @@
class OffsetScale(nn.Module):
- def __init__(self, dim, heads = 1):
+ def __init__(self, dim, heads=1):
super().__init__()
self.gamma = nn.Parameter(torch.ones(heads, dim))
self.beta = nn.Parameter(torch.zeros(heads, dim))
- nn.init.normal_(self.gamma, std = 0.02)
+ nn.init.normal_(self.gamma, std=0.02)
def forward(self, x):
- out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
- return out.unbind(dim = -2)
+ out = einsum("... d, h d -> ... h d", x, self.gamma) + self.beta
+ return out.unbind(dim=-2)
class FFConvM(nn.Module):
- def __init__(
- self,
- dim_in,
- dim_out,
- norm_klass = nn.LayerNorm,
- dropout = 0.1
- ):
+ def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1):
super().__init__()
self.mdl = nn.Sequential(
norm_klass(dim_in),
nn.Linear(dim_in, dim_out),
nn.SiLU(),
ConvModule(dim_out),
- nn.Dropout(dropout)
+ nn.Dropout(dropout),
)
+
def forward(
self,
x,
@@ -151,17 +157,17 @@
self,
*,
dim,
- group_size = 256,
- query_key_dim = 128,
- expansion_factor = 1.,
- causal = False,
- dropout = 0.1,
- rotary_pos_emb = None,
- norm_klass = nn.LayerNorm,
- shift_tokens = True
+ group_size=256,
+ query_key_dim=128,
+ expansion_factor=1.0,
+ causal=False,
+ dropout=0.1,
+ rotary_pos_emb=None,
+ norm_klass=nn.LayerNorm,
+ shift_tokens=True
):
super().__init__()
- hidden_dim = int(dim * expansion_factor)
+ hidden_dim = int(dim * expansion_factor)
self.group_size = group_size
self.causal = causal
self.shift_tokens = shift_tokens
@@ -171,38 +177,32 @@
# norm
self.dropout = nn.Dropout(dropout)
# projections
-
- self.to_hidden = FFConvM(
- dim_in = dim,
- dim_out = hidden_dim,
- norm_klass = norm_klass,
- dropout = dropout,
- )
- self.to_qk = FFConvM(
- dim_in = dim,
- dim_out = query_key_dim,
- norm_klass = norm_klass,
- dropout = dropout,
- )
- self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
+ self.to_hidden = FFConvM(
+ dim_in=dim,
+ dim_out=hidden_dim,
+ norm_klass=norm_klass,
+ dropout=dropout,
+ )
+ self.to_qk = FFConvM(
+ dim_in=dim,
+ dim_out=query_key_dim,
+ norm_klass=norm_klass,
+ dropout=dropout,
+ )
+
+ self.qk_offset_scale = OffsetScale(query_key_dim, heads=4)
self.to_out = FFConvM(
- dim_in = dim*2,
- dim_out = dim,
- norm_klass = norm_klass,
- dropout = dropout,
- )
-
- self.gateActivate=nn.Sigmoid()
+ dim_in=dim * 2,
+ dim_out=dim,
+ norm_klass=norm_klass,
+ dropout=dropout,
+ )
- def forward(
- self,
- x,
- *,
- mask = None
- ):
+ self.gateActivate = nn.Sigmoid()
+ def forward(self, x, *, mask=None):
"""
b - batch
n - sequence length (within groups)
@@ -213,95 +213,105 @@
j - sequence dimension (target)
"""
- normed_x = x
+ normed_x = x
# do token shift - a great, costless trick from an independent AI researcher in Shenzhen
residual = x
if self.shift_tokens:
- x_shift, x_pass = normed_x.chunk(2, dim = -1)
- x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
- normed_x = torch.cat((x_shift, x_pass), dim = -1)
+ x_shift, x_pass = normed_x.chunk(2, dim=-1)
+ x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.0)
+ normed_x = torch.cat((x_shift, x_pass), dim=-1)
# initial projections
- v, u = self.to_hidden(normed_x).chunk(2, dim = -1)
+ v, u = self.to_hidden(normed_x).chunk(2, dim=-1)
qk = self.to_qk(normed_x)
# offset and scale
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
- out = (att_u*v ) * self.gateActivate(att_v*u)
+ out = (att_u * v) * self.gateActivate(att_v * u)
x = x + self.to_out(out)
return x
- def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None):
+ def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask=None):
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
if exists(mask):
- lin_mask = rearrange(mask, '... -> ... 1')
- lin_k = lin_k.masked_fill(~lin_mask, 0.)
+ lin_mask = rearrange(mask, "... -> ... 1")
+ lin_k = lin_k.masked_fill(~lin_mask, 0.0)
# rotate queries and keys
if exists(self.rotary_pos_emb):
- quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
+ quad_q, lin_q, quad_k, lin_k = map(
+ self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k)
+ )
# padding for groups
padding = padding_to_multiple_of(n, g)
if padding > 0:
- quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))
+ quad_q, quad_k, lin_q, lin_k, v, u = map(
+ lambda t: F.pad(t, (0, 0, 0, padding), value=0.0),
+ (quad_q, quad_k, lin_q, lin_k, v, u),
+ )
- mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
- mask = F.pad(mask, (0, padding), value = False)
+ mask = default(mask, torch.ones((b, n), device=device, dtype=torch.bool))
+ mask = F.pad(mask, (0, padding), value=False)
# group along sequence
- quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
+ quad_q, quad_k, lin_q, lin_k, v, u = map(
+ lambda t: rearrange(t, "b (g n) d -> b g n d", n=self.group_size),
+ (quad_q, quad_k, lin_q, lin_k, v, u),
+ )
if exists(mask):
- mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
+ mask = rearrange(mask, "b (g j) -> b g 1 j", j=g)
# calculate quadratic attention output
- sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
+ sim = einsum("... i d, ... j d -> ... i j", quad_q, quad_k) / g
attn = F.relu(sim) ** 2
attn = self.dropout(attn)
if exists(mask):
- attn = attn.masked_fill(~mask, 0.)
+ attn = attn.masked_fill(~mask, 0.0)
if self.causal:
- causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
- attn = attn.masked_fill(causal_mask, 0.)
+ causal_mask = torch.ones((g, g), dtype=torch.bool, device=device).triu(1)
+ attn = attn.masked_fill(causal_mask, 0.0)
- quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
- quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
+ quad_out_v = einsum("... i j, ... j d -> ... i d", attn, v)
+ quad_out_u = einsum("... i j, ... j d -> ... i d", attn, u)
# calculate linear attention output
if self.causal:
- lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
+ lin_kv = einsum("b g n d, b g n e -> b g d e", lin_k, v) / g
# exclusive cumulative sum along group dimension
- lin_kv = lin_kv.cumsum(dim = 1)
- lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
- lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
+ lin_kv = lin_kv.cumsum(dim=1)
+ lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.0)
+ lin_out_v = einsum("b g d e, b g n d -> b g n e", lin_kv, lin_q)
- lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
+ lin_ku = einsum("b g n d, b g n e -> b g d e", lin_k, u) / g
# exclusive cumulative sum along group dimension
- lin_ku = lin_ku.cumsum(dim = 1)
- lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
- lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
+ lin_ku = lin_ku.cumsum(dim=1)
+ lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value=0.0)
+ lin_out_u = einsum("b g d e, b g n d -> b g n e", lin_ku, lin_q)
else:
- lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
- lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
+ lin_kv = einsum("b g n d, b g n e -> b d e", lin_k, v) / n
+ lin_out_v = einsum("b g n d, b d e -> b g n e", lin_q, lin_kv)
- lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
- lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
+ lin_ku = einsum("b g n d, b g n e -> b d e", lin_k, u) / n
+ lin_out_u = einsum("b g n d, b d e -> b g n e", lin_q, lin_ku)
# fold back groups into full sequence, and excise out padding
- return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u))
-
+ return map(
+ lambda t: rearrange(t, "b g n d -> b (g n) d")[:, :n],
+ (quad_out_v + lin_out_v, quad_out_u + lin_out_u),
+ )
--
Gitblit v1.9.1