From aa3fe1a353bde71d106755d030d9e5300fbde328 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 22 七月 2024 19:02:15 +0800
Subject: [PATCH] python runtime

---
 funasr/models/mossformer/mossformer.py |  206 +++++++++++++++++++++++++++------------------------
 1 files changed, 108 insertions(+), 98 deletions(-)

diff --git a/funasr/models/mossformer/mossformer.py b/funasr/models/mossformer/mossformer.py
index f1e8e28..30aab48 100644
--- a/funasr/models/mossformer/mossformer.py
+++ b/funasr/models/mossformer/mossformer.py
@@ -7,16 +7,20 @@
 def identity(t, *args, **kwargs):
     return t
 
+
 def append_dims(x, num_dims):
     if num_dims <= 0:
         return x
     return x.view(*x.shape, *((1,) * num_dims))
 
+
 def exists(val):
     return val is not None
 
+
 def default(val, d):
     return val if exists(val) else d
+
 
 def padding_to_multiple_of(n, mult):
     remainder = n % mult
@@ -26,7 +30,8 @@
 
 
 class Transpose(nn.Module):
-    """ Wrapper class of torch.transpose() for Sequential module. """
+    """Wrapper class of torch.transpose() for Sequential module."""
+
     def __init__(self, shape: tuple):
         super(Transpose, self).__init__()
         self.shape = shape
@@ -51,17 +56,20 @@
     Returns: outputs
         - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
     """
+
     def __init__(
-            self,
-            in_channels: int,
-            out_channels: int,
-            kernel_size: int,
-            stride: int = 1,
-            padding: int = 0,
-            bias: bool = False,
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        padding: int = 0,
+        bias: bool = False,
     ) -> None:
         super(DepthwiseConv1d, self).__init__()
-        assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
+        assert (
+            out_channels % in_channels == 0
+        ), "out_channels should be constant multiple of in_channels"
         self.conv = nn.Conv1d(
             in_channels=in_channels,
             out_channels=out_channels,
@@ -90,12 +98,13 @@
     Outputs: outputs
         outputs (batch, time, dim): Tensor produces by conformer convolution module.
     """
+
     def __init__(
-            self,
-            in_channels: int,
-            kernel_size: int = 17,
-            expansion_factor: int = 2,
-            dropout_p: float = 0.1,
+        self,
+        in_channels: int,
+        kernel_size: int = 17,
+        expansion_factor: int = 2,
+        dropout_p: float = 0.1,
     ) -> None:
         super(ConvModule, self).__init__()
         assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
@@ -103,7 +112,9 @@
 
         self.sequential = nn.Sequential(
             Transpose(shape=(1, 2)),
-            DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
+            DepthwiseConv1d(
+                in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2
+            ),
         )
 
     def forward(self, inputs):
@@ -111,33 +122,28 @@
 
 
 class OffsetScale(nn.Module):
-    def __init__(self, dim, heads = 1):
+    def __init__(self, dim, heads=1):
         super().__init__()
         self.gamma = nn.Parameter(torch.ones(heads, dim))
         self.beta = nn.Parameter(torch.zeros(heads, dim))
-        nn.init.normal_(self.gamma, std = 0.02)
+        nn.init.normal_(self.gamma, std=0.02)
 
     def forward(self, x):
-        out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
-        return out.unbind(dim = -2)
+        out = einsum("... d, h d -> ... h d", x, self.gamma) + self.beta
+        return out.unbind(dim=-2)
 
 
 class FFConvM(nn.Module):
-    def __init__(
-        self,
-        dim_in,
-        dim_out,
-        norm_klass = nn.LayerNorm,
-        dropout = 0.1
-    ):
+    def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1):
         super().__init__()
         self.mdl = nn.Sequential(
             norm_klass(dim_in),
             nn.Linear(dim_in, dim_out),
             nn.SiLU(),
             ConvModule(dim_out),
-            nn.Dropout(dropout)
+            nn.Dropout(dropout),
         )
+
     def forward(
         self,
         x,
@@ -151,17 +157,17 @@
         self,
         *,
         dim,
-        group_size = 256,
-        query_key_dim = 128,
-        expansion_factor = 1.,
-        causal = False,
-        dropout = 0.1,
-        rotary_pos_emb = None,
-        norm_klass = nn.LayerNorm,
-        shift_tokens = True
+        group_size=256,
+        query_key_dim=128,
+        expansion_factor=1.0,
+        causal=False,
+        dropout=0.1,
+        rotary_pos_emb=None,
+        norm_klass=nn.LayerNorm,
+        shift_tokens=True
     ):
         super().__init__()
-        hidden_dim = int(dim * expansion_factor)        
+        hidden_dim = int(dim * expansion_factor)
         self.group_size = group_size
         self.causal = causal
         self.shift_tokens = shift_tokens
@@ -171,38 +177,32 @@
         # norm
         self.dropout = nn.Dropout(dropout)
         # projections
-        
-        self.to_hidden = FFConvM(
-            dim_in = dim,
-            dim_out = hidden_dim,
-            norm_klass = norm_klass,
-            dropout = dropout,
-            )
-        self.to_qk = FFConvM(
-            dim_in = dim,
-            dim_out = query_key_dim,
-            norm_klass = norm_klass,
-            dropout = dropout,
-            )
 
-        self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
+        self.to_hidden = FFConvM(
+            dim_in=dim,
+            dim_out=hidden_dim,
+            norm_klass=norm_klass,
+            dropout=dropout,
+        )
+        self.to_qk = FFConvM(
+            dim_in=dim,
+            dim_out=query_key_dim,
+            norm_klass=norm_klass,
+            dropout=dropout,
+        )
+
+        self.qk_offset_scale = OffsetScale(query_key_dim, heads=4)
 
         self.to_out = FFConvM(
-            dim_in = dim*2,
-            dim_out = dim,
-            norm_klass = norm_klass,
-            dropout = dropout,
-            )
-        
-        self.gateActivate=nn.Sigmoid() 
+            dim_in=dim * 2,
+            dim_out=dim,
+            norm_klass=norm_klass,
+            dropout=dropout,
+        )
 
