liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/mossformer/mossformer_encoder.py
@@ -1,18 +1,20 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
    from rotary_embedding_torch import RotaryEmbedding
except:
    print("If you want use mossformer, lease install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch")
    print(
        "If you want use mossformer, lease install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch"
    )
from funasr.models.transformer.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
from funasr.models.transformer.embedding import ScaledSinuEmbedding
from funasr.models.transformer.mossformer import FLASH_ShareA_FFConvM
def select_norm(norm, dim, shape):
    """Just a wrapper to select the normalization type.
    """
    """Just a wrapper to select the normalization type."""
    if norm == "gln":
        return GlobalLayerNorm(dim, shape, elementwise_affine=True)
@@ -30,37 +32,50 @@
        *,
        dim,
        depth,
        group_size = 256,
        query_key_dim = 128,
        expansion_factor = 4.,
        causal = False,
        attn_dropout = 0.1,
        norm_type = 'scalenorm',
        shift_tokens = True
        group_size=256,
        query_key_dim=128,
        expansion_factor=4.0,
        causal=False,
        attn_dropout=0.1,
        norm_type="scalenorm",
        shift_tokens=True
    ):
        super().__init__()
        assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
        assert norm_type in (
            "scalenorm",
            "layernorm",
        ), "norm_type must be one of scalenorm or layernorm"
        if norm_type == 'scalenorm':
        if norm_type == "scalenorm":
            norm_klass = ScaleNorm
        elif norm_type == 'layernorm':
        elif norm_type == "layernorm":
            norm_klass = nn.LayerNorm
        self.group_size = group_size
        rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
        rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
        # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
        self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])
        self.layers = nn.ModuleList(
            [
                FLASH_ShareA_FFConvM(
                    dim=dim,
                    group_size=group_size,
                    query_key_dim=query_key_dim,
                    expansion_factor=expansion_factor,
                    causal=causal,
                    dropout=attn_dropout,
                    rotary_pos_emb=rotary_pos_emb,
                    norm_klass=norm_klass,
                    shift_tokens=shift_tokens,
                )
                for _ in range(depth)
            ]
        )
    def forward(
        self,
        x,
        *,
        mask = None
    ):
    def forward(self, x, *, mask=None):
        ii = 0
        for flash in self.layers:
            x = flash(x, mask = mask)
            x = flash(x, mask=mask)
            ii = ii + 1
        return x
@@ -119,25 +134,19 @@
            self.pos_enc = ScaledSinuEmbedding(out_channels)
        self.mdl = Computation_Block(
                    num_blocks,
                    out_channels,
                    norm,
                    skip_around_intra=skip_around_intra,
                )
        self.conv1d_out = nn.Conv1d(
            out_channels, out_channels * num_spks, kernel_size=1
            num_blocks,
            out_channels,
            norm,
            skip_around_intra=skip_around_intra,
        )
        self.conv1d_out = nn.Conv1d(out_channels, out_channels * num_spks, kernel_size=1)
        self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
        self.prelu = nn.PReLU()
        self.activation = nn.ReLU()
        # gated output layer
        self.output = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
        )
        self.output_gate = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
        )
        self.output = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Tanh())
        self.output_gate = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid())
    def forward(self, x):
        """Returns the output tensor.
@@ -165,18 +174,17 @@
        # [B, N, L]
        x = self.conv1d_encoder(x)
        if self.use_global_pos_enc:
            #x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
            # x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
            #    x.size(1) ** 0.5)
            base = x
            x = x.transpose(1, -1)
            emb = self.pos_enc(x)
            emb = emb.transpose(0, -1)
            #print('base: {}, emb: {}'.format(base.shape, emb.shape))
            emb = emb.transpose(0, -1)
            # print('base: {}, emb: {}'.format(base.shape, emb.shape))
            x = base + emb
        # [B, N, S]
        #for i in range(self.num_modules):
        # for i in range(self.num_modules):
        #    x = self.dual_mdl[i](x)
        x = self.mdl(x)
        x = self.prelu(x)
@@ -264,6 +272,7 @@
        return x
class MossFormerM(nn.Module):
    """This class implements the transformer encoder.
@@ -293,27 +302,28 @@
    >>> output.shape
    torch.Size([8, 60, 512])
    """
    def __init__(
        self,
        num_blocks,
        d_model=None,
        causal=False,
        group_size = 256,
        query_key_dim = 128,
        expansion_factor = 4.,
        attn_dropout = 0.1
        group_size=256,
        query_key_dim=128,
        expansion_factor=4.0,
        attn_dropout=0.1,
    ):
        super().__init__()
        self.mossformerM = MossformerBlock(
                           dim=d_model,
                           depth=num_blocks,
                           group_size=group_size,
                           query_key_dim=query_key_dim,
                           expansion_factor=expansion_factor,
                           causal=causal,
                           attn_dropout=attn_dropout
                              )
            dim=d_model,
            depth=num_blocks,
            group_size=group_size,
            query_key_dim=query_key_dim,
            expansion_factor=expansion_factor,
            causal=causal,
            attn_dropout=attn_dropout,
        )
        self.norm = nn.LayerNorm(d_model, eps=1e-6)
    def forward(
@@ -371,7 +381,7 @@
        super(Computation_Block, self).__init__()
        ##MossFormer2M: MossFormer with recurrence
        #self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
        # self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
        ##MossFormerM: the orignal MossFormer
        self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
        self.skip_around_intra = skip_around_intra
@@ -396,12 +406,12 @@
            Output tensor of dimension [B, N, S].
            where, B = Batchsize,
               N = number of filters
               S = sequence time index
               S = sequence time index
        """
        B, N, S = x.shape
        # intra RNN
        # [B, S, N]
        intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
        intra = x.permute(0, 2, 1).contiguous()  # .view(B, S, N)
        intra = self.intra_mdl(intra)
@@ -416,4 +426,3 @@
        out = intra
        return out