| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | |
| | | import logging |
| | | from pathlib import Path |
| | | from typing import List |
| | | from typing import Union |
| | | |
| | | import numpy as np |
| | | import torch |
| | | |
| | | from funasr.build_utils.build_model_from_file import build_model_from_file |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | |
| | | |
| | | class SpeechSeparator: |
| | | """SpeechSeparator class |
| | | |
| | | Examples: |
| | | >>> import soundfile |
| | | >>> speech_separator = MossFormer("ss_config.yml", "ss.pt") |
| | | >>> audio, rate = soundfile.read("speech.wav") |
| | | >>> separated_wavs = speech_separator(audio) |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | ss_infer_config: Union[Path, str] = None, |
| | | ss_model_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | batch_size: int = 1, |
| | | dtype: str = "float32", |
| | | **kwargs, |
| | | ): |
| | | |
| | | # 1. Build ss model |
| | | ss_model, ss_infer_args = build_model_from_file( |
| | | ss_infer_config, ss_model_file, None, device, task_name="ss" |
| | | ) |
| | | |
| | | logging.info("ss_model: {}".format(ss_model)) |
| | | logging.info("ss_infer_args: {}".format(ss_infer_args)) |
| | | |
| | | ss_model.to(dtype=getattr(torch, dtype)).eval() |
| | | |
| | | self.ss_model = ss_model |
| | | self.ss_infer_args = ss_infer_args |
| | | self.device = device |
| | | self.dtype = dtype |
| | | self.batch_size = batch_size |
| | | |
| | | def decode(self, model, args, inputs, nsamples): |
| | | decode_do_segment = False |
| | | with torch.no_grad(): |
| | | out = [] |
| | | window = args.sample_rate * args.decode_window # decoding window length |
| | | stride = int(window*0.75) # decoding stride if segmentation is used |
| | | b, t = inputs.shape |
| | | if t > window * args.one_time_decode_length: |
| | | decode_do_segment = True # set segment decoding to true for very long sequence |
| | | |
| | | if t < window: |
| | | inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], window-t))], 1) |
| | | elif t < window + stride: |
| | | padding = window + stride - t |
| | | inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1) |
| | | else: |
| | | if (t - window) % stride != 0: |
| | | padding = t - (t-window)//stride * stride |
| | | inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1) |
| | | inputs = torch.from_numpy(np.float32(inputs)) |
| | | inputs = to_device(inputs, device=self.device) |
| | | b, t = inputs.shape |
| | | if decode_do_segment: |
| | | outputs = np.zeros((args.num_spks, t)) |
| | | give_up_length = (window - stride)//2 |
| | | current_idx = 0 |
| | | while current_idx + window <= t: |
| | | tmp_input = inputs[:, current_idx:current_idx+window] |
| | | tmp_out_list = model(tmp_input,) |
| | | for spk in range(args.num_spks): |
| | | tmp_out_list[spk] = tmp_out_list[spk][0, :].cpu().numpy() |
| | | if current_idx == 0: |
| | | outputs[spk, current_idx:current_idx+window-give_up_length] = \ |
| | | tmp_out_list[spk][:-give_up_length] |
| | | else: |
| | | outputs[spk, current_idx+give_up_length:current_idx+window-give_up_length] = \ |
| | | tmp_out_list[spk][give_up_length:-give_up_length] |
| | | current_idx += stride |
| | | for spk in range(args.num_spks): |
| | | out.append(outputs[spk, :]) |
| | | else: |
| | | out_list = model(inputs) |
| | | for spk in range(args.num_spks): |
| | | out.append(out_list[spk][0, :].cpu().numpy()) |
| | | |
| | | max_abs = 0 |
| | | for spk in range(args.num_spks): |
| | | if max_abs < max(abs(out[spk])): |
| | | max_abs = max(abs(out[spk])) |
| | | for spk in range(args.num_spks): |
| | | out[spk] = out[spk][:nsamples] |
| | | out[spk] = out[spk]/max_abs |
| | | |
| | | return out |
| | | |
| | | @torch.no_grad() |
| | | def __call__( |
| | | self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, |
| | | ) -> List[torch.Tensor]: |
| | | """Inference |
| | | |
| | | Args: |
| | | speech: Input speech data |
| | | Returns: |
| | | speech list: list of speech data |
| | | |
| | | """ |
| | | |
| | | out = self.decode(self.ss_model, self.ss_infer_args, speech, speech_lengths) |
| | | |
| | | return out |
| | | |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | import sys |
| | | from typing import Optional |
| | | from typing import Union |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import soundfile as sf |
| | | from funasr.build_utils.build_streaming_iterator import build_streaming_iterator |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.bin.ss_infer import SpeechSeparator |
| | | |
| | | |
| | | def inference_ss( |
| | | batch_size: int, |
| | | ngpu: int, |
| | | log_level: Union[int, str], |
| | | ss_infer_config: Optional[str], |
| | | ss_model_file: Optional[str], |
| | | output_dir: Optional[str] = None, |
| | | dtype: str = "float32", |
| | | seed: int = 0, |
| | | num_workers: int = 1, |
| | | num_spks: int = 2, |
| | | sample_rate: int = 8000, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | ncpu = kwargs.get("ncpu", 1) |
| | | torch.set_num_threads(ncpu) |
| | | if batch_size > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | logging.basicConfig( |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | if ngpu >= 1 and torch.cuda.is_available(): |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | batch_size = 1 |
| | | |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2. Build speech separator |
| | | speech_separator_kwargs = dict( |
| | | ss_infer_config=ss_infer_config, |
| | | ss_model_file=ss_model_file, |
| | | device=device, |
| | | dtype=dtype, |
| | | ) |
| | | logging.info("speech_separator_kwargs: {}".format(speech_separator_kwargs)) |
| | | speech_separator = SpeechSeparator(**speech_separator_kwargs) |
| | | |
| | | def _forward( |
| | | data_path_and_name_and_type, |
| | | raw_inputs: Union[np.ndarray, torch.Tensor] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | | fs: dict = None, |
| | | param_dict: dict = None |
| | | ): |
| | | # 3. Build data-iterator |
| | | if data_path_and_name_and_type is None and raw_inputs is not None: |
| | | if isinstance(raw_inputs, torch.Tensor): |
| | | raw_inputs = raw_inputs.numpy() |
| | | data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] |
| | | loader = build_streaming_iterator( |
| | | task_name="ss", |
| | | preprocess_args=None, |
| | | data_path_and_name_and_type=data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | fs=fs, |
| | | batch_size=batch_size, |
| | | num_workers=num_workers, |
| | | ) |
| | | |
| | | # 4 .Start for-loop |
| | | output_path = output_dir_v2 if output_dir_v2 is not None else output_dir |
| | | if not os.path.exists(output_path): |
| | | cmd = 'mkdir -p ' + output_path |
| | | os.system(cmd) |
| | | |
| | | for keys, batch in loader: |
| | | assert isinstance(batch, dict), type(batch) |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | _bs = len(next(iter(batch.values()))) |
| | | assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
| | | |
| | | # do speech separation |
| | | logging.info('decoding: {}'.format(keys[0])) |
| | | ss_results = speech_separator(**batch) |
| | | |
| | | for spk in range(num_spks): |
| | | sf.write(os.path.join(output_path, keys[0].replace('.wav', '_s'+str(spk+1)+'.wav')), ss_results[spk], sample_rate) |
| | | torch.cuda.empty_cache() |
| | | return ss_results |
| | | |
| | | return _forward |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "mossformer": |
| | | return inference_ss(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="Speech Separator Decoding", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | |
| | | # Note(kamo): Use '_' instead of '-' as separator. |
| | | # '-' is confusing if written in yaml. |
| | | parser.add_argument( |
| | | "--log_level", |
| | | type=lambda x: x.upper(), |
| | | default="INFO", |
| | | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| | | help="The verbose level of logging", |
| | | ) |
| | | |
| | | parser.add_argument("--output_dir", type=str, required=True) |
| | | parser.add_argument( |
| | | "--ngpu", |
| | | type=int, |
| | | default=1, |
| | | help="The number of gpus. 0 indicates CPU mode", |
| | | ) |
| | | parser.add_argument( |
| | | "--njob", |
| | | type=int, |
| | | default=1, |
| | | help="The number of jobs for each gpu", |
| | | ) |
| | | parser.add_argument( |
| | | "--gpuid_list", |
| | | type=str, |
| | | default="2", |
| | | help="The visible gpus", |
| | | ) |
| | | parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| | | parser.add_argument( |
| | | "--dtype", |
| | | default="float32", |
| | | choices=["float16", "float32", "float64"], |
| | | help="Data type", |
| | | ) |
| | | parser.add_argument( |
| | | "--num_workers", |
| | | type=int, |
| | | default=1, |
| | | help="The number of workers used for DataLoader", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Input data related") |
| | | group.add_argument( |
| | | "--data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | required=True, |
| | | action="append", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("The model configuration related") |
| | | group.add_argument( |
| | | "--ss_infer_config", |
| | | type=str, |
| | | help="SS infer configuration", |
| | | ) |
| | | group.add_argument( |
| | | "--ss_model_file", |
| | | type=str, |
| | | help="SS model parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--ss_train_config", |
| | | type=str, |
| | | help="SS training configuration", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("The inference configuration related") |
| | | group.add_argument( |
| | | "--batch_size", |
| | | type=int, |
| | | default=1, |
| | | help="The batch size for inference", |
| | | ) |
| | | |
| | | parser.add_argument( |
| | | '--num-spks', dest='num_spks', type=int, default=2) |
| | | |
| | | parser.add_argument( |
| | | '--one-time-decode-length', dest='one_time_decode_length', type=int, |
| | | default=60, help='the max length (second) for one-time decoding') |
| | | |
| | | parser.add_argument( |
| | | '--decode-window', dest='decode_window', type=int, |
| | | default=1, help='segmental decoding window length (second)') |
| | | |
| | | parser.add_argument( |
| | | '--sample-rate', dest='sample_rate', type=int, default='8000') |
| | | return parser |
| | | |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | parser = get_parser() |
| | | parser.add_argument( |
| | | "--mode", |
| | | type=str, |
| | | default="mossformer", |
| | | help="The decoding mode", |
| | | ) |
| | | args = parser.parse_args(cmd) |
| | | kwargs = vars(args) |
| | | kwargs.pop("config", None) |
| | | |
| | | # set logging messages |
| | | logging.basicConfig( |
| | | level=args.log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | logging.info("Decoding args: {}".format(kwargs)) |
| | | |
| | | # gpu setting |
| | | if args.ngpu > 0: |
| | | jobid = int(args.output_dir.split(".")[-1]) |
| | | gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob] |
| | | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | | os.environ["CUDA_VISIBLE_DEVICES"] = gpuid |
| | | |
| | | inference_pipeline = inference_launch(**kwargs) |
| | | return inference_pipeline(kwargs["data_path_and_name_and_type"]) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | | |
| | |
| | | from funasr.build_utils.build_punc_model import build_punc_model |
| | | from funasr.build_utils.build_sv_model import build_sv_model |
| | | from funasr.build_utils.build_vad_model import build_vad_model |
| | | from funasr.build_utils.build_ss_model import build_ss_model |
| | | |
| | | |
| | | def build_model(args): |
| | |
| | | model = build_diar_model(args) |
| | | elif args.task_name == "sv": |
| | | model = build_sv_model(args) |
| | | elif args.task_name == "ss": |
| | | model = build_ss_model(args) |
| | | else: |
| | | raise NotImplementedError("Not supported task: {}".format(args.task_name)) |
| | | |
| | |
| | | from funasr.models.base_model import FunASRModel |
| | | |
| | | |
| | | def load_checkpoint(checkpoint_path, use_cuda=1): |
| | | if use_cuda: |
| | | checkpoint = torch.load(checkpoint_path) |
| | | else: |
| | | checkpoint = torch.load( |
| | | checkpoint_path, map_location=lambda storage, loc: storage) |
| | | return checkpoint |
| | | |
| | | def reload_ss_for_eval(model, checkpoint_path, use_cuda=False): |
| | | checkpoint = load_checkpoint(checkpoint_path, use_cuda) |
| | | model.load_state_dict(checkpoint['model'], strict=False) |
| | | |
| | | def build_model_from_file( |
| | | config_file: Union[Path, str] = None, |
| | | model_file: Union[Path, str] = None, |
| | |
| | | model.load_state_dict(model_dict) |
| | | else: |
| | | model_dict = torch.load(model_file, map_location=device) |
| | | if task_name == 'ss': |
| | | reload_ss_for_eval(model, model_file, use_cuda=True) |
| | | logging.info("model is loaded from path: {}".format(model_file)) |
| | | if task_name == "diar" and mode == "sond": |
| | | model_dict = fileter_model_dict(model_dict, model.state_dict()) |
| | | if task_name == "vad": |
| New file |
| | |
| | | from funasr.models.e2e_ss import MossFormer |
| | | |
| | | def build_ss_model(args): |
| | | model = MossFormer( |
| | | in_channels=args.encoder_embedding_dim, |
| | | out_channels=args.mossformer_sequence_dim, |
| | | num_blocks=args.num_mossformer_layer, |
| | | kernel_size=args.encoder_kernel_size, |
| | | norm=args.norm, |
| | | num_spks=args.num_spks, |
| | | skip_around_intra=args.skip_around_intra, |
| | | use_global_pos_enc=args.use_global_pos_enc, |
| | | max_length=args.max_length) |
| | | |
| | | return model |
| New file |
| | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | |
| | | class MossFormerDecoder(nn.ConvTranspose1d): |
| | | """A decoder layer that consists of ConvTranspose1d. |
| | | |
| | | Arguments |
| | | --------- |
| | | kernel_size : int |
| | | Length of filters. |
| | | in_channels : int |
| | | Number of input channels. |
| | | out_channels : int |
| | | Number of output channels. |
| | | |
| | | |
| | | Example |
| | | --------- |
| | | >>> x = torch.randn(2, 100, 1000) |
| | | >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1) |
| | | >>> h = decoder(x) |
| | | >>> h.shape |
| | | torch.Size([2, 1003]) |
| | | """ |
| | | |
| | | def __init__(self, *args, **kwargs): |
| | | super(MossFormerDecoder, self).__init__(*args, **kwargs) |
| | | |
| | | def forward(self, x): |
| | | """Return the decoded output. |
| | | |
| | | Arguments |
| | | --------- |
| | | x : torch.Tensor |
| | | Input tensor with dimensionality [B, N, L]. |
| | | where, B = Batchsize, |
| | | N = number of filters |
| | | L = time points |
| | | """ |
| | | |
| | | if x.dim() not in [2, 3]: |
| | | raise RuntimeError( |
| | | "{} accept 3/4D tensor as input".format(self.__name__) |
| | | ) |
| | | x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) |
| | | |
| | | if torch.squeeze(x).dim() == 1: |
| | | x = torch.squeeze(x, dim=1) |
| | | else: |
| | | x = torch.squeeze(x) |
| | | return x |
| | | |
| New file |
| | |
| | | import math |
| | | import torch |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | import copy |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet |
| | | from funasr.models.decoder.mossformer_decoder import MossFormerDecoder |
| | | |
| | | |
| | | class MossFormer(FunASRModel): |
| | | """The MossFormer model for separating input mixed speech into different speaker's speech. |
| | | |
| | | Arguments |
| | | --------- |
| | | in_channels : int |
| | | Number of channels at the output of the encoder. |
| | | out_channels : int |
| | | Number of channels that would be inputted to the intra and inter blocks. |
| | | num_blocks : int |
| | | Number of layers of Dual Computation Block. |
| | | norm : str |
| | | Normalization type. |
| | | num_spks : int |
| | | Number of sources (speakers). |
| | | skip_around_intra : bool |
| | | Skip connection around intra. |
| | | use_global_pos_enc : bool |
| | | Global positional encodings. |
| | | max_length : int |
| | | Maximum sequence length. |
| | | kernel_size: int |
| | | Encoder and decoder kernel size |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | in_channels=512, |
| | | out_channels=512, |
| | | num_blocks=24, |
| | | kernel_size=16, |
| | | norm="ln", |
| | | num_spks=2, |
| | | skip_around_intra=True, |
| | | use_global_pos_enc=True, |
| | | max_length=20000, |
| | | ): |
| | | super(MossFormer, self).__init__() |
| | | self.num_spks = num_spks |
| | | # Encoding |
| | | self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1) |
| | | |
| | | ##Compute Mask |
| | | self.mask_net = MossFormer_MaskNet( |
| | | in_channels=in_channels, |
| | | out_channels=out_channels, |
| | | num_blocks=num_blocks, |
| | | norm=norm, |
| | | num_spks=num_spks, |
| | | skip_around_intra=skip_around_intra, |
| | | use_global_pos_enc=use_global_pos_enc, |
| | | max_length=max_length, |
| | | ) |
| | | self.dec = MossFormerDecoder( |
| | | in_channels=out_channels, |
| | | out_channels=1, |
| | | kernel_size=kernel_size, |
| | | stride = kernel_size//2, |
| | | bias=False |
| | | ) |
| | | def forward(self, input): |
| | | x = self.enc(input) |
| | | mask = self.mask_net(x) |
| | | x = torch.stack([x] * self.num_spks) |
| | | sep_x = x * mask |
| | | |
| | | # Decoding |
| | | est_source = torch.cat( |
| | | [ |
| | | self.dec(sep_x[i]).unsqueeze(-1) |
| | | for i in range(self.num_spks) |
| | | ], |
| | | dim=-1, |
| | | ) |
| | | T_origin = input.size(1) |
| | | T_est = est_source.size(1) |
| | | if T_origin > T_est: |
| | | est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est)) |
| | | else: |
| | | est_source = est_source[:, :T_origin, :] |
| | | |
| | | out = [] |
| | | for spk in range(self.num_spks): |
| | | out.append(est_source[:,:,spk]) |
| | | return out |
| New file |
| | |
| | | import torch |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | |
| | | from rotary_embedding_torch import RotaryEmbedding |
| | | from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm |
| | | from funasr.modules.embedding import ScaledSinuEmbedding |
| | | from funasr.modules.mossformer import FLASH_ShareA_FFConvM |
| | | |
| | | |
| | | def select_norm(norm, dim, shape): |
| | | """Just a wrapper to select the normalization type. |
| | | """ |
| | | |
| | | if norm == "gln": |
| | | return GlobalLayerNorm(dim, shape, elementwise_affine=True) |
| | | if norm == "cln": |
| | | return CumulativeLayerNorm(dim, elementwise_affine=True) |
| | | if norm == "ln": |
| | | return nn.GroupNorm(1, dim, eps=1e-8) |
| | | else: |
| | | return nn.BatchNorm1d(dim) |
| | | |
| | | |
| | | class MossformerBlock(nn.Module): |
| | | def __init__( |
| | | self, |
| | | *, |
| | | dim, |
| | | depth, |
| | | group_size = 256, |
| | | query_key_dim = 128, |
| | | expansion_factor = 4., |
| | | causal = False, |
| | | attn_dropout = 0.1, |
| | | norm_type = 'scalenorm', |
| | | shift_tokens = True |
| | | ): |
| | | super().__init__() |
| | | assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm' |
| | | |
| | | if norm_type == 'scalenorm': |
| | | norm_klass = ScaleNorm |
| | | elif norm_type == 'layernorm': |
| | | norm_klass = nn.LayerNorm |
| | | |
| | | self.group_size = group_size |
| | | |
| | | rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim)) |
| | | # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J |
| | | self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)]) |
| | | |
| | | def forward( |
| | | self, |
| | | x, |
| | | *, |
| | | mask = None |
| | | ): |
| | | ii = 0 |
| | | for flash in self.layers: |
| | | x = flash(x, mask = mask) |
| | | ii = ii + 1 |
| | | return x |
| | | |
| | | |
| | | class MossFormer_MaskNet(nn.Module): |
| | | """The MossFormer module for computing output masks. |
| | | |
| | | Arguments |
| | | --------- |
| | | in_channels : int |
| | | Number of channels at the output of the encoder. |
| | | out_channels : int |
| | | Number of channels that would be inputted to the intra and inter blocks. |
| | | num_blocks : int |
| | | Number of layers of Dual Computation Block. |
| | | norm : str |
| | | Normalization type. |
| | | num_spks : int |
| | | Number of sources (speakers). |
| | | skip_around_intra : bool |
| | | Skip connection around intra. |
| | | use_global_pos_enc : bool |
| | | Global positional encodings. |
| | | max_length : int |
| | | Maximum sequence length. |
| | | |
| | | Example |
| | | --------- |
| | | >>> mossformer_block = MossFormerM(1, 64, 8) |
| | | >>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2) |
| | | >>> x = torch.randn(10, 64, 2000) |
| | | >>> x = mossformer_masknet(x) |
| | | >>> x.shape |
| | | torch.Size([2, 10, 64, 2000]) |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | in_channels, |
| | | out_channels, |
| | | num_blocks=24, |
| | | norm="ln", |
| | | num_spks=2, |
| | | skip_around_intra=True, |
| | | use_global_pos_enc=True, |
| | | max_length=20000, |
| | | ): |
| | | super(MossFormer_MaskNet, self).__init__() |
| | | self.num_spks = num_spks |
| | | self.num_blocks = num_blocks |
| | | self.norm = select_norm(norm, in_channels, 3) |
| | | self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False) |
| | | self.use_global_pos_enc = use_global_pos_enc |
| | | |
| | | if self.use_global_pos_enc: |
| | | self.pos_enc = ScaledSinuEmbedding(out_channels) |
| | | |
| | | self.mdl = Computation_Block( |
| | | num_blocks, |
| | | out_channels, |
| | | norm, |
| | | skip_around_intra=skip_around_intra, |
| | | ) |
| | | |
| | | self.conv1d_out = nn.Conv1d( |
| | | out_channels, out_channels * num_spks, kernel_size=1 |
| | | ) |
| | | self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False) |
| | | self.prelu = nn.PReLU() |
| | | self.activation = nn.ReLU() |
| | | # gated output layer |
| | | self.output = nn.Sequential( |
| | | nn.Conv1d(out_channels, out_channels, 1), nn.Tanh() |
| | | ) |
| | | self.output_gate = nn.Sequential( |
| | | nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid() |
| | | ) |
| | | |
| | | def forward(self, x): |
| | | """Returns the output tensor. |
| | | |
| | | Arguments |
| | | --------- |
| | | x : torch.Tensor |
| | | Input tensor of dimension [B, N, S]. |
| | | |
| | | Returns |
| | | ------- |
| | | out : torch.Tensor |
| | | Output tensor of dimension [spks, B, N, S] |
| | | where, spks = Number of speakers |
| | | B = Batchsize, |
| | | N = number of filters |
| | | S = the number of time frames |
| | | """ |
| | | |
| | | # before each line we indicate the shape after executing the line |
| | | |
| | | # [B, N, L] |
| | | x = self.norm(x) |
| | | |
| | | # [B, N, L] |
| | | x = self.conv1d_encoder(x) |
| | | if self.use_global_pos_enc: |
| | | #x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * ( |
| | | # x.size(1) ** 0.5) |
| | | base = x |
| | | x = x.transpose(1, -1) |
| | | emb = self.pos_enc(x) |
| | | emb = emb.transpose(0, -1) |
| | | #print('base: {}, emb: {}'.format(base.shape, emb.shape)) |
| | | x = base + emb |
| | | |
| | | |
| | | # [B, N, S] |
| | | #for i in range(self.num_modules): |
| | | # x = self.dual_mdl[i](x) |
| | | x = self.mdl(x) |
| | | x = self.prelu(x) |
| | | |
| | | # [B, N*spks, S] |
| | | x = self.conv1d_out(x) |
| | | B, _, S = x.shape |
| | | |
| | | # [B*spks, N, S] |
| | | x = x.view(B * self.num_spks, -1, S) |
| | | |
| | | # [B*spks, N, S] |
| | | x = self.output(x) * self.output_gate(x) |
| | | |
| | | # [B*spks, N, S] |
| | | x = self.conv1_decoder(x) |
| | | |
| | | # [B, spks, N, S] |
| | | _, N, L = x.shape |
| | | x = x.view(B, self.num_spks, N, L) |
| | | x = self.activation(x) |
| | | |
| | | # [spks, B, N, S] |
| | | x = x.transpose(0, 1) |
| | | |
| | | return x |
| | | |
| | | |
| | | class MossFormerEncoder(nn.Module): |
| | | """Convolutional Encoder Layer. |
| | | |
| | | Arguments |
| | | --------- |
| | | kernel_size : int |
| | | Length of filters. |
| | | in_channels : int |
| | | Number of input channels. |
| | | out_channels : int |
| | | Number of output channels. |
| | | |
| | | Example |
| | | ------- |
| | | >>> x = torch.randn(2, 1000) |
| | | >>> encoder = Encoder(kernel_size=4, out_channels=64) |
| | | >>> h = encoder(x) |
| | | >>> h.shape |
| | | torch.Size([2, 64, 499]) |
| | | """ |
| | | |
| | | def __init__(self, kernel_size=2, out_channels=64, in_channels=1): |
| | | super(MossFormerEncoder, self).__init__() |
| | | self.conv1d = nn.Conv1d( |
| | | in_channels=in_channels, |
| | | out_channels=out_channels, |
| | | kernel_size=kernel_size, |
| | | stride=kernel_size // 2, |
| | | groups=1, |
| | | bias=False, |
| | | ) |
| | | self.in_channels = in_channels |
| | | |
| | | def forward(self, x): |
| | | """Return the encoded output. |
| | | |
| | | Arguments |
| | | --------- |
| | | x : torch.Tensor |
| | | Input tensor with dimensionality [B, L]. |
| | | Return |
| | | ------ |
| | | x : torch.Tensor |
| | | Encoded tensor with dimensionality [B, N, T_out]. |
| | | |
| | | where B = Batchsize |
| | | L = Number of timepoints |
| | | N = Number of filters |
| | | T_out = Number of timepoints at the output of the encoder |
| | | """ |
| | | # B x L -> B x 1 x L |
| | | if self.in_channels == 1: |
| | | x = torch.unsqueeze(x, dim=1) |
| | | # B x 1 x L -> B x N x T_out |
| | | x = self.conv1d(x) |
| | | x = F.relu(x) |
| | | |
| | | return x |
| | | |
| | | class MossFormerM(nn.Module): |
| | | """This class implements the transformer encoder. |
| | | |
| | | Arguments |
| | | --------- |
| | | num_blocks : int |
| | | Number of mossformer blocks to include. |
| | | d_model : int |
| | | The dimension of the input embedding. |
| | | attn_dropout : float |
| | | Dropout for the self-attention (Optional). |
| | | group_size: int |
| | | the chunk size |
| | | query_key_dim: int |
| | | the attention vector dimension |
| | | expansion_factor: int |
| | | the expansion factor for the linear projection in conv module |
| | | causal: bool |
| | | true for causal / false for non causal |
| | | |
| | | Example |
| | | ------- |
| | | >>> import torch |
| | | >>> x = torch.rand((8, 60, 512)) |
| | | >>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512) |
| | | >>> output, _ = net(x) |
| | | >>> output.shape |
| | | torch.Size([8, 60, 512]) |
| | | """ |
| | | def __init__( |
| | | self, |
| | | num_blocks, |
| | | d_model=None, |
| | | causal=False, |
| | | group_size = 256, |
| | | query_key_dim = 128, |
| | | expansion_factor = 4., |
| | | attn_dropout = 0.1 |
| | | ): |
| | | super().__init__() |
| | | |
| | | self.mossformerM = MossformerBlock( |
| | | dim=d_model, |
| | | depth=num_blocks, |
| | | group_size=group_size, |
| | | query_key_dim=query_key_dim, |
| | | expansion_factor=expansion_factor, |
| | | causal=causal, |
| | | attn_dropout=attn_dropout |
| | | ) |
| | | self.norm = nn.LayerNorm(d_model, eps=1e-6) |
| | | |
| | | def forward( |
| | | self, |
| | | src, |
| | | ): |
| | | """ |
| | | Arguments |
| | | ---------- |
| | | src : torch.Tensor |
| | | Tensor shape [B, L, N], |
| | | where, B = Batchsize, |
| | | L = time points |
| | | N = number of filters |
| | | The sequence to the encoder layer (required). |
| | | src_mask : tensor |
| | | The mask for the src sequence (optional). |
| | | src_key_padding_mask : tensor |
| | | The mask for the src keys per batch (optional). |
| | | """ |
| | | output = self.mossformerM(src) |
| | | output = self.norm(output) |
| | | |
| | | return output |
| | | |
| | | |
| | | class Computation_Block(nn.Module): |
| | | """Computation block for dual-path processing. |
| | | |
| | | Arguments |
| | | --------- |
| | | out_channels : int |
| | | Dimensionality of inter/intra model. |
| | | norm : str |
| | | Normalization type. |
| | | skip_around_intra : bool |
| | | Skip connection around the intra layer. |
| | | |
| | | Example |
| | | --------- |
| | | >>> comp_block = Computation_Block(64) |
| | | >>> x = torch.randn(10, 64, 100) |
| | | >>> x = comp_block(x) |
| | | >>> x.shape |
| | | torch.Size([10, 64, 100]) |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | num_blocks, |
| | | out_channels, |
| | | norm="ln", |
| | | skip_around_intra=True, |
| | | ): |
| | | super(Computation_Block, self).__init__() |
| | | |
| | | ##MossFormer2M: MossFormer with recurrence |
| | | #self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels) |
| | | ##MossFormerM: the orignal MossFormer |
| | | self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels) |
| | | self.skip_around_intra = skip_around_intra |
| | | |
| | | # Norm |
| | | self.norm = norm |
| | | if norm is not None: |
| | | self.intra_norm = select_norm(norm, out_channels, 3) |
| | | |
| | | def forward(self, x): |
| | | """Returns the output tensor. |
| | | |
| | | Arguments |
| | | --------- |
| | | x : torch.Tensor |
| | | Input tensor of dimension [B, N, S]. |
| | | |
| | | |
| | | Return |
| | | --------- |
| | | out: torch.Tensor |
| | | Output tensor of dimension [B, N, S]. |
| | | where, B = Batchsize, |
| | | N = number of filters |
| | | S = sequence time index |
| | | """ |
| | | B, N, S = x.shape |
| | | # intra RNN |
| | | # [B, S, N] |
| | | intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N) |
| | | |
| | | intra = self.intra_mdl(intra) |
| | | |
| | | # [B, N, S] |
| | | intra = intra.permute(0, 2, 1).contiguous() |
| | | if self.norm is not None: |
| | | intra = self.intra_norm(intra) |
| | | |
| | | # [B, N, S] |
| | | if self.skip_around_intra: |
| | | intra = intra + x |
| | | |
| | | out = intra |
| | | return out |
| | | |
| | |
| | | import math |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from torch import einsum |
| | | |
| | | def _pre_hook( |
| | | state_dict, |
| | |
| | | pos_enc = self.dropout(pos_enc) |
| | | |
| | | return pos_enc |
| | | |
| | | |
| | | class ScaledSinuEmbedding(torch.nn.Module): |
| | | def __init__(self, dim): |
| | | super().__init__() |
| | | self.scale = torch.nn.Parameter(torch.ones(1,)) |
| | | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
| | | self.register_buffer('inv_freq', inv_freq) |
| | | |
| | | def forward(self, x): |
| | | n, device = x.shape[1], x.device |
| | | t = torch.arange(n, device = device).type_as(self.inv_freq) |
| | | sinu = einsum('i , j -> i j', t, self.inv_freq) |
| | | emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1) |
| | | return emb * self.scale |
| | | |
| | |
| | | """Layer normalization module.""" |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | |
| | | class LayerNorm(torch.nn.LayerNorm): |
| | |
| | | .forward(x.transpose(self.dim, -1)) |
| | | .transpose(self.dim, -1) |
| | | ) |
| | | |
| | | |
| | | class GlobalLayerNorm(nn.Module): |
| | | """Calculate Global Layer Normalization. |
| | | |
| | | Arguments |
| | | --------- |
| | | dim : (int or list or torch.Size) |
| | | Input shape from an expected input of size. |
| | | eps : float |
| | | A value added to the denominator for numerical stability. |
| | | elementwise_affine : bool |
| | | A boolean value that when set to True, |
| | | this module has learnable per-element affine parameters |
| | | initialized to ones (for weights) and zeros (for biases). |
| | | |
| | | Example |
| | | ------- |
| | | >>> x = torch.randn(5, 10, 20) |
| | | >>> GLN = GlobalLayerNorm(10, 3) |
| | | >>> x_norm = GLN(x) |
| | | """ |
| | | |
| | | def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True): |
| | | super(GlobalLayerNorm, self).__init__() |
| | | self.dim = dim |
| | | self.eps = eps |
| | | self.elementwise_affine = elementwise_affine |
| | | |
| | | if self.elementwise_affine: |
| | | if shape == 3: |
| | | self.weight = nn.Parameter(torch.ones(self.dim, 1)) |
| | | self.bias = nn.Parameter(torch.zeros(self.dim, 1)) |
| | | if shape == 4: |
| | | self.weight = nn.Parameter(torch.ones(self.dim, 1, 1)) |
| | | self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1)) |
| | | else: |
| | | self.register_parameter("weight", None) |
| | | self.register_parameter("bias", None) |
| | | |
| | | def forward(self, x): |
| | | """Returns the normalized tensor. |
| | | |
| | | Arguments |
| | | --------- |
| | | x : torch.Tensor |
| | | Tensor of size [N, C, K, S] or [N, C, L]. |
| | | """ |
| | | # x = N x C x K x S or N x C x L |
| | | # N x 1 x 1 |
| | | # cln: mean,var N x 1 x K x S |
| | | # gln: mean,var N x 1 x 1 |
| | | if x.dim() == 3: |
| | | mean = torch.mean(x, (1, 2), keepdim=True) |
| | | var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True) |
| | | if self.elementwise_affine: |
| | | x = ( |
| | | self.weight * (x - mean) / torch.sqrt(var + self.eps) |
| | | + self.bias |
| | | ) |
| | | else: |
| | | x = (x - mean) / torch.sqrt(var + self.eps) |
| | | |
| | | if x.dim() == 4: |
| | | mean = torch.mean(x, (1, 2, 3), keepdim=True) |
| | | var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True) |
| | | if self.elementwise_affine: |
| | | x = ( |
| | | self.weight * (x - mean) / torch.sqrt(var + self.eps) |
| | | + self.bias |
| | | ) |
| | | else: |
| | | x = (x - mean) / torch.sqrt(var + self.eps) |
| | | return x |
| | | |
| | | |
| | | class CumulativeLayerNorm(nn.LayerNorm): |
| | | """Calculate Cumulative Layer Normalization. |
| | | |
| | | Arguments |
| | | --------- |
| | | dim : int |
| | | Dimension that you want to normalize. |
| | | elementwise_affine : True |
| | | Learnable per-element affine parameters. |
| | | |
| | | Example |
| | | ------- |
| | | >>> x = torch.randn(5, 10, 20) |
| | | >>> CLN = CumulativeLayerNorm(10) |
| | | >>> x_norm = CLN(x) |
| | | """ |
| | | |
| | | def __init__(self, dim, elementwise_affine=True): |
| | | super(CumulativeLayerNorm, self).__init__( |
| | | dim, elementwise_affine=elementwise_affine, eps=1e-8 |
| | | ) |
| | | |
| | | def forward(self, x): |
| | | """Returns the normalized tensor. |
| | | |
| | | Arguments |
| | | --------- |
| | | x : torch.Tensor |
| | | Tensor size [N, C, K, S] or [N, C, L] |
| | | """ |
| | | # x: N x C x K x S or N x C x L |
| | | # N x K x S x C |
| | | if x.dim() == 4: |
| | | x = x.permute(0, 2, 3, 1).contiguous() |
| | | # N x K x S x C == only channel norm |
| | | x = super().forward(x) |
| | | # N x C x K x S |
| | | x = x.permute(0, 3, 1, 2).contiguous() |
| | | if x.dim() == 3: |
| | | x = torch.transpose(x, 1, 2) |
| | | # N x L x C == only channel norm |
| | | x = super().forward(x) |
| | | # N x C x L |
| | | x = torch.transpose(x, 1, 2) |
| | | return x |
| | | |
| | | |
| | | class ScaleNorm(nn.Module): |
| | | def __init__(self, dim, eps = 1e-5): |
| | | super().__init__() |
| | | self.scale = dim ** -0.5 |
| | | self.eps = eps |
| | | self.g = nn.Parameter(torch.ones(1)) |
| | | |
| | | def forward(self, x): |
| | | norm = torch.norm(x, dim = -1, keepdim = True) * self.scale |
| | | return x / norm.clamp(min = self.eps) * self.g |
| | | |
| New file |
| | |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from torch import nn, einsum |
| | | from einops import rearrange |
| | | |
| | | |
| | | def identity(t, *args, **kwargs): |
| | | return t |
| | | |
| | | def append_dims(x, num_dims): |
| | | if num_dims <= 0: |
| | | return x |
| | | return x.view(*x.shape, *((1,) * num_dims)) |
| | | |
| | | def exists(val): |
| | | return val is not None |
| | | |
| | | def default(val, d): |
| | | return val if exists(val) else d |
| | | |
| | | def padding_to_multiple_of(n, mult): |
| | | remainder = n % mult |
| | | if remainder == 0: |
| | | return 0 |
| | | return mult - remainder |
| | | |
| | | |
| | | class Transpose(nn.Module): |
| | | """ Wrapper class of torch.transpose() for Sequential module. """ |
| | | def __init__(self, shape: tuple): |
| | | super(Transpose, self).__init__() |
| | | self.shape = shape |
| | | |
| | | def forward(self, x): |
| | | return x.transpose(*self.shape) |
| | | |
| | | |
| | | class DepthwiseConv1d(nn.Module): |
| | | """ |
| | | When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, |
| | | this operation is termed in literature as depthwise convolution. |
| | | Args: |
| | | in_channels (int): Number of channels in the input |
| | | out_channels (int): Number of channels produced by the convolution |
| | | kernel_size (int or tuple): Size of the convolving kernel |
| | | stride (int, optional): Stride of the convolution. Default: 1 |
| | | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 |
| | | bias (bool, optional): If True, adds a learnable bias to the output. Default: True |
| | | Inputs: inputs |
| | | - **inputs** (batch, in_channels, time): Tensor containing input vector |
| | | Returns: outputs |
| | | - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. |
| | | """ |
| | | def __init__( |
| | | self, |
| | | in_channels: int, |
| | | out_channels: int, |
| | | kernel_size: int, |
| | | stride: int = 1, |
| | | padding: int = 0, |
| | | bias: bool = False, |
| | | ) -> None: |
| | | super(DepthwiseConv1d, self).__init__() |
| | | assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels" |
| | | self.conv = nn.Conv1d( |
| | | in_channels=in_channels, |
| | | out_channels=out_channels, |
| | | kernel_size=kernel_size, |
| | | groups=in_channels, |
| | | stride=stride, |
| | | padding=padding, |
| | | bias=bias, |
| | | ) |
| | | |
| | | def forward(self, inputs): |
| | | return self.conv(inputs) |
| | | |
| | | |
| | | class ConvModule(nn.Module): |
| | | """ |
| | | Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU). |
| | | This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution |
| | | to aid training deep models. |
| | | Args: |
| | | in_channels (int): Number of channels in the input |
| | | kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 |
| | | dropout_p (float, optional): probability of dropout |
| | | Inputs: inputs |
| | | inputs (batch, time, dim): Tensor contains input sequences |
| | | Outputs: outputs |
| | | outputs (batch, time, dim): Tensor produces by conformer convolution module. |
| | | """ |
| | | def __init__( |
| | | self, |
| | | in_channels: int, |
| | | kernel_size: int = 17, |
| | | expansion_factor: int = 2, |
| | | dropout_p: float = 0.1, |
| | | ) -> None: |
| | | super(ConvModule, self).__init__() |
| | | assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" |
| | | assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" |
| | | |
| | | self.sequential = nn.Sequential( |
| | | Transpose(shape=(1, 2)), |
| | | DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2), |
| | | ) |
| | | |
| | | def forward(self, inputs): |
| | | return inputs + self.sequential(inputs).transpose(1, 2) |
| | | |
| | | |
| | | class OffsetScale(nn.Module): |
| | | def __init__(self, dim, heads = 1): |
| | | super().__init__() |
| | | self.gamma = nn.Parameter(torch.ones(heads, dim)) |
| | | self.beta = nn.Parameter(torch.zeros(heads, dim)) |
| | | nn.init.normal_(self.gamma, std = 0.02) |
| | | |
| | | def forward(self, x): |
| | | out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta |
| | | return out.unbind(dim = -2) |
| | | |
| | | |
| | | class FFConvM(nn.Module): |
| | | def __init__( |
| | | self, |
| | | dim_in, |
| | | dim_out, |
| | | norm_klass = nn.LayerNorm, |
| | | dropout = 0.1 |
| | | ): |
| | | super().__init__() |
| | | self.mdl = nn.Sequential( |
| | | norm_klass(dim_in), |
| | | nn.Linear(dim_in, dim_out), |
| | | nn.SiLU(), |
| | | ConvModule(dim_out), |
| | | nn.Dropout(dropout) |
| | | ) |
| | | def forward( |
| | | self, |
| | | x, |
| | | ): |
| | | output = self.mdl(x) |
| | | return output |
| | | |
| | | |
| | | class FLASH_ShareA_FFConvM(nn.Module): |
| | | def __init__( |
| | | self, |
| | | *, |
| | | dim, |
| | | group_size = 256, |
| | | query_key_dim = 128, |
| | | expansion_factor = 1., |
| | | causal = False, |
| | | dropout = 0.1, |
| | | rotary_pos_emb = None, |
| | | norm_klass = nn.LayerNorm, |
| | | shift_tokens = True |
| | | ): |
| | | super().__init__() |
| | | hidden_dim = int(dim * expansion_factor) |
| | | self.group_size = group_size |
| | | self.causal = causal |
| | | self.shift_tokens = shift_tokens |
| | | |
| | | # positional embeddings |
| | | self.rotary_pos_emb = rotary_pos_emb |
| | | # norm |
| | | self.dropout = nn.Dropout(dropout) |
| | | # projections |
| | | |
| | | self.to_hidden = FFConvM( |
| | | dim_in = dim, |
| | | dim_out = hidden_dim, |
| | | norm_klass = norm_klass, |
| | | dropout = dropout, |
| | | ) |
| | | self.to_qk = FFConvM( |
| | | dim_in = dim, |
| | | dim_out = query_key_dim, |
| | | norm_klass = norm_klass, |
| | | dropout = dropout, |
| | | ) |
| | | |
| | | self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4) |
| | | |
| | | self.to_out = FFConvM( |
| | | dim_in = dim*2, |
| | | dim_out = dim, |
| | | norm_klass = norm_klass, |
| | | dropout = dropout, |
| | | ) |
| | | |
| | | self.gateActivate=nn.Sigmoid() |
| | | |
| | | def forward( |
| | | self, |
| | | x, |
| | | *, |
| | | mask = None |
| | | ): |
| | | |
| | | """ |
| | | b - batch |
| | | n - sequence length (within groups) |
| | | g - group dimension |
| | | d - feature dimension (keys) |
| | | e - feature dimension (values) |
| | | i - sequence dimension (source) |
| | | j - sequence dimension (target) |
| | | """ |
| | | |
| | | normed_x = x |
| | | |
| | | # do token shift - a great, costless trick from an independent AI researcher in Shenzhen |
| | | residual = x |
| | | |
| | | if self.shift_tokens: |
| | | x_shift, x_pass = normed_x.chunk(2, dim = -1) |
| | | x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.) |
| | | normed_x = torch.cat((x_shift, x_pass), dim = -1) |
| | | |
| | | # initial projections |
| | | |
| | | v, u = self.to_hidden(normed_x).chunk(2, dim = -1) |
| | | qk = self.to_qk(normed_x) |
| | | |
| | | # offset and scale |
| | | quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk) |
| | | att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u) |
| | | out = (att_u*v ) * self.gateActivate(att_v*u) |
| | | x = x + self.to_out(out) |
| | | return x |
| | | |
| | | def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None): |
| | | b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size |
| | | |
| | | if exists(mask): |
| | | lin_mask = rearrange(mask, '... -> ... 1') |
| | | lin_k = lin_k.masked_fill(~lin_mask, 0.) |
| | | |
| | | # rotate queries and keys |
| | | |
| | | if exists(self.rotary_pos_emb): |
| | | quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k)) |
| | | |
| | | # padding for groups |
| | | |
| | | padding = padding_to_multiple_of(n, g) |
| | | |
| | | if padding > 0: |
| | | quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u)) |
| | | |
| | | mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool)) |
| | | mask = F.pad(mask, (0, padding), value = False) |
| | | |
| | | # group along sequence |
| | | |
| | | quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u)) |
| | | |
| | | if exists(mask): |
| | | mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g) |
| | | |
| | | # calculate quadratic attention output |
| | | |
| | | sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g |
| | | |
| | | attn = F.relu(sim) ** 2 |
| | | attn = self.dropout(attn) |
| | | |
| | | if exists(mask): |
| | | attn = attn.masked_fill(~mask, 0.) |
| | | |
| | | if self.causal: |
| | | causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1) |
| | | attn = attn.masked_fill(causal_mask, 0.) |
| | | |
| | | quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v) |
| | | quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u) |
| | | |
| | | # calculate linear attention output |
| | | |
| | | if self.causal: |
| | | lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g |
| | | # exclusive cumulative sum along group dimension |
| | | lin_kv = lin_kv.cumsum(dim = 1) |
| | | lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.) |
| | | lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q) |
| | | |
| | | lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g |
| | | # exclusive cumulative sum along group dimension |
| | | lin_ku = lin_ku.cumsum(dim = 1) |
| | | lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.) |
| | | lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q) |
| | | else: |
| | | lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n |
| | | lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv) |
| | | |
| | | lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n |
| | | lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku) |
| | | |
| | | # fold back groups into full sequence, and excise out padding |
| | | return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u)) |
| | | |