zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/mossformer/mossformer_encoder.py
@@ -1,18 +1,20 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
    from rotary_embedding_torch import RotaryEmbedding
except:
    print("If you want use mossformer, lease install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch")
    print(
        "If you want use mossformer, lease install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch"
    )
from funasr.models.transformer.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
from funasr.models.transformer.embedding import ScaledSinuEmbedding
from funasr.models.transformer.mossformer import FLASH_ShareA_FFConvM
def select_norm(norm, dim, shape):
    """Just a wrapper to select the normalization type.
    """
    """Just a wrapper to select the normalization type."""
    if norm == "gln":
        return GlobalLayerNorm(dim, shape, elementwise_affine=True)
@@ -32,32 +34,45 @@
        depth,
        group_size = 256, 
        query_key_dim = 128, 
        expansion_factor = 4.,
        expansion_factor=4.0,
        causal = False,
        attn_dropout = 0.1,
        norm_type = 'scalenorm',
        norm_type="scalenorm",
        shift_tokens = True
    ):
        super().__init__()
        assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
        assert norm_type in (
            "scalenorm",
            "layernorm",
        ), "norm_type must be one of scalenorm or layernorm"
        if norm_type == 'scalenorm':
        if norm_type == "scalenorm":
            norm_klass = ScaleNorm
        elif norm_type == 'layernorm':
        elif norm_type == "layernorm":
            norm_klass = nn.LayerNorm
        self.group_size = group_size
        rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
        # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
        self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])
        self.layers = nn.ModuleList(
            [
                FLASH_ShareA_FFConvM(
                    dim=dim,
                    group_size=group_size,
                    query_key_dim=query_key_dim,
                    expansion_factor=expansion_factor,
                    causal=causal,
                    dropout=attn_dropout,
                    rotary_pos_emb=rotary_pos_emb,
                    norm_klass=norm_klass,
                    shift_tokens=shift_tokens,
                )
                for _ in range(depth)
            ]
        )
    def forward(
        self,
        x,
        *,
        mask = None
    ):
    def forward(self, x, *, mask=None):
        ii = 0
        for flash in self.layers:
            x = flash(x, mask = mask)
@@ -125,19 +140,13 @@
                    skip_around_intra=skip_around_intra,
                )
        self.conv1d_out = nn.Conv1d(
            out_channels, out_channels * num_spks, kernel_size=1
        )
        self.conv1d_out = nn.Conv1d(out_channels, out_channels * num_spks, kernel_size=1)
        self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
        self.prelu = nn.PReLU()
        self.activation = nn.ReLU()
        # gated output layer
        self.output = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
        )
        self.output_gate = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
        )
        self.output = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Tanh())
        self.output_gate = nn.Sequential(nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid())
    def forward(self, x):
        """Returns the output tensor.
@@ -173,7 +182,6 @@
            emb = emb.transpose(0, -1) 
            #print('base: {}, emb: {}'.format(base.shape, emb.shape))
            x = base + emb
        # [B, N, S]
        #for i in range(self.num_modules):
@@ -264,6 +272,7 @@
        return x
class MossFormerM(nn.Module):
    """This class implements the transformer encoder.
@@ -293,6 +302,7 @@
    >>> output.shape
    torch.Size([8, 60, 512])
    """
    def __init__(
        self,
        num_blocks,
@@ -300,8 +310,8 @@
        causal=False,
        group_size = 256,
        query_key_dim = 128,
        expansion_factor = 4.,
        attn_dropout = 0.1
        expansion_factor=4.0,
        attn_dropout=0.1,
    ):
        super().__init__()
@@ -312,7 +322,7 @@
                           query_key_dim=query_key_dim,
                           expansion_factor=expansion_factor,
                           causal=causal,
                           attn_dropout=attn_dropout
            attn_dropout=attn_dropout,
                              )
        self.norm = nn.LayerNorm(d_model, eps=1e-6)
@@ -416,4 +426,3 @@
        out = intra
        return out