| | |
| | | import torch |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | |
| | | try: |
| | | from rotary_embedding_torch import RotaryEmbedding |
| | | except: |
| | | print("If you want use mossformer, lease install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch") |
| | | print( |
| | | "If you want use mossformer, lease install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch" |
| | | ) |
| | | from funasr.models.transformer.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm |
| | | from funasr.models.transformer.embedding import ScaledSinuEmbedding |
| | | from funasr.models.transformer.mossformer import FLASH_ShareA_FFConvM |
| | | |
| | | |
| | | def select_norm(norm, dim, shape): |
| | | """Just a wrapper to select the normalization type. |
| | | """ |
| | | """Just a wrapper to select the normalization type.""" |
| | | |
| | | if norm == "gln": |
| | | return GlobalLayerNorm(dim, shape, elementwise_affine=True) |
| | |
| | | *, |
| | | dim, |
| | | depth, |
| | | group_size = 256, |
| | | query_key_dim = 128, |
| | | expansion_factor = 4., |
| | | causal = False, |
| | | attn_dropout = 0.1, |
| | | norm_type = 'scalenorm', |
| | | shift_tokens = True |
| | | group_size=256, |
| | | query_key_dim=128, |
| | | expansion_factor=4.0, |
| | | causal=False, |
| | | attn_dropout=0.1, |
| | | norm_type="scalenorm", |
| | | shift_tokens=True |
| | | ): |
| | | super().__init__() |
| | | assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm' |
| | | assert norm_type in ( |
| | | "scalenorm", |
| | | "layernorm", |
| | | ), "norm_type must be one of scalenorm or layernorm" |
| | | |
| | | if norm_type == 'scalenorm': |
| | | if norm_type == "scalenorm": |
| | | norm_klass = ScaleNorm |
| | | elif norm_type == 'layernorm': |
| | | elif norm_type == "layernorm": |
| | | norm_klass = nn.LayerNorm |
| | | |
| | | self.group_size = group_size |
| | | |
| | | rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim)) |
| | | rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim)) |
| | | # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J |
| | | self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)]) |
| | | self.layers = nn.ModuleList( |
| | | [ |
| | | FLASH_ShareA_FFConvM( |
| | | dim=dim, |
| | | group_size=group_size, |
| | | query_key_dim=query_key_dim, |
| | | expansion_factor=expansion_factor, |
| | | causal=causal, |
| | | dropout=attn_dropout, |
| | | rotary_pos_emb=rotary_pos_emb, |
| | | norm_klass=norm_klass, |
| | | shift_tokens=shift_tokens, |
| | | ) |
| | | for _ in range(depth) |
| | | ] |
| | | ) |
| | | |
| | | def forward( |
| | | self, |
| | | x, |
| | | *, |
| | | mask = None |
| | | ): |
| | | def forward(self, x, *, mask=None): |
| | | ii = 0 |
| | | for flash in self.layers: |
| | | x = flash(x, mask = mask) |
| | | x = flash(x, mask=mask) |
| | | ii = ii + 1 |
| | | return x |
| | | |
| | |
| | | self.pos_enc = ScaledSinuEmbedding(out_channels) |
| | | |
| | | self.mdl = Computation_Block( |
| | | num_blocks, |
| | | out_channels, |
| | | norm, |
| | | skip_around_intra=skip_around_intra, |
| | | ) |
| | | |
| | | self.conv1d_out = nn.Conv1d( |
| | | out_channels, out_channels * num_spks, kernel_size=1 |
| | | num_blocks, |
| | | out_channels, |
| | | norm, |
| | | skip_around_intra=skip_around_intra, |
| | | ) |
| | | |
| | | self.conv1d_out = nn.Conv1d(out_channels, out_channels * num_spks, kernel_size=1) |
| | | self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False) |
| | | self.prelu = nn.PReLU() |
| | | self.activation = nn.ReLU() |
| | | # gated output layer |
| | | self.output = nn.Sequential( |
| | | nn.Conv1d(out_channels, out_channels, 1), nn.Tanh() |
| | | ) |
| | | self.output_gate = nn.Sequential( |
| | | nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid() |
| | | ) |
| | | self.output = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()) |
| | | self.output_gate = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()) |
| | | |
| | | def forward(self, x): |
| | | """Returns the output tensor. |
| | |
| | | # [B, N, L] |
| | | x = self.conv1d_encoder(x) |
| | | if self.use_global_pos_enc: |
| | | #x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * ( |
| | | # x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * ( |
| | | # x.size(1) ** 0.5) |
| | | base = x |
| | | x = x.transpose(1, -1) |
| | | emb = self.pos_enc(x) |
| | | emb = emb.transpose(0, -1) |
| | | #print('base: {}, emb: {}'.format(base.shape, emb.shape)) |
| | | emb = emb.transpose(0, -1) |
| | | # print('base: {}, emb: {}'.format(base.shape, emb.shape)) |
| | | x = base + emb |
| | | |
| | | |
| | | # [B, N, S] |
| | | #for i in range(self.num_modules): |
| | | # for i in range(self.num_modules): |
| | | # x = self.dual_mdl[i](x) |
| | | x = self.mdl(x) |
| | | x = self.prelu(x) |
| | |
| | | |
| | | return x |
| | | |
| | | |
| | | class MossFormerM(nn.Module): |
| | | """This class implements the transformer encoder. |
| | | |
| | |
| | | >>> output.shape |
| | | torch.Size([8, 60, 512]) |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | num_blocks, |
| | | d_model=None, |
| | | causal=False, |
| | | group_size = 256, |
| | | query_key_dim = 128, |
| | | expansion_factor = 4., |
| | | attn_dropout = 0.1 |
| | | group_size=256, |
| | | query_key_dim=128, |
| | | expansion_factor=4.0, |
| | | attn_dropout=0.1, |
| | | ): |
| | | super().__init__() |
| | | |
| | | self.mossformerM = MossformerBlock( |
| | | dim=d_model, |
| | | depth=num_blocks, |
| | | group_size=group_size, |
| | | query_key_dim=query_key_dim, |
| | | expansion_factor=expansion_factor, |
| | | causal=causal, |
| | | attn_dropout=attn_dropout |
| | | ) |
| | | dim=d_model, |
| | | depth=num_blocks, |
| | | group_size=group_size, |
| | | query_key_dim=query_key_dim, |
| | | expansion_factor=expansion_factor, |
| | | causal=causal, |
| | | attn_dropout=attn_dropout, |
| | | ) |
| | | self.norm = nn.LayerNorm(d_model, eps=1e-6) |
| | | |
| | | def forward( |
| | |
| | | super(Computation_Block, self).__init__() |
| | | |
| | | ##MossFormer2M: MossFormer with recurrence |
| | | #self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels) |
| | | # self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels) |
| | | ##MossFormerM: the orignal MossFormer |
| | | self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels) |
| | | self.skip_around_intra = skip_around_intra |
| | |
| | | Output tensor of dimension [B, N, S]. |
| | | where, B = Batchsize, |
| | | N = number of filters |
| | | S = sequence time index |
| | | S = sequence time index |
| | | """ |
| | | B, N, S = x.shape |
| | | # intra RNN |
| | | # [B, S, N] |
| | | intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N) |
| | | intra = x.permute(0, 2, 1).contiguous() # .view(B, S, N) |
| | | |
| | | intra = self.intra_mdl(intra) |
| | | |
| | |
| | | |
| | | out = intra |
| | | return out |
| | | |