From bce72487636cf84c463381096216e995deb1920d Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 10 八月 2023 12:38:55 +0800
Subject: [PATCH] add mossformer code

---
 funasr/modules/embedding.py                 |   17 
 funasr/modules/mossformer.py                |  307 ++++++++++++
 funasr/build_utils/build_model.py           |    3 
 funasr/build_utils/build_ss_model.py        |   15 
 funasr/modules/layer_norm.py                |  135 +++++
 funasr/models/encoder/mossformer_encoder.py |  417 +++++++++++++++++
 funasr/models/e2e_ss.py                     |   95 +++
 funasr/bin/ss_inference_launch.py           |  253 ++++++++++
 funasr/bin/ss_infer.py                      |  127 +++++
 funasr/build_utils/build_model_from_file.py |   15 
 funasr/models/decoder/mossformer_decoder.py |   53 ++
 11 files changed, 1,437 insertions(+), 0 deletions(-)

diff --git a/funasr/bin/ss_infer.py b/funasr/bin/ss_infer.py
new file mode 100644
index 0000000..483967b
--- /dev/null
+++ b/funasr/bin/ss_infer.py
@@ -0,0 +1,127 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+
+import logging
+from pathlib import Path
+from typing import List
+from typing import Union
+
+import numpy as np
+import torch
+
+from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.torch_utils.device_funcs import to_device
+
+
+class SpeechSeparator:
+    """SpeechSeparator class
+
+    Examples:
+        >>> import soundfile
+        >>> speech_separator = MossFormer("ss_config.yml", "ss.pt")
+        >>> audio, rate = soundfile.read("speech.wav")
+        >>> separated_wavs = speech_separator(audio)        
+
+    """
+
+    def __init__(
+            self,
+            ss_infer_config: Union[Path, str] = None,
+            ss_model_file: Union[Path, str] = None,
+            device: str = "cpu",
+            batch_size: int = 1,
+            dtype: str = "float32",
+            **kwargs,
+    ):
+
+        # 1. Build ss model
+        ss_model, ss_infer_args = build_model_from_file(
+            ss_infer_config, ss_model_file, None, device, task_name="ss"
+        )
+
+        logging.info("ss_model: {}".format(ss_model))
+        logging.info("ss_infer_args: {}".format(ss_infer_args))
+
+        ss_model.to(dtype=getattr(torch, dtype)).eval()
+
+        self.ss_model = ss_model
+        self.ss_infer_args = ss_infer_args
+        self.device = device
+        self.dtype = dtype
+        self.batch_size = batch_size
+
+    def decode(self, model, args, inputs, nsamples):
+        decode_do_segment = False
+        with torch.no_grad():       
+            out = []
+            window = args.sample_rate * args.decode_window  # decoding window length
+            stride = int(window*0.75)  # decoding stride if segmentation is used
+            b, t = inputs.shape
+            if t > window * args.one_time_decode_length:
+                decode_do_segment = True  # set segment decoding to true for very long sequence
+
+            if t < window:
+                inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], window-t))], 1)
+            elif t < window + stride:
+                padding = window + stride - t
+                inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
+            else:
+                if (t - window) % stride != 0:
+                    padding = t - (t-window)//stride * stride
+                    inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
+            inputs = torch.from_numpy(np.float32(inputs))
+            inputs = to_device(inputs, device=self.device)
+            b, t = inputs.shape
+            if decode_do_segment:
+                outputs = np.zeros((args.num_spks, t))
+                give_up_length = (window - stride)//2
+                current_idx = 0
+                while current_idx + window <= t:
+                    tmp_input = inputs[:, current_idx:current_idx+window]
+                    tmp_out_list = model(tmp_input,)
+                    for spk in range(args.num_spks):
+                        tmp_out_list[spk] = tmp_out_list[spk][0, :].cpu().numpy()
+                        if current_idx == 0:
+                            outputs[spk, current_idx:current_idx+window-give_up_length] = \
+                                tmp_out_list[spk][:-give_up_length]
+                        else:
+                            outputs[spk, current_idx+give_up_length:current_idx+window-give_up_length] = \
+                                tmp_out_list[spk][give_up_length:-give_up_length]
+                    current_idx += stride
+                for spk in range(args.num_spks):
+                    out.append(outputs[spk, :])
+            else:
+                out_list = model(inputs)
+                for spk in range(args.num_spks):
+                    out.append(out_list[spk][0, :].cpu().numpy())
+
+            max_abs = 0
+            for spk in range(args.num_spks):
+                if max_abs < max(abs(out[spk])):
+                    max_abs = max(abs(out[spk]))
+            for spk in range(args.num_spks):
+                out[spk] = out[spk][:nsamples]
+                out[spk] = out[spk]/max_abs
+
+        return out
+
+    @torch.no_grad()
+    def __call__(
+            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
+    ) -> List[torch.Tensor]:
+        """Inference
+
+        Args:
+            speech: Input speech data
+        Returns:
+            speech list: list of speech data
+
+        """
+
+        out = self.decode(self.ss_model, self.ss_infer_args, speech, speech_lengths)
+
+        return out
+
diff --git a/funasr/bin/ss_inference_launch.py b/funasr/bin/ss_inference_launch.py
new file mode 100644
index 0000000..bab68ad
--- /dev/null
+++ b/funasr/bin/ss_inference_launch.py
@@ -0,0 +1,253 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+
+import argparse
+import logging
+import os
+import sys
+from typing import Optional
+from typing import Union
+
+import numpy as np
+import torch
+import soundfile as sf
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2triple_str
+from funasr.bin.ss_infer import SpeechSeparator
+
+
+def inference_ss(
+        batch_size: int,
+        ngpu: int,
+        log_level: Union[int, str],
+        ss_infer_config: Optional[str],
+        ss_model_file: Optional[str],
+        output_dir: Optional[str] = None,
+        dtype: str = "float32",
+        seed: int = 0,
+        num_workers: int = 1,
+        num_spks: int = 2,
+        sample_rate: int = 8000,
+        param_dict: dict = None,
+        **kwargs,
+):
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
+    if batch_size > 1:
+        raise NotImplementedError("batch decoding is not implemented")
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+
+    if ngpu >= 1 and torch.cuda.is_available():
+        device = "cuda"
+    else:
+        device = "cpu"
+        batch_size = 1
+
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+
+    # 2. Build speech separator
+    speech_separator_kwargs = dict(
+        ss_infer_config=ss_infer_config,
+        ss_model_file=ss_model_file,
+        device=device,
+        dtype=dtype,
+    )
+    logging.info("speech_separator_kwargs: {}".format(speech_separator_kwargs))
+    speech_separator = SpeechSeparator(**speech_separator_kwargs)
+
+    def _forward(
+            data_path_and_name_and_type,
+            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+            output_dir_v2: Optional[str] = None,
+            fs: dict = None,
+            param_dict: dict = None
+    ):
+        # 3. Build data-iterator
+        if data_path_and_name_and_type is None and raw_inputs is not None:
+            if isinstance(raw_inputs, torch.Tensor):
+                raw_inputs = raw_inputs.numpy()
+            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+        loader = build_streaming_iterator(
+            task_name="ss",
+            preprocess_args=None,
+            data_path_and_name_and_type=data_path_and_name_and_type,
+            dtype=dtype,
+            fs=fs,
+            batch_size=batch_size,
+            num_workers=num_workers,
+        )
+
+        # 4 .Start for-loop
+        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+        if not os.path.exists(output_path):
+            cmd = 'mkdir -p ' + output_path 
+            os.system(cmd)       
+ 
+        for keys, batch in loader:
+            assert isinstance(batch, dict), type(batch)
+            assert all(isinstance(s, str) for s in keys), keys
+            _bs = len(next(iter(batch.values())))
+            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+
+            # do speech separation
+            logging.info('decoding: {}'.format(keys[0]))
+            ss_results = speech_separator(**batch)
+            
+            for spk in range(num_spks):
+                sf.write(os.path.join(output_path, keys[0].replace('.wav', '_s'+str(spk+1)+'.wav')), ss_results[spk], sample_rate)
+        torch.cuda.empty_cache()
+        return ss_results
+
+    return _forward
+
+
+def inference_launch(mode, **kwargs):
+    if mode == "mossformer":
+        return inference_ss(**kwargs)
+    else:
+        logging.info("Unknown decoding mode: {}".format(mode))
+        return None
+
+
+def get_parser():
+    parser = config_argparse.ArgumentParser(
+        description="Speech Separator Decoding",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    # Note(kamo): Use '_' instead of '-' as separator.
+    # '-' is confusing if written in yaml.
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+
+    parser.add_argument("--output_dir", type=str, required=True)
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=1,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument(
+        "--njob",
+        type=int,
+        default=1,
+        help="The number of jobs for each gpu",
+    )
+    parser.add_argument(
+        "--gpuid_list",
+        type=str,
+        default="2",
+        help="The visible gpus",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument(
+        "--dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=1,
+        help="The number of workers used for DataLoader",
+    )
+
+    group = parser.add_argument_group("Input data related")
+    group.add_argument(
+        "--data_path_and_name_and_type",
+        type=str2triple_str,
+        required=True,
+        action="append",
+    )
+
+    group = parser.add_argument_group("The model configuration related")
+    group.add_argument(
+        "--ss_infer_config",
+        type=str,
+        help="SS infer configuration",
+    )
+    group.add_argument(
+        "--ss_model_file",
+        type=str,
+        help="SS model parameter file",
+    )
+    group.add_argument(
+        "--ss_train_config",
+        type=str,
+        help="SS training configuration",
+    )
+
+    group = parser.add_argument_group("The inference configuration related")
+    group.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="The batch size for inference",
+    )
+
+    parser.add_argument(
+        '--num-spks', dest='num_spks', type=int, default=2)
+
+    parser.add_argument(
+        '--one-time-decode-length', dest='one_time_decode_length', type=int,
+        default=60, help='the max length (second) for one-time decoding')
+
+    parser.add_argument(
+        '--decode-window', dest='decode_window', type=int,
+        default=1, help='segmental decoding window length (second)')
+
+    parser.add_argument(
+        '--sample-rate', dest='sample_rate', type=int, default='8000')
+    return parser
+
+
+def main(cmd=None):
+    print(get_commandline_args(), file=sys.stderr)
+    parser = get_parser()
+    parser.add_argument(
+        "--mode",
+        type=str,
+        default="mossformer",
+        help="The decoding mode",
+    )
+    args = parser.parse_args(cmd)
+    kwargs = vars(args)
+    kwargs.pop("config", None)
+
+    # set logging messages
+    logging.basicConfig(
+        level=args.log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+    logging.info("Decoding args: {}".format(kwargs))
+
+    # gpu setting
+    if args.ngpu > 0:
+        jobid = int(args.output_dir.split(".")[-1])
+        gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
+        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
+
+    inference_pipeline = inference_launch(**kwargs)
+    return inference_pipeline(kwargs["data_path_and_name_and_type"])
+
+
+if __name__ == "__main__":
+    main()
+
diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py
index be8f910..66fdfd0 100644
--- a/funasr/build_utils/build_model.py
+++ b/funasr/build_utils/build_model.py
@@ -5,6 +5,7 @@
 from funasr.build_utils.build_punc_model import build_punc_model
 from funasr.build_utils.build_sv_model import build_sv_model
 from funasr.build_utils.build_vad_model import build_vad_model
+from funasr.build_utils.build_ss_model import build_ss_model
 
 
 def build_model(args):
@@ -22,6 +23,8 @@
         model = build_diar_model(args)
     elif args.task_name == "sv":
         model = build_sv_model(args)
+    elif args.task_name == "ss":
+        model = build_ss_model(args)
     else:
         raise NotImplementedError("Not supported task: {}".format(args.task_name))
 
diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py
index 26542cd..6130e71 100644
--- a/funasr/build_utils/build_model_from_file.py
+++ b/funasr/build_utils/build_model_from_file.py
@@ -11,6 +11,18 @@
 from funasr.models.base_model import FunASRModel
 
 
+def load_checkpoint(checkpoint_path, use_cuda=1):
+    if use_cuda:
+        checkpoint = torch.load(checkpoint_path)
+    else:
+        checkpoint = torch.load(
+            checkpoint_path, map_location=lambda storage, loc: storage)
+    return checkpoint
+
+def reload_ss_for_eval(model, checkpoint_path, use_cuda=False):
+    checkpoint = load_checkpoint(checkpoint_path, use_cuda)
+    model.load_state_dict(checkpoint['model'], strict=False)
+
 def build_model_from_file(
         config_file: Union[Path, str] = None,
         model_file: Union[Path, str] = None,
@@ -70,6 +82,9 @@
             model.load_state_dict(model_dict)
         else:
             model_dict = torch.load(model_file, map_location=device)
+    if task_name == 'ss':        
+        reload_ss_for_eval(model, model_file, use_cuda=True)
+        logging.info("model is loaded from path: {}".format(model_file))
     if task_name == "diar" and mode == "sond":
         model_dict = fileter_model_dict(model_dict, model.state_dict())
     if task_name == "vad":
diff --git a/funasr/build_utils/build_ss_model.py b/funasr/build_utils/build_ss_model.py
new file mode 100644
index 0000000..a6b5209
--- /dev/null
+++ b/funasr/build_utils/build_ss_model.py
@@ -0,0 +1,15 @@
+from funasr.models.e2e_ss import MossFormer
+
+def build_ss_model(args):
+    model = MossFormer(
+        in_channels=args.encoder_embedding_dim,
+        out_channels=args.mossformer_sequence_dim,
+        num_blocks=args.num_mossformer_layer,
+        kernel_size=args.encoder_kernel_size,
+        norm=args.norm,
+        num_spks=args.num_spks,
+        skip_around_intra=args.skip_around_intra,
+        use_global_pos_enc=args.use_global_pos_enc,
+        max_length=args.max_length)
+
+    return model
diff --git a/funasr/models/decoder/mossformer_decoder.py b/funasr/models/decoder/mossformer_decoder.py
new file mode 100644
index 0000000..e0189f7
--- /dev/null
+++ b/funasr/models/decoder/mossformer_decoder.py
@@ -0,0 +1,53 @@
+import torch
+import torch.nn as nn
+
+
+class MossFormerDecoder(nn.ConvTranspose1d):
+    """A decoder layer that consists of ConvTranspose1d.
+
+    Arguments
+    ---------
+    kernel_size : int
+        Length of filters.
+    in_channels : int
+        Number of  input channels.
+    out_channels : int
+        Number of output channels.
+
+
+    Example
+    ---------
+    >>> x = torch.randn(2, 100, 1000)
+    >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
+    >>> h = decoder(x)
+    >>> h.shape
+    torch.Size([2, 1003])
+    """
+
+    def __init__(self, *args, **kwargs):
+        super(MossFormerDecoder, self).__init__(*args, **kwargs)
+
+    def forward(self, x):
+        """Return the decoded output.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Input tensor with dimensionality [B, N, L].
+                where, B = Batchsize,
+                       N = number of filters
+                       L = time points
+        """
+
+        if x.dim() not in [2, 3]:
+            raise RuntimeError(
+                "{} accept 3/4D tensor as input".format(self.__name__)
+            )
+        x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
+
+        if torch.squeeze(x).dim() == 1:
+            x = torch.squeeze(x, dim=1)
+        else:
+            x = torch.squeeze(x)
+        return x
+
diff --git a/funasr/models/e2e_ss.py b/funasr/models/e2e_ss.py
new file mode 100644
index 0000000..1a46b3f
--- /dev/null
+++ b/funasr/models/e2e_ss.py
@@ -0,0 +1,95 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+from funasr.models.base_model import FunASRModel
+from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet 
+from funasr.models.decoder.mossformer_decoder import MossFormerDecoder
+
+
+class MossFormer(FunASRModel):
+    """The MossFormer model for separating input mixed speech into different speaker's speech.
+
+    Arguments
+    ---------
+    in_channels : int
+        Number of channels at the output of the encoder.
+    out_channels : int
+        Number of channels that would be inputted to the intra and inter blocks.
+    num_blocks : int
+        Number of layers of Dual Computation Block.
+    norm : str
+        Normalization type.
+    num_spks : int
+        Number of sources (speakers).
+    skip_around_intra : bool
+        Skip connection around intra.
+    use_global_pos_enc : bool
+        Global positional encodings.
+    max_length : int
+        Maximum sequence length.
+    kernel_size: int
+        Encoder and decoder kernel size
+    """
+
+    def __init__(
+        self,
+        in_channels=512,
+        out_channels=512,
+        num_blocks=24,
+        kernel_size=16,
+        norm="ln",
+        num_spks=2,
+        skip_around_intra=True,
+        use_global_pos_enc=True,
+        max_length=20000,
+    ):
+        super(MossFormer, self).__init__()
+        self.num_spks = num_spks
+        # Encoding
+        self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1)
+
+        ##Compute Mask
+        self.mask_net = MossFormer_MaskNet(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            num_blocks=num_blocks,
+            norm=norm,
+            num_spks=num_spks,
+            skip_around_intra=skip_around_intra,
+            use_global_pos_enc=use_global_pos_enc,
+            max_length=max_length,
+        )
+        self.dec = MossFormerDecoder(
+           in_channels=out_channels,
+           out_channels=1,
+           kernel_size=kernel_size,
+           stride = kernel_size//2,
+           bias=False
+        )
+    def forward(self, input):
+        x = self.enc(input)
+        mask = self.mask_net(x)
+        x = torch.stack([x] * self.num_spks)
+        sep_x = x * mask
+
+        # Decoding
+        est_source = torch.cat(
+            [
+                self.dec(sep_x[i]).unsqueeze(-1)
+                for i in range(self.num_spks)
+            ],
+            dim=-1,
+        )
+        T_origin = input.size(1)
+        T_est = est_source.size(1)
+        if T_origin > T_est:
+            est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
+        else:
+            est_source = est_source[:, :T_origin, :]
+
+        out = []
+        for spk in range(self.num_spks):
+            out.append(est_source[:,:,spk])
+        return out
diff --git a/funasr/models/encoder/mossformer_encoder.py b/funasr/models/encoder/mossformer_encoder.py
new file mode 100644
index 0000000..54d80ca
--- /dev/null
+++ b/funasr/models/encoder/mossformer_encoder.py
@@ -0,0 +1,417 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from rotary_embedding_torch import RotaryEmbedding
+from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
+from funasr.modules.embedding import ScaledSinuEmbedding
+from funasr.modules.mossformer import FLASH_ShareA_FFConvM
+
+
+def select_norm(norm, dim, shape):
+    """Just a wrapper to select the normalization type.
+    """
+
+    if norm == "gln":
+        return GlobalLayerNorm(dim, shape, elementwise_affine=True)
+    if norm == "cln":
+        return CumulativeLayerNorm(dim, elementwise_affine=True)
+    if norm == "ln":
+        return nn.GroupNorm(1, dim, eps=1e-8)
+    else:
+        return nn.BatchNorm1d(dim)
+
+
+class MossformerBlock(nn.Module):
+    def __init__(
+        self,
+        *,
+        dim,
+        depth,
+        group_size = 256, 
+        query_key_dim = 128, 
+        expansion_factor = 4.,
+        causal = False,
+        attn_dropout = 0.1,
+        norm_type = 'scalenorm',
+        shift_tokens = True
+    ):
+        super().__init__()
+        assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
+
+        if norm_type == 'scalenorm':
+            norm_klass = ScaleNorm
+        elif norm_type == 'layernorm':
+            norm_klass = nn.LayerNorm
+
+        self.group_size = group_size
+
+        rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
+        # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
+        self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])
+
+    def forward(
+        self,
+        x,
+        *,
+        mask = None
+    ):
+        ii = 0
+        for flash in self.layers:
+            x = flash(x, mask = mask)
+            ii = ii + 1
+        return x
+
+
+class MossFormer_MaskNet(nn.Module):
+    """The MossFormer module for computing output masks.
+
+    Arguments
+    ---------
+    in_channels : int
+        Number of channels at the output of the encoder.
+    out_channels : int
+        Number of channels that would be inputted to the intra and inter blocks.
+    num_blocks : int
+        Number of layers of Dual Computation Block.
+    norm : str
+        Normalization type.
+    num_spks : int
+        Number of sources (speakers).
+    skip_around_intra : bool
+        Skip connection around intra.
+    use_global_pos_enc : bool
+        Global positional encodings.
+    max_length : int
+        Maximum sequence length.
+
+    Example
+    ---------
+    >>> mossformer_block = MossFormerM(1, 64, 8)
+    >>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2)
+    >>> x = torch.randn(10, 64, 2000)
+    >>> x = mossformer_masknet(x)
+    >>> x.shape
+    torch.Size([2, 10, 64, 2000])
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        num_blocks=24,
+        norm="ln",
+        num_spks=2,
+        skip_around_intra=True,
+        use_global_pos_enc=True,
+        max_length=20000,
+    ):
+        super(MossFormer_MaskNet, self).__init__()
+        self.num_spks = num_spks
+        self.num_blocks = num_blocks
+        self.norm = select_norm(norm, in_channels, 3)
+        self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
+        self.use_global_pos_enc = use_global_pos_enc
+
+        if self.use_global_pos_enc:
+            self.pos_enc = ScaledSinuEmbedding(out_channels)
+
+        self.mdl = Computation_Block(
+                    num_blocks,
+                    out_channels,
+                    norm,
+                    skip_around_intra=skip_around_intra,
+                )
+
+        self.conv1d_out = nn.Conv1d(
+            out_channels, out_channels * num_spks, kernel_size=1
+        )
+        self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
+        self.prelu = nn.PReLU()
+        self.activation = nn.ReLU()
+        # gated output layer
+        self.output = nn.Sequential(
+            nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
+        )
+        self.output_gate = nn.Sequential(
+            nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
+        )
+
+    def forward(self, x):
+        """Returns the output tensor.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Input tensor of dimension [B, N, S].
+
+        Returns
+        -------
+        out : torch.Tensor
+            Output tensor of dimension [spks, B, N, S]
+            where, spks = Number of speakers
+               B = Batchsize,
+               N = number of filters
+               S = the number of time frames
+        """
+
+        # before each line we indicate the shape after executing the line
+
+        # [B, N, L]
+        x = self.norm(x)
+
+        # [B, N, L]
+        x = self.conv1d_encoder(x)
+        if self.use_global_pos_enc:
+            #x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
+            #    x.size(1) ** 0.5)
+            base = x
+            x = x.transpose(1, -1)
+            emb = self.pos_enc(x)
+            emb = emb.transpose(0, -1) 
+            #print('base: {}, emb: {}'.format(base.shape, emb.shape))
+            x = base + emb
+            
+
+        # [B, N, S]
+        #for i in range(self.num_modules):
+        #    x = self.dual_mdl[i](x)
+        x = self.mdl(x)
+        x = self.prelu(x)
+
+        # [B, N*spks, S]
+        x = self.conv1d_out(x)
+        B, _, S = x.shape
+
+        # [B*spks, N, S]
+        x = x.view(B * self.num_spks, -1, S)
+
+        # [B*spks, N, S]
+        x = self.output(x) * self.output_gate(x)
+
+        # [B*spks, N, S]
+        x = self.conv1_decoder(x)
+
+        # [B, spks, N, S]
+        _, N, L = x.shape
+        x = x.view(B, self.num_spks, N, L)
+        x = self.activation(x)
+
+        # [spks, B, N, S]
+        x = x.transpose(0, 1)
+
+        return x
+
+
+class MossFormerEncoder(nn.Module):
+    """Convolutional Encoder Layer.
+
+    Arguments
+    ---------
+    kernel_size : int
+        Length of filters.
+    in_channels : int
+        Number of  input channels.
+    out_channels : int
+        Number of output channels.
+
+    Example
+    -------
+    >>> x = torch.randn(2, 1000)
+    >>> encoder = Encoder(kernel_size=4, out_channels=64)
+    >>> h = encoder(x)
+    >>> h.shape
+    torch.Size([2, 64, 499])
+    """
+
+    def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
+        super(MossFormerEncoder, self).__init__()
+        self.conv1d = nn.Conv1d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            stride=kernel_size // 2,
+            groups=1,
+            bias=False,
+        )
+        self.in_channels = in_channels
+
+    def forward(self, x):
+        """Return the encoded output.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Input tensor with dimensionality [B, L].
+        Return
+        ------
+        x : torch.Tensor
+            Encoded tensor with dimensionality [B, N, T_out].
+
+        where B = Batchsize
+              L = Number of timepoints
+              N = Number of filters
+              T_out = Number of timepoints at the output of the encoder
+        """
+        # B x L -> B x 1 x L
+        if self.in_channels == 1:
+            x = torch.unsqueeze(x, dim=1)
+        # B x 1 x L -> B x N x T_out
+        x = self.conv1d(x)
+        x = F.relu(x)
+
+        return x
+
+class MossFormerM(nn.Module):
+    """This class implements the transformer encoder.
+
+    Arguments
+    ---------
+    num_blocks : int
+        Number of mossformer blocks to include.
+    d_model : int
+        The dimension of the input embedding.
+    attn_dropout : float
+        Dropout for the self-attention (Optional).
+    group_size: int
+        the chunk size
+    query_key_dim: int
+        the attention vector dimension
+    expansion_factor: int
+        the expansion factor for the linear projection in conv module
+    causal: bool
+        true for causal / false for non causal
+
+    Example
+    -------
+    >>> import torch
+    >>> x = torch.rand((8, 60, 512))
+    >>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
+    >>> output, _ = net(x)
+    >>> output.shape
+    torch.Size([8, 60, 512])
+    """
+    def __init__(
+        self,
+        num_blocks,
+        d_model=None,
+        causal=False,
+        group_size = 256,
+        query_key_dim = 128,
+        expansion_factor = 4.,
+        attn_dropout = 0.1
+    ):
+        super().__init__()
+
+        self.mossformerM = MossformerBlock(
+                           dim=d_model,
+                           depth=num_blocks,
+                           group_size=group_size,
+                           query_key_dim=query_key_dim,
+                           expansion_factor=expansion_factor,
+                           causal=causal,
+                           attn_dropout=attn_dropout
+                              )
+        self.norm = nn.LayerNorm(d_model, eps=1e-6)
+
+    def forward(
+        self,
+        src,
+    ):
+        """
+        Arguments
+        ----------
+        src : torch.Tensor
+            Tensor shape [B, L, N],
+            where, B = Batchsize,
+                   L = time points
+                   N = number of filters
+            The sequence to the encoder layer (required).
+        src_mask : tensor
+            The mask for the src sequence (optional).
+        src_key_padding_mask : tensor
+            The mask for the src keys per batch (optional).
+        """
+        output = self.mossformerM(src)
+        output = self.norm(output)
+
+        return output
+
+
+class Computation_Block(nn.Module):
+    """Computation block for dual-path processing.
+
+    Arguments
+    ---------
+     out_channels : int
+        Dimensionality of inter/intra model.
+     norm : str
+        Normalization type.
+     skip_around_intra : bool
+        Skip connection around the intra layer.
+
+    Example
+    ---------
+        >>> comp_block = Computation_Block(64)
+        >>> x = torch.randn(10, 64, 100)
+        >>> x = comp_block(x)
+        >>> x.shape
+        torch.Size([10, 64, 100])
+    """
+
+    def __init__(
+        self,
+        num_blocks,
+        out_channels,
+        norm="ln",
+        skip_around_intra=True,
+    ):
+        super(Computation_Block, self).__init__()
+
+        ##MossFormer2M: MossFormer with recurrence
+        #self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
+        ##MossFormerM: the orignal MossFormer
+        self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
+        self.skip_around_intra = skip_around_intra
+
+        # Norm
+        self.norm = norm
+        if norm is not None:
+            self.intra_norm = select_norm(norm, out_channels, 3)
+
+    def forward(self, x):
+        """Returns the output tensor.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Input tensor of dimension [B, N, S].
+
+
+        Return
+        ---------
+        out: torch.Tensor
+            Output tensor of dimension [B, N, S].
+            where, B = Batchsize,
+               N = number of filters
+               S = sequence time index 
+        """
+        B, N, S = x.shape
+        # intra RNN
+        # [B, S, N]
+        intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
+
+        intra = self.intra_mdl(intra)
+
+        # [B, N, S]
+        intra = intra.permute(0, 2, 1).contiguous()
+        if self.norm is not None:
+            intra = self.intra_norm(intra)
+
+        # [B, N, S]
+        if self.skip_around_intra:
+            intra = intra + x
+
+        out = intra
+        return out
+
diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py
index 374eba4..1995bbe 100644
--- a/funasr/modules/embedding.py
+++ b/funasr/modules/embedding.py
@@ -9,6 +9,7 @@
 import math
 import torch
 import torch.nn.functional as F
+from torch import einsum
 
 def _pre_hook(
     state_dict,
@@ -510,3 +511,19 @@
         pos_enc = self.dropout(pos_enc)
 
         return pos_enc
+
+
+class ScaledSinuEmbedding(torch.nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        self.scale = torch.nn.Parameter(torch.ones(1,))
+        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+        self.register_buffer('inv_freq', inv_freq)
+
+    def forward(self, x):
+        n, device = x.shape[1], x.device
+        t = torch.arange(n, device = device).type_as(self.inv_freq)
+        sinu = einsum('i , j -> i j', t, self.inv_freq)
+        emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
+        return emb * self.scale
+
diff --git a/funasr/modules/layer_norm.py b/funasr/modules/layer_norm.py
index 6e934e6..8683230 100644
--- a/funasr/modules/layer_norm.py
+++ b/funasr/modules/layer_norm.py
@@ -7,6 +7,7 @@
 """Layer normalization module."""
 
 import torch
+import torch.nn as nn
 
 
 class LayerNorm(torch.nn.LayerNorm):
@@ -40,3 +41,137 @@
             .forward(x.transpose(self.dim, -1))
             .transpose(self.dim, -1)
         )
+
+
+class GlobalLayerNorm(nn.Module):
+    """Calculate Global Layer Normalization.
+
+    Arguments
+    ---------
+       dim : (int or list or torch.Size)
+           Input shape from an expected input of size.
+       eps : float
+           A value added to the denominator for numerical stability.
+       elementwise_affine : bool
+          A boolean value that when set to True,
+          this module has learnable per-element affine parameters
+          initialized to ones (for weights) and zeros (for biases).
+
+    Example
+    -------
+    >>> x = torch.randn(5, 10, 20)
+    >>> GLN = GlobalLayerNorm(10, 3)
+    >>> x_norm = GLN(x)
+    """
+
+    def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
+        super(GlobalLayerNorm, self).__init__()
+        self.dim = dim
+        self.eps = eps
+        self.elementwise_affine = elementwise_affine
+
+        if self.elementwise_affine:
+            if shape == 3:
+                self.weight = nn.Parameter(torch.ones(self.dim, 1))
+                self.bias = nn.Parameter(torch.zeros(self.dim, 1))
+            if shape == 4:
+                self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
+                self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
+        else:
+            self.register_parameter("weight", None)
+            self.register_parameter("bias", None)
+
+    def forward(self, x):
+        """Returns the normalized tensor.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor of size [N, C, K, S] or [N, C, L].
+        """
+        # x = N x C x K x S or N x C x L
+        # N x 1 x 1
+        # cln: mean,var N x 1 x K x S
+        # gln: mean,var N x 1 x 1
+        if x.dim() == 3:
+            mean = torch.mean(x, (1, 2), keepdim=True)
+            var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
+            if self.elementwise_affine:
+                x = (
+                    self.weight * (x - mean) / torch.sqrt(var + self.eps)
+                    + self.bias
+                )
+            else:
+                x = (x - mean) / torch.sqrt(var + self.eps)
+
+        if x.dim() == 4:
+            mean = torch.mean(x, (1, 2, 3), keepdim=True)
+            var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
+            if self.elementwise_affine:
+                x = (
+                    self.weight * (x - mean) / torch.sqrt(var + self.eps)
+                    + self.bias
+                )
+            else:
+                x = (x - mean) / torch.sqrt(var + self.eps)
+        return x
+
+
+class CumulativeLayerNorm(nn.LayerNorm):
+    """Calculate Cumulative Layer Normalization.
+
+       Arguments
+       ---------
+       dim : int
+        Dimension that you want to normalize.
+       elementwise_affine : True
+        Learnable per-element affine parameters.
+
+    Example
+    -------
+    >>> x = torch.randn(5, 10, 20)
+    >>> CLN = CumulativeLayerNorm(10)
+    >>> x_norm = CLN(x)
+    """
+
+    def __init__(self, dim, elementwise_affine=True):
+        super(CumulativeLayerNorm, self).__init__(
+            dim, elementwise_affine=elementwise_affine, eps=1e-8
+        )
+
+    def forward(self, x):
+        """Returns the normalized tensor.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor size [N, C, K, S] or [N, C, L]
+        """
+        # x: N x C x K x S or N x C x L
+        # N x K x S x C
+        if x.dim() == 4:
+            x = x.permute(0, 2, 3, 1).contiguous()
+            # N x K x S x C == only channel norm
+            x = super().forward(x)
+            # N x C x K x S
+            x = x.permute(0, 3, 1, 2).contiguous()
+        if x.dim() == 3:
+            x = torch.transpose(x, 1, 2)
+            # N x L x C == only channel norm
+            x = super().forward(x)
+            # N x C x L
+            x = torch.transpose(x, 1, 2)
+        return x
+
+
+class ScaleNorm(nn.Module):
+    def __init__(self, dim, eps = 1e-5):
+        super().__init__()
+        self.scale = dim ** -0.5
+        self.eps = eps
+        self.g = nn.Parameter(torch.ones(1))
+
+    def forward(self, x):
+        norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
+        return x / norm.clamp(min = self.eps) * self.g
+
diff --git a/funasr/modules/mossformer.py b/funasr/modules/mossformer.py
new file mode 100644
index 0000000..f1e8e28
--- /dev/null
+++ b/funasr/modules/mossformer.py
@@ -0,0 +1,307 @@
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange
+
+
+def identity(t, *args, **kwargs):
+    return t
+
+def append_dims(x, num_dims):
+    if num_dims <= 0:
+        return x
+    return x.view(*x.shape, *((1,) * num_dims))
+
+def exists(val):
+    return val is not None
+
+def default(val, d):
+    return val if exists(val) else d
+
+def padding_to_multiple_of(n, mult):
+    remainder = n % mult
+    if remainder == 0:
+        return 0
+    return mult - remainder
+
+
+class Transpose(nn.Module):
+    """ Wrapper class of torch.transpose() for Sequential module. """
+    def __init__(self, shape: tuple):
+        super(Transpose, self).__init__()
+        self.shape = shape
+
+    def forward(self, x):
+        return x.transpose(*self.shape)
+
+
+class DepthwiseConv1d(nn.Module):
+    """
+    When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
+    this operation is termed in literature as depthwise convolution.
+    Args:
+        in_channels (int): Number of channels in the input
+        out_channels (int): Number of channels produced by the convolution
+        kernel_size (int or tuple): Size of the convolving kernel
+        stride (int, optional): Stride of the convolution. Default: 1
+        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
+        bias (bool, optional): If True, adds a learnable bias to the output. Default: True
+    Inputs: inputs
+        - **inputs** (batch, in_channels, time): Tensor containing input vector
+    Returns: outputs
+        - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
+    """
+    def __init__(
+            self,
+            in_channels: int,
+            out_channels: int,
+            kernel_size: int,
+            stride: int = 1,
+            padding: int = 0,
+            bias: bool = False,
+    ) -> None:
+        super(DepthwiseConv1d, self).__init__()
+        assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
+        self.conv = nn.Conv1d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            groups=in_channels,
+            stride=stride,
+            padding=padding,
+            bias=bias,
+        )
+
+    def forward(self, inputs):
+        return self.conv(inputs)
+
+
+class ConvModule(nn.Module):
+    """
+    Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
+    This is followed by a single 1-D depthwise convolution layer. Batchnorm is  deployed just after the convolution
+    to aid training deep models.
+    Args:
+        in_channels (int): Number of channels in the input
+        kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
+        dropout_p (float, optional): probability of dropout
+    Inputs: inputs
+        inputs (batch, time, dim): Tensor contains input sequences
+    Outputs: outputs
+        outputs (batch, time, dim): Tensor produces by conformer convolution module.
+    """
+    def __init__(
+            self,
+            in_channels: int,
+            kernel_size: int = 17,
+            expansion_factor: int = 2,
+            dropout_p: float = 0.1,
+    ) -> None:
+        super(ConvModule, self).__init__()
+        assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
+        assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
+
+        self.sequential = nn.Sequential(
+            Transpose(shape=(1, 2)),
+            DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
+        )
+
+    def forward(self, inputs):
+        return inputs + self.sequential(inputs).transpose(1, 2)
+
+
+class OffsetScale(nn.Module):
+    def __init__(self, dim, heads = 1):
+        super().__init__()
+        self.gamma = nn.Parameter(torch.ones(heads, dim))
+        self.beta = nn.Parameter(torch.zeros(heads, dim))
+        nn.init.normal_(self.gamma, std = 0.02)
+
+    def forward(self, x):
+        out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
+        return out.unbind(dim = -2)
+
+
+class FFConvM(nn.Module):
+    def __init__(
+        self,
+        dim_in,
+        dim_out,
+        norm_klass = nn.LayerNorm,
+        dropout = 0.1
+    ):
+        super().__init__()
+        self.mdl = nn.Sequential(
+            norm_klass(dim_in),
+            nn.Linear(dim_in, dim_out),
+            nn.SiLU(),
+            ConvModule(dim_out),
+            nn.Dropout(dropout)
+        )
+    def forward(
+        self,
+        x,
+    ):
+        output = self.mdl(x)
+        return output
+
+
+class FLASH_ShareA_FFConvM(nn.Module):
+    def __init__(
+        self,
+        *,
+        dim,
+        group_size = 256,
+        query_key_dim = 128,
+        expansion_factor = 1.,
+        causal = False,
+        dropout = 0.1,
+        rotary_pos_emb = None,
+        norm_klass = nn.LayerNorm,
+        shift_tokens = True
+    ):
+        super().__init__()
+        hidden_dim = int(dim * expansion_factor)        
+        self.group_size = group_size
+        self.causal = causal
+        self.shift_tokens = shift_tokens
+
+        # positional embeddings
+        self.rotary_pos_emb = rotary_pos_emb
+        # norm
+        self.dropout = nn.Dropout(dropout)
+        # projections
+        
+        self.to_hidden = FFConvM(
+            dim_in = dim,
+            dim_out = hidden_dim,
+            norm_klass = norm_klass,
+            dropout = dropout,
+            )
+        self.to_qk = FFConvM(
+            dim_in = dim,
+            dim_out = query_key_dim,
+            norm_klass = norm_klass,
+            dropout = dropout,
+            )
+
+        self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
+
+        self.to_out = FFConvM(
+            dim_in = dim*2,
+            dim_out = dim,
+            norm_klass = norm_klass,
+            dropout = dropout,
+            )
+        
+        self.gateActivate=nn.Sigmoid() 
+
+    def forward(
+        self,
+        x,
+        *,
+        mask = None
+    ):
+
+        """
+        b - batch
+        n - sequence length (within groups)
+        g - group dimension
+        d - feature dimension (keys)
+        e - feature dimension (values)
+        i - sequence dimension (source)
+        j - sequence dimension (target)
+        """
+
+        normed_x = x 
+
+        # do token shift - a great, costless trick from an independent AI researcher in Shenzhen
+        residual = x
+
+        if self.shift_tokens:
+            x_shift, x_pass = normed_x.chunk(2, dim = -1)
+            x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
+            normed_x = torch.cat((x_shift, x_pass), dim = -1)
+
+        # initial projections
+
+        v, u = self.to_hidden(normed_x).chunk(2, dim = -1)
+        qk = self.to_qk(normed_x)
+
+        # offset and scale
+        quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
+        att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
+        out = (att_u*v ) * self.gateActivate(att_v*u)        
+        x = x + self.to_out(out)
+        return x
+
+    def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None):
+        b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
+
+        if exists(mask):
+            lin_mask = rearrange(mask, '... -> ... 1')
+            lin_k = lin_k.masked_fill(~lin_mask, 0.)
+
+        # rotate queries and keys
+
+        if exists(self.rotary_pos_emb):
+            quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
+
+        # padding for groups
+
+        padding = padding_to_multiple_of(n, g)
+
+        if padding > 0:
+            quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))
+
+            mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
+            mask = F.pad(mask, (0, padding), value = False)
+
+        # group along sequence
+
+        quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
+
+        if exists(mask):
+            mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
+
+        # calculate quadratic attention output
+
+        sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
+
+        attn = F.relu(sim) ** 2
+        attn = self.dropout(attn)
+
+        if exists(mask):
+            attn = attn.masked_fill(~mask, 0.)
+
+        if self.causal:
+            causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
+            attn = attn.masked_fill(causal_mask, 0.)
+
+        quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
+        quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
+
+        # calculate linear attention output
+
+        if self.causal:
+            lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
+            # exclusive cumulative sum along group dimension
+            lin_kv = lin_kv.cumsum(dim = 1)
+            lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
+            lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
+
+            lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
+            # exclusive cumulative sum along group dimension
+            lin_ku = lin_ku.cumsum(dim = 1)
+            lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
+            lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
+        else:
+            lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
+            lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
+
+            lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
+            lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
+
+        # fold back groups into full sequence, and excise out padding
+        return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u))
+

--
Gitblit v1.9.1