From bce72487636cf84c463381096216e995deb1920d Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 10 八月 2023 12:38:55 +0800
Subject: [PATCH] add mossformer code
---
funasr/modules/embedding.py | 17
funasr/modules/mossformer.py | 307 ++++++++++++
funasr/build_utils/build_model.py | 3
funasr/build_utils/build_ss_model.py | 15
funasr/modules/layer_norm.py | 135 +++++
funasr/models/encoder/mossformer_encoder.py | 417 +++++++++++++++++
funasr/models/e2e_ss.py | 95 +++
funasr/bin/ss_inference_launch.py | 253 ++++++++++
funasr/bin/ss_infer.py | 127 +++++
funasr/build_utils/build_model_from_file.py | 15
funasr/models/decoder/mossformer_decoder.py | 53 ++
11 files changed, 1,437 insertions(+), 0 deletions(-)
diff --git a/funasr/bin/ss_infer.py b/funasr/bin/ss_infer.py
new file mode 100644
index 0000000..483967b
--- /dev/null
+++ b/funasr/bin/ss_infer.py
@@ -0,0 +1,127 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+
+import logging
+from pathlib import Path
+from typing import List
+from typing import Union
+
+import numpy as np
+import torch
+
+from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.torch_utils.device_funcs import to_device
+
+
+class SpeechSeparator:
+ """SpeechSeparator class
+
+ Examples:
+ >>> import soundfile
+ >>> speech_separator = MossFormer("ss_config.yml", "ss.pt")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> separated_wavs = speech_separator(audio)
+
+ """
+
+ def __init__(
+ self,
+ ss_infer_config: Union[Path, str] = None,
+ ss_model_file: Union[Path, str] = None,
+ device: str = "cpu",
+ batch_size: int = 1,
+ dtype: str = "float32",
+ **kwargs,
+ ):
+
+ # 1. Build ss model
+ ss_model, ss_infer_args = build_model_from_file(
+ ss_infer_config, ss_model_file, None, device, task_name="ss"
+ )
+
+ logging.info("ss_model: {}".format(ss_model))
+ logging.info("ss_infer_args: {}".format(ss_infer_args))
+
+ ss_model.to(dtype=getattr(torch, dtype)).eval()
+
+ self.ss_model = ss_model
+ self.ss_infer_args = ss_infer_args
+ self.device = device
+ self.dtype = dtype
+ self.batch_size = batch_size
+
+ def decode(self, model, args, inputs, nsamples):
+ decode_do_segment = False
+ with torch.no_grad():
+ out = []
+ window = args.sample_rate * args.decode_window # decoding window length
+ stride = int(window*0.75) # decoding stride if segmentation is used
+ b, t = inputs.shape
+ if t > window * args.one_time_decode_length:
+ decode_do_segment = True # set segment decoding to true for very long sequence
+
+ if t < window:
+ inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], window-t))], 1)
+ elif t < window + stride:
+ padding = window + stride - t
+ inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
+ else:
+ if (t - window) % stride != 0:
+ padding = t - (t-window)//stride * stride
+ inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
+ inputs = torch.from_numpy(np.float32(inputs))
+ inputs = to_device(inputs, device=self.device)
+ b, t = inputs.shape
+ if decode_do_segment:
+ outputs = np.zeros((args.num_spks, t))
+ give_up_length = (window - stride)//2
+ current_idx = 0
+ while current_idx + window <= t:
+ tmp_input = inputs[:, current_idx:current_idx+window]
+ tmp_out_list = model(tmp_input,)
+ for spk in range(args.num_spks):
+ tmp_out_list[spk] = tmp_out_list[spk][0, :].cpu().numpy()
+ if current_idx == 0:
+ outputs[spk, current_idx:current_idx+window-give_up_length] = \
+ tmp_out_list[spk][:-give_up_length]
+ else:
+ outputs[spk, current_idx+give_up_length:current_idx+window-give_up_length] = \
+ tmp_out_list[spk][give_up_length:-give_up_length]
+ current_idx += stride
+ for spk in range(args.num_spks):
+ out.append(outputs[spk, :])
+ else:
+ out_list = model(inputs)
+ for spk in range(args.num_spks):
+ out.append(out_list[spk][0, :].cpu().numpy())
+
+ max_abs = 0
+ for spk in range(args.num_spks):
+ if max_abs < max(abs(out[spk])):
+ max_abs = max(abs(out[spk]))
+ for spk in range(args.num_spks):
+ out[spk] = out[spk][:nsamples]
+ out[spk] = out[spk]/max_abs
+
+ return out
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
+ ) -> List[torch.Tensor]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ speech list: list of speech data
+
+ """
+
+ out = self.decode(self.ss_model, self.ss_infer_args, speech, speech_lengths)
+
+ return out
+
diff --git a/funasr/bin/ss_inference_launch.py b/funasr/bin/ss_inference_launch.py
new file mode 100644
index 0000000..bab68ad
--- /dev/null
+++ b/funasr/bin/ss_inference_launch.py
@@ -0,0 +1,253 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+
+import argparse
+import logging
+import os
+import sys
+from typing import Optional
+from typing import Union
+
+import numpy as np
+import torch
+import soundfile as sf
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2triple_str
+from funasr.bin.ss_infer import SpeechSeparator
+
+
+def inference_ss(
+ batch_size: int,
+ ngpu: int,
+ log_level: Union[int, str],
+ ss_infer_config: Optional[str],
+ ss_model_file: Optional[str],
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ num_workers: int = 1,
+ num_spks: int = 2,
+ sample_rate: int = 8000,
+ param_dict: dict = None,
+ **kwargs,
+):
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+ batch_size = 1
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech separator
+ speech_separator_kwargs = dict(
+ ss_infer_config=ss_infer_config,
+ ss_model_file=ss_model_file,
+ device=device,
+ dtype=dtype,
+ )
+ logging.info("speech_separator_kwargs: {}".format(speech_separator_kwargs))
+ speech_separator = SpeechSeparator(**speech_separator_kwargs)
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = build_streaming_iterator(
+ task_name="ss",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ )
+
+ # 4 .Start for-loop
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if not os.path.exists(output_path):
+ cmd = 'mkdir -p ' + output_path
+ os.system(cmd)
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+
+ # do speech separation
+ logging.info('decoding: {}'.format(keys[0]))
+ ss_results = speech_separator(**batch)
+
+ for spk in range(num_spks):
+ sf.write(os.path.join(output_path, keys[0].replace('.wav', '_s'+str(spk+1)+'.wav')), ss_results[spk], sample_rate)
+ torch.cuda.empty_cache()
+ return ss_results
+
+ return _forward
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "mossformer":
+ return inference_ss(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="Speech Separator Decoding",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # Note(kamo): Use '_' instead of '-' as separator.
+ # '-' is confusing if written in yaml.
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=1,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument(
+ "--njob",
+ type=int,
+ default=1,
+ help="The number of jobs for each gpu",
+ )
+ parser.add_argument(
+ "--gpuid_list",
+ type=str,
+ default="2",
+ help="The visible gpus",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ required=True,
+ action="append",
+ )
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument(
+ "--ss_infer_config",
+ type=str,
+ help="SS infer configuration",
+ )
+ group.add_argument(
+ "--ss_model_file",
+ type=str,
+ help="SS model parameter file",
+ )
+ group.add_argument(
+ "--ss_train_config",
+ type=str,
+ help="SS training configuration",
+ )
+
+ group = parser.add_argument_group("The inference configuration related")
+ group.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+
+ parser.add_argument(
+ '--num-spks', dest='num_spks', type=int, default=2)
+
+ parser.add_argument(
+ '--one-time-decode-length', dest='one_time_decode_length', type=int,
+ default=60, help='the max length (second) for one-time decoding')
+
+ parser.add_argument(
+ '--decode-window', dest='decode_window', type=int,
+ default=1, help='segmental decoding window length (second)')
+
+ parser.add_argument(
+ '--sample-rate', dest='sample_rate', type=int, default='8000')
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ parser.add_argument(
+ "--mode",
+ type=str,
+ default="mossformer",
+ help="The decoding mode",
+ )
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+
+ # set logging messages
+ logging.basicConfig(
+ level=args.log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ logging.info("Decoding args: {}".format(kwargs))
+
+ # gpu setting
+ if args.ngpu > 0:
+ jobid = int(args.output_dir.split(".")[-1])
+ gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
+
+ inference_pipeline = inference_launch(**kwargs)
+ return inference_pipeline(kwargs["data_path_and_name_and_type"])
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py
index be8f910..66fdfd0 100644
--- a/funasr/build_utils/build_model.py
+++ b/funasr/build_utils/build_model.py
@@ -5,6 +5,7 @@
from funasr.build_utils.build_punc_model import build_punc_model
from funasr.build_utils.build_sv_model import build_sv_model
from funasr.build_utils.build_vad_model import build_vad_model
+from funasr.build_utils.build_ss_model import build_ss_model
def build_model(args):
@@ -22,6 +23,8 @@
model = build_diar_model(args)
elif args.task_name == "sv":
model = build_sv_model(args)
+ elif args.task_name == "ss":
+ model = build_ss_model(args)
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))
diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py
index 26542cd..6130e71 100644
--- a/funasr/build_utils/build_model_from_file.py
+++ b/funasr/build_utils/build_model_from_file.py
@@ -11,6 +11,18 @@
from funasr.models.base_model import FunASRModel
+def load_checkpoint(checkpoint_path, use_cuda=1):
+ if use_cuda:
+ checkpoint = torch.load(checkpoint_path)
+ else:
+ checkpoint = torch.load(
+ checkpoint_path, map_location=lambda storage, loc: storage)
+ return checkpoint
+
+def reload_ss_for_eval(model, checkpoint_path, use_cuda=False):
+ checkpoint = load_checkpoint(checkpoint_path, use_cuda)
+ model.load_state_dict(checkpoint['model'], strict=False)
+
def build_model_from_file(
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
@@ -70,6 +82,9 @@
model.load_state_dict(model_dict)
else:
model_dict = torch.load(model_file, map_location=device)
+ if task_name == 'ss':
+ reload_ss_for_eval(model, model_file, use_cuda=True)
+ logging.info("model is loaded from path: {}".format(model_file))
if task_name == "diar" and mode == "sond":
model_dict = fileter_model_dict(model_dict, model.state_dict())
if task_name == "vad":
diff --git a/funasr/build_utils/build_ss_model.py b/funasr/build_utils/build_ss_model.py
new file mode 100644
index 0000000..a6b5209
--- /dev/null
+++ b/funasr/build_utils/build_ss_model.py
@@ -0,0 +1,15 @@
+from funasr.models.e2e_ss import MossFormer
+
+def build_ss_model(args):
+ model = MossFormer(
+ in_channels=args.encoder_embedding_dim,
+ out_channels=args.mossformer_sequence_dim,
+ num_blocks=args.num_mossformer_layer,
+ kernel_size=args.encoder_kernel_size,
+ norm=args.norm,
+ num_spks=args.num_spks,
+ skip_around_intra=args.skip_around_intra,
+ use_global_pos_enc=args.use_global_pos_enc,
+ max_length=args.max_length)
+
+ return model
diff --git a/funasr/models/decoder/mossformer_decoder.py b/funasr/models/decoder/mossformer_decoder.py
new file mode 100644
index 0000000..e0189f7
--- /dev/null
+++ b/funasr/models/decoder/mossformer_decoder.py
@@ -0,0 +1,53 @@
+import torch
+import torch.nn as nn
+
+
+class MossFormerDecoder(nn.ConvTranspose1d):
+ """A decoder layer that consists of ConvTranspose1d.
+
+ Arguments
+ ---------
+ kernel_size : int
+ Length of filters.
+ in_channels : int
+ Number of input channels.
+ out_channels : int
+ Number of output channels.
+
+
+ Example
+ ---------
+ >>> x = torch.randn(2, 100, 1000)
+ >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
+ >>> h = decoder(x)
+ >>> h.shape
+ torch.Size([2, 1003])
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(MossFormerDecoder, self).__init__(*args, **kwargs)
+
+ def forward(self, x):
+ """Return the decoded output.
+
+ Arguments
+ ---------
+ x : torch.Tensor
+ Input tensor with dimensionality [B, N, L].
+ where, B = Batchsize,
+ N = number of filters
+ L = time points
+ """
+
+ if x.dim() not in [2, 3]:
+ raise RuntimeError(
+ "{} accept 3/4D tensor as input".format(self.__name__)
+ )
+ x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
+
+ if torch.squeeze(x).dim() == 1:
+ x = torch.squeeze(x, dim=1)
+ else:
+ x = torch.squeeze(x)
+ return x
+
diff --git a/funasr/models/e2e_ss.py b/funasr/models/e2e_ss.py
new file mode 100644
index 0000000..1a46b3f
--- /dev/null
+++ b/funasr/models/e2e_ss.py
@@ -0,0 +1,95 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+from funasr.models.base_model import FunASRModel
+from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet
+from funasr.models.decoder.mossformer_decoder import MossFormerDecoder
+
+
+class MossFormer(FunASRModel):
+ """The MossFormer model for separating input mixed speech into different speaker's speech.
+
+ Arguments
+ ---------
+ in_channels : int
+ Number of channels at the output of the encoder.
+ out_channels : int
+ Number of channels that would be inputted to the intra and inter blocks.
+ num_blocks : int
+ Number of layers of Dual Computation Block.
+ norm : str
+ Normalization type.
+ num_spks : int
+ Number of sources (speakers).
+ skip_around_intra : bool
+ Skip connection around intra.
+ use_global_pos_enc : bool
+ Global positional encodings.
+ max_length : int
+ Maximum sequence length.
+ kernel_size: int
+ Encoder and decoder kernel size
+ """
+
+ def __init__(
+ self,
+ in_channels=512,
+ out_channels=512,
+ num_blocks=24,
+ kernel_size=16,
+ norm="ln",
+ num_spks=2,
+ skip_around_intra=True,
+ use_global_pos_enc=True,
+ max_length=20000,
+ ):
+ super(MossFormer, self).__init__()
+ self.num_spks = num_spks
+ # Encoding
+ self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1)
+
+ ##Compute Mask
+ self.mask_net = MossFormer_MaskNet(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ norm=norm,
+ num_spks=num_spks,
+ skip_around_intra=skip_around_intra,
+ use_global_pos_enc=use_global_pos_enc,
+ max_length=max_length,
+ )
+ self.dec = MossFormerDecoder(
+ in_channels=out_channels,
+ out_channels=1,
+ kernel_size=kernel_size,
+ stride = kernel_size//2,
+ bias=False
+ )
+ def forward(self, input):
+ x = self.enc(input)
+ mask = self.mask_net(x)
+ x = torch.stack([x] * self.num_spks)
+ sep_x = x * mask
+
+ # Decoding
+ est_source = torch.cat(
+ [
+ self.dec(sep_x[i]).unsqueeze(-1)
+ for i in range(self.num_spks)
+ ],
+ dim=-1,
+ )
+ T_origin = input.size(1)
+ T_est = est_source.size(1)
+ if T_origin > T_est:
+ est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
+ else:
+ est_source = est_source[:, :T_origin, :]
+
+ out = []
+ for spk in range(self.num_spks):
+ out.append(est_source[:,:,spk])
+ return out
diff --git a/funasr/models/encoder/mossformer_encoder.py b/funasr/models/encoder/mossformer_encoder.py
new file mode 100644
index 0000000..54d80ca
--- /dev/null
+++ b/funasr/models/encoder/mossformer_encoder.py
@@ -0,0 +1,417 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from rotary_embedding_torch import RotaryEmbedding
+from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
+from funasr.modules.embedding import ScaledSinuEmbedding
+from funasr.modules.mossformer import FLASH_ShareA_FFConvM
+
+
+def select_norm(norm, dim, shape):
+ """Just a wrapper to select the normalization type.
+ """
+
+ if norm == "gln":
+ return GlobalLayerNorm(dim, shape, elementwise_affine=True)
+ if norm == "cln":
+ return CumulativeLayerNorm(dim, elementwise_affine=True)
+ if norm == "ln":
+ return nn.GroupNorm(1, dim, eps=1e-8)
+ else:
+ return nn.BatchNorm1d(dim)
+
+
+class MossformerBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth,
+ group_size = 256,
+ query_key_dim = 128,
+ expansion_factor = 4.,
+ causal = False,
+ attn_dropout = 0.1,
+ norm_type = 'scalenorm',
+ shift_tokens = True
+ ):
+ super().__init__()
+ assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
+
+ if norm_type == 'scalenorm':
+ norm_klass = ScaleNorm
+ elif norm_type == 'layernorm':
+ norm_klass = nn.LayerNorm
+
+ self.group_size = group_size
+
+ rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
+ # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
+ self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])
+
+ def forward(
+ self,
+ x,
+ *,
+ mask = None
+ ):
+ ii = 0
+ for flash in self.layers:
+ x = flash(x, mask = mask)
+ ii = ii + 1
+ return x
+
+
+class MossFormer_MaskNet(nn.Module):
+ """The MossFormer module for computing output masks.
+
+ Arguments
+ ---------
+ in_channels : int
+ Number of channels at the output of the encoder.
+ out_channels : int
+ Number of channels that would be inputted to the intra and inter blocks.
+ num_blocks : int
+ Number of layers of Dual Computation Block.
+ norm : str
+ Normalization type.
+ num_spks : int
+ Number of sources (speakers).
+ skip_around_intra : bool
+ Skip connection around intra.
+ use_global_pos_enc : bool
+ Global positional encodings.
+ max_length : int
+ Maximum sequence length.
+
+ Example
+ ---------
+ >>> mossformer_block = MossFormerM(1, 64, 8)
+ >>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2)
+ >>> x = torch.randn(10, 64, 2000)
+ >>> x = mossformer_masknet(x)
+ >>> x.shape
+ torch.Size([2, 10, 64, 2000])
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ num_blocks=24,
+ norm="ln",
+ num_spks=2,
+ skip_around_intra=True,
+ use_global_pos_enc=True,
+ max_length=20000,
+ ):
+ super(MossFormer_MaskNet, self).__init__()
+ self.num_spks = num_spks
+ self.num_blocks = num_blocks
+ self.norm = select_norm(norm, in_channels, 3)
+ self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
+ self.use_global_pos_enc = use_global_pos_enc
+
+ if self.use_global_pos_enc:
+ self.pos_enc = ScaledSinuEmbedding(out_channels)
+
+ self.mdl = Computation_Block(
+ num_blocks,
+ out_channels,
+ norm,
+ skip_around_intra=skip_around_intra,
+ )
+
+ self.conv1d_out = nn.Conv1d(
+ out_channels, out_channels * num_spks, kernel_size=1
+ )
+ self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
+ self.prelu = nn.PReLU()
+ self.activation = nn.ReLU()
+ # gated output layer
+ self.output = nn.Sequential(
+ nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
+ )
+ self.output_gate = nn.Sequential(
+ nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
+ )
+
+ def forward(self, x):
+ """Returns the output tensor.
+
+ Arguments
+ ---------
+ x : torch.Tensor
+ Input tensor of dimension [B, N, S].
+
+ Returns
+ -------
+ out : torch.Tensor
+ Output tensor of dimension [spks, B, N, S]
+ where, spks = Number of speakers
+ B = Batchsize,
+ N = number of filters
+ S = the number of time frames
+ """
+
+ # before each line we indicate the shape after executing the line
+
+ # [B, N, L]
+ x = self.norm(x)
+
+ # [B, N, L]
+ x = self.conv1d_encoder(x)
+ if self.use_global_pos_enc:
+ #x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
+ # x.size(1) ** 0.5)
+ base = x
+ x = x.transpose(1, -1)
+ emb = self.pos_enc(x)
+ emb = emb.transpose(0, -1)
+ #print('base: {}, emb: {}'.format(base.shape, emb.shape))
+ x = base + emb
+
+
+ # [B, N, S]
+ #for i in range(self.num_modules):
+ # x = self.dual_mdl[i](x)
+ x = self.mdl(x)
+ x = self.prelu(x)
+
+ # [B, N*spks, S]
+ x = self.conv1d_out(x)
+ B, _, S = x.shape
+
+ # [B*spks, N, S]
+ x = x.view(B * self.num_spks, -1, S)
+
+ # [B*spks, N, S]
+ x = self.output(x) * self.output_gate(x)
+
+ # [B*spks, N, S]
+ x = self.conv1_decoder(x)
+
+ # [B, spks, N, S]
+ _, N, L = x.shape
+ x = x.view(B, self.num_spks, N, L)
+ x = self.activation(x)
+
+ # [spks, B, N, S]
+ x = x.transpose(0, 1)
+
+ return x
+
+
+class MossFormerEncoder(nn.Module):
+ """Convolutional Encoder Layer.
+
+ Arguments
+ ---------
+ kernel_size : int
+ Length of filters.
+ in_channels : int
+ Number of input channels.
+ out_channels : int
+ Number of output channels.
+
+ Example
+ -------
+ >>> x = torch.randn(2, 1000)
+ >>> encoder = Encoder(kernel_size=4, out_channels=64)
+ >>> h = encoder(x)
+ >>> h.shape
+ torch.Size([2, 64, 499])
+ """
+
+ def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
+ super(MossFormerEncoder, self).__init__()
+ self.conv1d = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=kernel_size // 2,
+ groups=1,
+ bias=False,
+ )
+ self.in_channels = in_channels
+
+ def forward(self, x):
+ """Return the encoded output.
+
+ Arguments
+ ---------
+ x : torch.Tensor
+ Input tensor with dimensionality [B, L].
+ Return
+ ------
+ x : torch.Tensor
+ Encoded tensor with dimensionality [B, N, T_out].
+
+ where B = Batchsize
+ L = Number of timepoints
+ N = Number of filters
+ T_out = Number of timepoints at the output of the encoder
+ """
+ # B x L -> B x 1 x L
+ if self.in_channels == 1:
+ x = torch.unsqueeze(x, dim=1)
+ # B x 1 x L -> B x N x T_out
+ x = self.conv1d(x)
+ x = F.relu(x)
+
+ return x
+
+class MossFormerM(nn.Module):
+ """This class implements the transformer encoder.
+
+ Arguments
+ ---------
+ num_blocks : int
+ Number of mossformer blocks to include.
+ d_model : int
+ The dimension of the input embedding.
+ attn_dropout : float
+ Dropout for the self-attention (Optional).
+ group_size: int
+ the chunk size
+ query_key_dim: int
+ the attention vector dimension
+ expansion_factor: int
+ the expansion factor for the linear projection in conv module
+ causal: bool
+ true for causal / false for non causal
+
+ Example
+ -------
+ >>> import torch
+ >>> x = torch.rand((8, 60, 512))
+ >>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
+ >>> output, _ = net(x)
+ >>> output.shape
+ torch.Size([8, 60, 512])
+ """
+ def __init__(
+ self,
+ num_blocks,
+ d_model=None,
+ causal=False,
+ group_size = 256,
+ query_key_dim = 128,
+ expansion_factor = 4.,
+ attn_dropout = 0.1
+ ):
+ super().__init__()
+
+ self.mossformerM = MossformerBlock(
+ dim=d_model,
+ depth=num_blocks,
+ group_size=group_size,
+ query_key_dim=query_key_dim,
+ expansion_factor=expansion_factor,
+ causal=causal,
+ attn_dropout=attn_dropout
+ )
+ self.norm = nn.LayerNorm(d_model, eps=1e-6)
+
+ def forward(
+ self,
+ src,
+ ):
+ """
+ Arguments
+ ----------
+ src : torch.Tensor
+ Tensor shape [B, L, N],
+ where, B = Batchsize,
+ L = time points
+ N = number of filters
+ The sequence to the encoder layer (required).
+ src_mask : tensor
+ The mask for the src sequence (optional).
+ src_key_padding_mask : tensor
+ The mask for the src keys per batch (optional).
+ """
+ output = self.mossformerM(src)
+ output = self.norm(output)
+
+ return output
+
+
+class Computation_Block(nn.Module):
+ """Computation block for dual-path processing.
+
+ Arguments
+ ---------
+ out_channels : int
+ Dimensionality of inter/intra model.
+ norm : str
+ Normalization type.
+ skip_around_intra : bool
+ Skip connection around the intra layer.
+
+ Example
+ ---------
+ >>> comp_block = Computation_Block(64)
+ >>> x = torch.randn(10, 64, 100)
+ >>> x = comp_block(x)
+ >>> x.shape
+ torch.Size([10, 64, 100])
+ """
+
+ def __init__(
+ self,
+ num_blocks,
+ out_channels,
+ norm="ln",
+ skip_around_intra=True,
+ ):
+ super(Computation_Block, self).__init__()
+
+ ##MossFormer2M: MossFormer with recurrence
+ #self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
+ ##MossFormerM: the orignal MossFormer
+ self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
+ self.skip_around_intra = skip_around_intra
+
+ # Norm
+ self.norm = norm
+ if norm is not None:
+ self.intra_norm = select_norm(norm, out_channels, 3)
+
+ def forward(self, x):
+ """Returns the output tensor.
+
+ Arguments
+ ---------
+ x : torch.Tensor
+ Input tensor of dimension [B, N, S].
+
+
+ Return
+ ---------
+ out: torch.Tensor
+ Output tensor of dimension [B, N, S].
+ where, B = Batchsize,
+ N = number of filters
+ S = sequence time index
+ """
+ B, N, S = x.shape
+ # intra RNN
+ # [B, S, N]
+ intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
+
+ intra = self.intra_mdl(intra)
+
+ # [B, N, S]
+ intra = intra.permute(0, 2, 1).contiguous()
+ if self.norm is not None:
+ intra = self.intra_norm(intra)
+
+ # [B, N, S]
+ if self.skip_around_intra:
+ intra = intra + x
+
+ out = intra
+ return out
+
diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py
index 374eba4..1995bbe 100644
--- a/funasr/modules/embedding.py
+++ b/funasr/modules/embedding.py
@@ -9,6 +9,7 @@
import math
import torch
import torch.nn.functional as F
+from torch import einsum
def _pre_hook(
state_dict,
@@ -510,3 +511,19 @@
pos_enc = self.dropout(pos_enc)
return pos_enc
+
+
+class ScaledSinuEmbedding(torch.nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = torch.nn.Parameter(torch.ones(1,))
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+
+ def forward(self, x):
+ n, device = x.shape[1], x.device
+ t = torch.arange(n, device = device).type_as(self.inv_freq)
+ sinu = einsum('i , j -> i j', t, self.inv_freq)
+ emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
+ return emb * self.scale
+
diff --git a/funasr/modules/layer_norm.py b/funasr/modules/layer_norm.py
index 6e934e6..8683230 100644
--- a/funasr/modules/layer_norm.py
+++ b/funasr/modules/layer_norm.py
@@ -7,6 +7,7 @@
"""Layer normalization module."""
import torch
+import torch.nn as nn
class LayerNorm(torch.nn.LayerNorm):
@@ -40,3 +41,137 @@
.forward(x.transpose(self.dim, -1))
.transpose(self.dim, -1)
)
+
+
+class GlobalLayerNorm(nn.Module):
+ """Calculate Global Layer Normalization.
+
+ Arguments
+ ---------
+ dim : (int or list or torch.Size)
+ Input shape from an expected input of size.
+ eps : float
+ A value added to the denominator for numerical stability.
+ elementwise_affine : bool
+ A boolean value that when set to True,
+ this module has learnable per-element affine parameters
+ initialized to ones (for weights) and zeros (for biases).
+
+ Example
+ -------
+ >>> x = torch.randn(5, 10, 20)
+ >>> GLN = GlobalLayerNorm(10, 3)
+ >>> x_norm = GLN(x)
+ """
+
+ def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
+ super(GlobalLayerNorm, self).__init__()
+ self.dim = dim
+ self.eps = eps
+ self.elementwise_affine = elementwise_affine
+
+ if self.elementwise_affine:
+ if shape == 3:
+ self.weight = nn.Parameter(torch.ones(self.dim, 1))
+ self.bias = nn.Parameter(torch.zeros(self.dim, 1))
+ if shape == 4:
+ self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
+ self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
+ else:
+ self.register_parameter("weight", None)
+ self.register_parameter("bias", None)
+
+ def forward(self, x):
+ """Returns the normalized tensor.
+
+ Arguments
+ ---------
+ x : torch.Tensor
+ Tensor of size [N, C, K, S] or [N, C, L].
+ """
+ # x = N x C x K x S or N x C x L
+ # N x 1 x 1
+ # cln: mean,var N x 1 x K x S
+ # gln: mean,var N x 1 x 1
+ if x.dim() == 3:
+ mean = torch.mean(x, (1, 2), keepdim=True)
+ var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
+ if self.elementwise_affine:
+ x = (
+ self.weight * (x - mean) / torch.sqrt(var + self.eps)
+ + self.bias
+ )
+ else:
+ x = (x - mean) / torch.sqrt(var + self.eps)
+
+ if x.dim() == 4:
+ mean = torch.mean(x, (1, 2, 3), keepdim=True)
+ var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
+ if self.elementwise_affine:
+ x = (
+ self.weight * (x - mean) / torch.sqrt(var + self.eps)
+ + self.bias
+ )
+ else:
+ x = (x - mean) / torch.sqrt(var + self.eps)
+ return x
+
+
+class CumulativeLayerNorm(nn.LayerNorm):
+ """Calculate Cumulative Layer Normalization.
+
+ Arguments
+ ---------
+ dim : int
+ Dimension that you want to normalize.
+ elementwise_affine : True
+ Learnable per-element affine parameters.
+
+ Example
+ -------
+ >>> x = torch.randn(5, 10, 20)
+ >>> CLN = CumulativeLayerNorm(10)
+ >>> x_norm = CLN(x)
+ """
+
+ def __init__(self, dim, elementwise_affine=True):
+ super(CumulativeLayerNorm, self).__init__(
+ dim, elementwise_affine=elementwise_affine, eps=1e-8
+ )
+
+ def forward(self, x):
+ """Returns the normalized tensor.
+
+ Arguments
+ ---------
+ x : torch.Tensor
+ Tensor size [N, C, K, S] or [N, C, L]
+ """
+ # x: N x C x K x S or N x C x L
+ # N x K x S x C
+ if x.dim() == 4:
+ x = x.permute(0, 2, 3, 1).contiguous()
+ # N x K x S x C == only channel norm
+ x = super().forward(x)
+ # N x C x K x S
+ x = x.permute(0, 3, 1, 2).contiguous()
+ if x.dim() == 3:
+ x = torch.transpose(x, 1, 2)
+ # N x L x C == only channel norm
+ x = super().forward(x)
+ # N x C x L
+ x = torch.transpose(x, 1, 2)
+ return x
+
+
+class ScaleNorm(nn.Module):
+ def __init__(self, dim, eps = 1e-5):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(1))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
+ return x / norm.clamp(min = self.eps) * self.g
+
diff --git a/funasr/modules/mossformer.py b/funasr/modules/mossformer.py
new file mode 100644
index 0000000..f1e8e28
--- /dev/null
+++ b/funasr/modules/mossformer.py
@@ -0,0 +1,307 @@
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange
+
+
+def identity(t, *args, **kwargs):
+ return t
+
+def append_dims(x, num_dims):
+ if num_dims <= 0:
+ return x
+ return x.view(*x.shape, *((1,) * num_dims))
+
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ return val if exists(val) else d
+
+def padding_to_multiple_of(n, mult):
+ remainder = n % mult
+ if remainder == 0:
+ return 0
+ return mult - remainder
+
+
+class Transpose(nn.Module):
+ """ Wrapper class of torch.transpose() for Sequential module. """
+ def __init__(self, shape: tuple):
+ super(Transpose, self).__init__()
+ self.shape = shape
+
+ def forward(self, x):
+ return x.transpose(*self.shape)
+
+
+class DepthwiseConv1d(nn.Module):
+ """
+ When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
+ this operation is termed in literature as depthwise convolution.
+ Args:
+ in_channels (int): Number of channels in the input
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
+ bias (bool, optional): If True, adds a learnable bias to the output. Default: True
+ Inputs: inputs
+ - **inputs** (batch, in_channels, time): Tensor containing input vector
+ Returns: outputs
+ - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
+ """
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ bias: bool = False,
+ ) -> None:
+ super(DepthwiseConv1d, self).__init__()
+ assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
+ self.conv = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ groups=in_channels,
+ stride=stride,
+ padding=padding,
+ bias=bias,
+ )
+
+ def forward(self, inputs):
+ return self.conv(inputs)
+
+
+class ConvModule(nn.Module):
+ """
+ Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
+ This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
+ to aid training deep models.
+ Args:
+ in_channels (int): Number of channels in the input
+ kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
+ dropout_p (float, optional): probability of dropout
+ Inputs: inputs
+ inputs (batch, time, dim): Tensor contains input sequences
+ Outputs: outputs
+ outputs (batch, time, dim): Tensor produces by conformer convolution module.
+ """
+ def __init__(
+ self,
+ in_channels: int,
+ kernel_size: int = 17,
+ expansion_factor: int = 2,
+ dropout_p: float = 0.1,
+ ) -> None:
+ super(ConvModule, self).__init__()
+ assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
+ assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
+
+ self.sequential = nn.Sequential(
+ Transpose(shape=(1, 2)),
+ DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
+ )
+
+ def forward(self, inputs):
+ return inputs + self.sequential(inputs).transpose(1, 2)
+
+
+class OffsetScale(nn.Module):
+ def __init__(self, dim, heads = 1):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
+ self.beta = nn.Parameter(torch.zeros(heads, dim))
+ nn.init.normal_(self.gamma, std = 0.02)
+
+ def forward(self, x):
+ out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
+ return out.unbind(dim = -2)
+
+
+class FFConvM(nn.Module):
+ def __init__(
+ self,
+ dim_in,
+ dim_out,
+ norm_klass = nn.LayerNorm,
+ dropout = 0.1
+ ):
+ super().__init__()
+ self.mdl = nn.Sequential(
+ norm_klass(dim_in),
+ nn.Linear(dim_in, dim_out),
+ nn.SiLU(),
+ ConvModule(dim_out),
+ nn.Dropout(dropout)
+ )
+ def forward(
+ self,
+ x,
+ ):
+ output = self.mdl(x)
+ return output
+
+
+class FLASH_ShareA_FFConvM(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ group_size = 256,
+ query_key_dim = 128,
+ expansion_factor = 1.,
+ causal = False,
+ dropout = 0.1,
+ rotary_pos_emb = None,
+ norm_klass = nn.LayerNorm,
+ shift_tokens = True
+ ):
+ super().__init__()
+ hidden_dim = int(dim * expansion_factor)
+ self.group_size = group_size
+ self.causal = causal
+ self.shift_tokens = shift_tokens
+
+ # positional embeddings
+ self.rotary_pos_emb = rotary_pos_emb
+ # norm
+ self.dropout = nn.Dropout(dropout)
+ # projections
+
+ self.to_hidden = FFConvM(
+ dim_in = dim,
+ dim_out = hidden_dim,
+ norm_klass = norm_klass,
+ dropout = dropout,
+ )
+ self.to_qk = FFConvM(
+ dim_in = dim,
+ dim_out = query_key_dim,
+ norm_klass = norm_klass,
+ dropout = dropout,
+ )
+
+ self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
+
+ self.to_out = FFConvM(
+ dim_in = dim*2,
+ dim_out = dim,
+ norm_klass = norm_klass,
+ dropout = dropout,
+ )
+
+ self.gateActivate=nn.Sigmoid()
+
+ def forward(
+ self,
+ x,
+ *,
+ mask = None
+ ):
+
+ """
+ b - batch
+ n - sequence length (within groups)
+ g - group dimension
+ d - feature dimension (keys)
+ e - feature dimension (values)
+ i - sequence dimension (source)
+ j - sequence dimension (target)
+ """
+
+ normed_x = x
+
+ # do token shift - a great, costless trick from an independent AI researcher in Shenzhen
+ residual = x
+
+ if self.shift_tokens:
+ x_shift, x_pass = normed_x.chunk(2, dim = -1)
+ x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
+ normed_x = torch.cat((x_shift, x_pass), dim = -1)
+
+ # initial projections
+
+ v, u = self.to_hidden(normed_x).chunk(2, dim = -1)
+ qk = self.to_qk(normed_x)
+
+ # offset and scale
+ quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
+ att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
+ out = (att_u*v ) * self.gateActivate(att_v*u)
+ x = x + self.to_out(out)
+ return x
+
+ def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None):
+ b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
+
+ if exists(mask):
+ lin_mask = rearrange(mask, '... -> ... 1')
+ lin_k = lin_k.masked_fill(~lin_mask, 0.)
+
+ # rotate queries and keys
+
+ if exists(self.rotary_pos_emb):
+ quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
+
+ # padding for groups
+
+ padding = padding_to_multiple_of(n, g)
+
+ if padding > 0:
+ quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))
+
+ mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
+ mask = F.pad(mask, (0, padding), value = False)
+
+ # group along sequence
+
+ quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
+
+ if exists(mask):
+ mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
+
+ # calculate quadratic attention output
+
+ sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
+
+ attn = F.relu(sim) ** 2
+ attn = self.dropout(attn)
+
+ if exists(mask):
+ attn = attn.masked_fill(~mask, 0.)
+
+ if self.causal:
+ causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
+ attn = attn.masked_fill(causal_mask, 0.)
+
+ quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
+ quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
+
+ # calculate linear attention output
+
+ if self.causal:
+ lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
+ # exclusive cumulative sum along group dimension
+ lin_kv = lin_kv.cumsum(dim = 1)
+ lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
+ lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
+
+ lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
+ # exclusive cumulative sum along group dimension
+ lin_ku = lin_ku.cumsum(dim = 1)
+ lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
+ lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
+ else:
+ lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
+ lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
+
+ lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
+ lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
+
+ # fold back groups into full sequence, and excise out padding
+ return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u))
+
--
Gitblit v1.9.1