-    def forward(
-        self,
-        x,
-        *,
-        mask = None
-    ):
+        self.gateActivate = nn.Sigmoid()
 
+    def forward(self, x, *, mask=None):
         """
         b - batch
         n - sequence length (within groups)
@@ -213,95 +213,105 @@
         j - sequence dimension (target)
         """
 
-        normed_x = x 
+        normed_x = x
 
         # do token shift - a great, costless trick from an independent AI researcher in Shenzhen
         residual = x
 
         if self.shift_tokens:
-            x_shift, x_pass = normed_x.chunk(2, dim = -1)
-            x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
-            normed_x = torch.cat((x_shift, x_pass), dim = -1)
+            x_shift, x_pass = normed_x.chunk(2, dim=-1)
+            x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.0)
+            normed_x = torch.cat((x_shift, x_pass), dim=-1)
 
         # initial projections
 
-        v, u = self.to_hidden(normed_x).chunk(2, dim = -1)
+        v, u = self.to_hidden(normed_x).chunk(2, dim=-1)
         qk = self.to_qk(normed_x)
 
         # offset and scale
         quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
         att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
-        out = (att_u*v ) * self.gateActivate(att_v*u)        
+        out = (att_u * v) * self.gateActivate(att_v * u)
         x = x + self.to_out(out)
         return x
 
-    def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None):
+    def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask=None):
         b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
 
         if exists(mask):
-            lin_mask = rearrange(mask, '... -> ... 1')
-            lin_k = lin_k.masked_fill(~lin_mask, 0.)
+            lin_mask = rearrange(mask, "... -> ... 1")
+            lin_k = lin_k.masked_fill(~lin_mask, 0.0)
 
         # rotate queries and keys
 
         if exists(self.rotary_pos_emb):
-            quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
+            quad_q, lin_q, quad_k, lin_k = map(
+                self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k)
+            )
 
         # padding for groups
 
         padding = padding_to_multiple_of(n, g)
 
         if padding > 0:
-            quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))
+            quad_q, quad_k, lin_q, lin_k, v, u = map(
+                lambda t: F.pad(t, (0, 0, 0, padding), value=0.0),
+                (quad_q, quad_k, lin_q, lin_k, v, u),
+            )
 
-            mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
-            mask = F.pad(mask, (0, padding), value = False)
+            mask = default(mask, torch.ones((b, n), device=device, dtype=torch.bool))
+            mask = F.pad(mask, (0, padding), value=False)
 
         # group along sequence
 
-        quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
+        quad_q, quad_k, lin_q, lin_k, v, u = map(
+            lambda t: rearrange(t, "b (g n) d -> b g n d", n=self.group_size),
+            (quad_q, quad_k, lin_q, lin_k, v, u),
+        )
 
         if exists(mask):
-            mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
+            mask = rearrange(mask, "b (g j) -> b g 1 j", j=g)
 
         # calculate quadratic attention output
 
-        sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
+        sim = einsum("... i d, ... j d -> ... i j", quad_q, quad_k) / g
 
         attn = F.relu(sim) ** 2
         attn = self.dropout(attn)
 
         if exists(mask):
-            attn = attn.masked_fill(~mask, 0.)
+            attn = attn.masked_fill(~mask, 0.0)
 
         if self.causal:
-            causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
-            attn = attn.masked_fill(causal_mask, 0.)
+            causal_mask = torch.ones((g, g), dtype=torch.bool, device=device).triu(1)
+            attn = attn.masked_fill(causal_mask, 0.0)
 
-        quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
-        quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
+        quad_out_v = einsum("... i j, ... j d -> ... i d", attn, v)
+        quad_out_u = einsum("... i j, ... j d -> ... i d", attn, u)
 
         # calculate linear attention output
 
         if self.causal:
-            lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
+            lin_kv = einsum("b g n d, b g n e -> b g d e", lin_k, v) / g
             # exclusive cumulative sum along group dimension
-            lin_kv = lin_kv.cumsum(dim = 1)
-            lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
-            lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
+            lin_kv = lin_kv.cumsum(dim=1)
+            lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.0)
+            lin_out_v = einsum("b g d e, b g n d -> b g n e", lin_kv, lin_q)
 
-            lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
+            lin_ku = einsum("b g n d, b g n e -> b g d e", lin_k, u) / g
             # exclusive cumulative sum along group dimension
-            lin_ku = lin_ku.cumsum(dim = 1)
-            lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
-            lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
+            lin_ku = lin_ku.cumsum(dim=1)
+            lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value=0.0)
+            lin_out_u = einsum("b g d e, b g n d -> b g n e", lin_ku, lin_q)
         else:
-            lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
-            lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
+            lin_kv = einsum("b g n d, b g n e -> b d e", lin_k, v) / n
+            lin_out_v = einsum("b g n d, b d e -> b g n e", lin_q, lin_kv)
 
-            lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
-            lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
+            lin_ku = einsum("b g n d, b g n e -> b d e", lin_k, u) / n
+            lin_out_u = einsum("b g n d, b d e -> b g n e", lin_q, lin_ku)
 
         # fold back groups into full sequence, and excise out padding
-        return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u))
-
+        return map(
+            lambda t: rearrange(t, "b g n d -> b (g n) d")[:, :n],
+            (quad_out_v + lin_out_v, quad_out_u + lin_out_u),
+        )

--
Gitblit v1.9.1