From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/mossformer/mossformer_encoder.py |  123 ++++++++++++++++++++++-------------------
 1 files changed, 66 insertions(+), 57 deletions(-)

diff --git a/funasr/models/mossformer/mossformer_encoder.py b/funasr/models/mossformer/mossformer_encoder.py
index a28c960..888aef8 100644
--- a/funasr/models/mossformer/mossformer_encoder.py
+++ b/funasr/models/mossformer/mossformer_encoder.py
@@ -1,18 +1,20 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+
 try:
     from rotary_embedding_torch import RotaryEmbedding
 except:
-    print("If you want use mossformer, lease install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch")
+    print(
+        "If you want use mossformer, lease install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch"
+    )
 from funasr.models.transformer.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
 from funasr.models.transformer.embedding import ScaledSinuEmbedding
 from funasr.models.transformer.mossformer import FLASH_ShareA_FFConvM
 
 
 def select_norm(norm, dim, shape):
-    """Just a wrapper to select the normalization type.
-    """
+    """Just a wrapper to select the normalization type."""
 
     if norm == "gln":
         return GlobalLayerNorm(dim, shape, elementwise_affine=True)
@@ -30,37 +32,50 @@
         *,
         dim,
         depth,
-        group_size = 256, 
-        query_key_dim = 128, 
-        expansion_factor = 4.,
-        causal = False,
-        attn_dropout = 0.1,
-        norm_type = 'scalenorm',
-        shift_tokens = True
+        group_size=256,
+        query_key_dim=128,
+        expansion_factor=4.0,
+        causal=False,
+        attn_dropout=0.1,
+        norm_type="scalenorm",
+        shift_tokens=True
     ):
         super().__init__()
-        assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
+        assert norm_type in (
+            "scalenorm",
+            "layernorm",
+        ), "norm_type must be one of scalenorm or layernorm"
 
-        if norm_type == 'scalenorm':
+        if norm_type == "scalenorm":
             norm_klass = ScaleNorm
-        elif norm_type == 'layernorm':
+        elif norm_type == "layernorm":
             norm_klass = nn.LayerNorm
 
         self.group_size = group_size
 
-        rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
+        rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
         # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
-        self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])
+        self.layers = nn.ModuleList(
+            [
+                FLASH_ShareA_FFConvM(
+                    dim=dim,
+                    group_size=group_size,
+                    query_key_dim=query_key_dim,
+                    expansion_factor=expansion_factor,
+                    causal=causal,
+                    dropout=attn_dropout,
+                    rotary_pos_emb=rotary_pos_emb,
+                    norm_klass=norm_klass,
+                    shift_tokens=shift_tokens,
+                )
+                for _ in range(depth)
+            ]
+        )
 
-    def forward(
-        self,
-        x,
-        *,
-        mask = None
-    ):
+    def forward(self, x, *, mask=None):
         ii = 0
         for flash in self.layers:
-            x = flash(x, mask = mask)
+            x = flash(x, mask=mask)
             ii = ii + 1
         return x
 
@@ -119,25 +134,19 @@
             self.pos_enc = ScaledSinuEmbedding(out_channels)
 
         self.mdl = Computation_Block(
-                    num_blocks,
-                    out_channels,
-                    norm,
-                    skip_around_intra=skip_around_intra,
-                )
-
-        self.conv1d_out = nn.Conv1d(
-            out_channels, out_channels * num_spks, kernel_size=1
+            num_blocks,
+            out_channels,
+            norm,
+            skip_around_intra=skip_around_intra,
         )
+
+        self.conv1d_out = nn.Conv1d(out_channels, out_channels * num_spks, kernel_size=1)
         self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
         self.prelu = nn.PReLU()
         self.activation = nn.ReLU()
         # gated output layer
-        self.output = nn.Sequential(
-            nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
-        )
-        self.output_gate = nn.Sequential(
-            nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
-        )
+        self.output = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Tanh())
+        self.output_gate = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid())
 
     def forward(self, x):
         """Returns the output tensor.
@@ -165,18 +174,17 @@
         # [B, N, L]
         x = self.conv1d_encoder(x)
         if self.use_global_pos_enc:
-            #x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
+            # x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
             #    x.size(1) ** 0.5)
             base = x
             x = x.transpose(1, -1)
             emb = self.pos_enc(x)
-            emb = emb.transpose(0, -1) 
-            #print('base: {}, emb: {}'.format(base.shape, emb.shape))
+            emb = emb.transpose(0, -1)
+            # print('base: {}, emb: {}'.format(base.shape, emb.shape))
             x = base + emb
-            
 
         # [B, N, S]
-        #for i in range(self.num_modules):
+        # for i in range(self.num_modules):
         #    x = self.dual_mdl[i](x)
         x = self.mdl(x)
         x = self.prelu(x)
@@ -264,6 +272,7 @@
 
         return x
 
+
 class MossFormerM(nn.Module):
     """This class implements the transformer encoder.
 
@@ -293,27 +302,28 @@
     >>> output.shape
     torch.Size([8, 60, 512])
     """
+
     def __init__(
         self,
         num_blocks,
         d_model=None,
         causal=False,
-        group_size = 256,
-        query_key_dim = 128,
-        expansion_factor = 4.,
-        attn_dropout = 0.1
+        group_size=256,
+        query_key_dim=128,
+        expansion_factor=4.0,
+        attn_dropout=0.1,
     ):
         super().__init__()
 
         self.mossformerM = MossformerBlock(
-                           dim=d_model,
-                           depth=num_blocks,
-                           group_size=group_size,
-                           query_key_dim=query_key_dim,
-                           expansion_factor=expansion_factor,
-                           causal=causal,
-                           attn_dropout=attn_dropout
-                              )
+            dim=d_model,
+            depth=num_blocks,
+            group_size=group_size,
+            query_key_dim=query_key_dim,
+            expansion_factor=expansion_factor,
+            causal=causal,
+            attn_dropout=attn_dropout,
+        )
         self.norm = nn.LayerNorm(d_model, eps=1e-6)
 
     def forward(
@@ -371,7 +381,7 @@
         super(Computation_Block, self).__init__()
 
         ##MossFormer2M: MossFormer with recurrence
-        #self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
+        # self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
         ##MossFormerM: the orignal MossFormer
         self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
         self.skip_around_intra = skip_around_intra
@@ -396,12 +406,12 @@
             Output tensor of dimension [B, N, S].
             where, B = Batchsize,
                N = number of filters
-               S = sequence time index 
+               S = sequence time index
         """
         B, N, S = x.shape
         # intra RNN
         # [B, S, N]
-        intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
+        intra = x.permute(0, 2, 1).contiguous()  # .view(B, S, N)
 
         intra = self.intra_mdl(intra)
 
@@ -416,4 +426,3 @@
 
         out = intra
         return out
-

--
Gitblit v1.9.1