From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/mossformer/mossformer_encoder.py | 123 ++++++++++++++++++++++-------------------
1 files changed, 66 insertions(+), 57 deletions(-)
diff --git a/funasr/models/mossformer/mossformer_encoder.py b/funasr/models/mossformer/mossformer_encoder.py
index a28c960..888aef8 100644
--- a/funasr/models/mossformer/mossformer_encoder.py
+++ b/funasr/models/mossformer/mossformer_encoder.py
@@ -1,18 +1,20 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
+
try:
from rotary_embedding_torch import RotaryEmbedding
except:
- print("If you want use mossformer, lease install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch")
+ print(
+ "If you want use mossformer, lease install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch"
+ )
from funasr.models.transformer.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
from funasr.models.transformer.embedding import ScaledSinuEmbedding
from funasr.models.transformer.mossformer import FLASH_ShareA_FFConvM
def select_norm(norm, dim, shape):
- """Just a wrapper to select the normalization type.
- """
+ """Just a wrapper to select the normalization type."""
if norm == "gln":
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
@@ -30,37 +32,50 @@
*,
dim,
depth,
- group_size = 256,
- query_key_dim = 128,
- expansion_factor = 4.,
- causal = False,
- attn_dropout = 0.1,
- norm_type = 'scalenorm',
- shift_tokens = True
+ group_size=256,
+ query_key_dim=128,
+ expansion_factor=4.0,
+ causal=False,
+ attn_dropout=0.1,
+ norm_type="scalenorm",
+ shift_tokens=True
):
super().__init__()
- assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
+ assert norm_type in (
+ "scalenorm",
+ "layernorm",
+ ), "norm_type must be one of scalenorm or layernorm"
- if norm_type == 'scalenorm':
+ if norm_type == "scalenorm":
norm_klass = ScaleNorm
- elif norm_type == 'layernorm':
+ elif norm_type == "layernorm":
norm_klass = nn.LayerNorm
self.group_size = group_size
- rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
+ rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
# max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
- self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])
+ self.layers = nn.ModuleList(
+ [
+ FLASH_ShareA_FFConvM(
+ dim=dim,
+ group_size=group_size,
+ query_key_dim=query_key_dim,
+ expansion_factor=expansion_factor,
+ causal=causal,
+ dropout=attn_dropout,
+ rotary_pos_emb=rotary_pos_emb,
+ norm_klass=norm_klass,
+ shift_tokens=shift_tokens,
+ )
+ for _ in range(depth)
+ ]
+ )
- def forward(
- self,
- x,
- *,
- mask = None
- ):
+ def forward(self, x, *, mask=None):
ii = 0
for flash in self.layers:
- x = flash(x, mask = mask)
+ x = flash(x, mask=mask)
ii = ii + 1
return x
@@ -119,25 +134,19 @@
self.pos_enc = ScaledSinuEmbedding(out_channels)
self.mdl = Computation_Block(
- num_blocks,
- out_channels,
- norm,
- skip_around_intra=skip_around_intra,
- )
-
- self.conv1d_out = nn.Conv1d(
- out_channels, out_channels * num_spks, kernel_size=1
+ num_blocks,
+ out_channels,
+ norm,
+ skip_around_intra=skip_around_intra,
)
+
+ self.conv1d_out = nn.Conv1d(out_channels, out_channels * num_spks, kernel_size=1)
self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
self.prelu = nn.PReLU()
self.activation = nn.ReLU()
# gated output layer
- self.output = nn.Sequential(
- nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
- )
- self.output_gate = nn.Sequential(
- nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
- )
+ self.output = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Tanh())
+ self.output_gate = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid())
def forward(self, x):
"""Returns the output tensor.
@@ -165,18 +174,17 @@
# [B, N, L]
x = self.conv1d_encoder(x)
if self.use_global_pos_enc:
- #x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
+ # x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
# x.size(1) ** 0.5)
base = x
x = x.transpose(1, -1)
emb = self.pos_enc(x)
- emb = emb.transpose(0, -1)
- #print('base: {}, emb: {}'.format(base.shape, emb.shape))
+ emb = emb.transpose(0, -1)
+ # print('base: {}, emb: {}'.format(base.shape, emb.shape))
x = base + emb
-
# [B, N, S]
- #for i in range(self.num_modules):
+ # for i in range(self.num_modules):
# x = self.dual_mdl[i](x)
x = self.mdl(x)
x = self.prelu(x)
@@ -264,6 +272,7 @@
return x
+
class MossFormerM(nn.Module):
"""This class implements the transformer encoder.
@@ -293,27 +302,28 @@
>>> output.shape
torch.Size([8, 60, 512])
"""
+
def __init__(
self,
num_blocks,
d_model=None,
causal=False,
- group_size = 256,
- query_key_dim = 128,
- expansion_factor = 4.,
- attn_dropout = 0.1
+ group_size=256,
+ query_key_dim=128,
+ expansion_factor=4.0,
+ attn_dropout=0.1,
):
super().__init__()
self.mossformerM = MossformerBlock(
- dim=d_model,
- depth=num_blocks,
- group_size=group_size,
- query_key_dim=query_key_dim,
- expansion_factor=expansion_factor,
- causal=causal,
- attn_dropout=attn_dropout
- )
+ dim=d_model,
+ depth=num_blocks,
+ group_size=group_size,
+ query_key_dim=query_key_dim,
+ expansion_factor=expansion_factor,
+ causal=causal,
+ attn_dropout=attn_dropout,
+ )
self.norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(
@@ -371,7 +381,7 @@
super(Computation_Block, self).__init__()
##MossFormer2M: MossFormer with recurrence
- #self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
+ # self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
##MossFormerM: the orignal MossFormer
self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
self.skip_around_intra = skip_around_intra
@@ -396,12 +406,12 @@
Output tensor of dimension [B, N, S].
where, B = Batchsize,
N = number of filters
- S = sequence time index
+ S = sequence time index
"""
B, N, S = x.shape
# intra RNN
# [B, S, N]
- intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
+ intra = x.permute(0, 2, 1).contiguous() # .view(B, S, N)
intra = self.intra_mdl(intra)
@@ -416,4 +426,3 @@
out = intra
return out
-
--
Gitblit v1.9.1