hnluo
2023-08-10 bce72487636cf84c463381096216e995deb1920d
add mossformer code
4个文件已修改
7个文件已添加
1437 ■■■■■ 已修改文件
funasr/bin/ss_infer.py 127 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/ss_inference_launch.py 253 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_model.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_model_from_file.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_ss_model.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/decoder/mossformer_decoder.py 53 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_ss.py 95 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/mossformer_encoder.py 417 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/embedding.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/layer_norm.py 135 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/mossformer.py 307 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/ss_infer.py
New file
@@ -0,0 +1,127 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import logging
from pathlib import Path
from typing import List
from typing import Union
import numpy as np
import torch
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.torch_utils.device_funcs import to_device
class SpeechSeparator:
    """SpeechSeparator class
    Examples:
        >>> import soundfile
        >>> speech_separator = MossFormer("ss_config.yml", "ss.pt")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> separated_wavs = speech_separator(audio)
    """
    def __init__(
            self,
            ss_infer_config: Union[Path, str] = None,
            ss_model_file: Union[Path, str] = None,
            device: str = "cpu",
            batch_size: int = 1,
            dtype: str = "float32",
            **kwargs,
    ):
        # 1. Build ss model
        ss_model, ss_infer_args = build_model_from_file(
            ss_infer_config, ss_model_file, None, device, task_name="ss"
        )
        logging.info("ss_model: {}".format(ss_model))
        logging.info("ss_infer_args: {}".format(ss_infer_args))
        ss_model.to(dtype=getattr(torch, dtype)).eval()
        self.ss_model = ss_model
        self.ss_infer_args = ss_infer_args
        self.device = device
        self.dtype = dtype
        self.batch_size = batch_size
    def decode(self, model, args, inputs, nsamples):
        decode_do_segment = False
        with torch.no_grad():
            out = []
            window = args.sample_rate * args.decode_window  # decoding window length
            stride = int(window*0.75)  # decoding stride if segmentation is used
            b, t = inputs.shape
            if t > window * args.one_time_decode_length:
                decode_do_segment = True  # set segment decoding to true for very long sequence
            if t < window:
                inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], window-t))], 1)
            elif t < window + stride:
                padding = window + stride - t
                inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
            else:
                if (t - window) % stride != 0:
                    padding = t - (t-window)//stride * stride
                    inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
            inputs = torch.from_numpy(np.float32(inputs))
            inputs = to_device(inputs, device=self.device)
            b, t = inputs.shape
            if decode_do_segment:
                outputs = np.zeros((args.num_spks, t))
                give_up_length = (window - stride)//2
                current_idx = 0
                while current_idx + window <= t:
                    tmp_input = inputs[:, current_idx:current_idx+window]
                    tmp_out_list = model(tmp_input,)
                    for spk in range(args.num_spks):
                        tmp_out_list[spk] = tmp_out_list[spk][0, :].cpu().numpy()
                        if current_idx == 0:
                            outputs[spk, current_idx:current_idx+window-give_up_length] = \
                                tmp_out_list[spk][:-give_up_length]
                        else:
                            outputs[spk, current_idx+give_up_length:current_idx+window-give_up_length] = \
                                tmp_out_list[spk][give_up_length:-give_up_length]
                    current_idx += stride
                for spk in range(args.num_spks):
                    out.append(outputs[spk, :])
            else:
                out_list = model(inputs)
                for spk in range(args.num_spks):
                    out.append(out_list[spk][0, :].cpu().numpy())
            max_abs = 0
            for spk in range(args.num_spks):
                if max_abs < max(abs(out[spk])):
                    max_abs = max(abs(out[spk]))
            for spk in range(args.num_spks):
                out[spk] = out[spk][:nsamples]
                out[spk] = out[spk]/max_abs
        return out
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
    ) -> List[torch.Tensor]:
        """Inference
        Args:
            speech: Input speech data
        Returns:
            speech list: list of speech data
        """
        out = self.decode(self.ss_model, self.ss_infer_args, speech, speech_lengths)
        return out
funasr/bin/ss_inference_launch.py
New file
@@ -0,0 +1,253 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from typing import Optional
from typing import Union
import numpy as np
import torch
import soundfile as sf
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2triple_str
from funasr.bin.ss_infer import SpeechSeparator
def inference_ss(
        batch_size: int,
        ngpu: int,
        log_level: Union[int, str],
        ss_infer_config: Optional[str],
        ss_model_file: Optional[str],
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        num_workers: int = 1,
        num_spks: int = 2,
        sample_rate: int = 8000,
        param_dict: dict = None,
        **kwargs,
):
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
        batch_size = 1
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech separator
    speech_separator_kwargs = dict(
        ss_infer_config=ss_infer_config,
        ss_model_file=ss_model_file,
        device=device,
        dtype=dtype,
    )
    logging.info("speech_separator_kwargs: {}".format(speech_separator_kwargs))
    speech_separator = SpeechSeparator(**speech_separator_kwargs)
    def _forward(
            data_path_and_name_and_type,
            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
            output_dir_v2: Optional[str] = None,
            fs: dict = None,
            param_dict: dict = None
    ):
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = build_streaming_iterator(
            task_name="ss",
            preprocess_args=None,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            batch_size=batch_size,
            num_workers=num_workers,
        )
        # 4 .Start for-loop
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if not os.path.exists(output_path):
            cmd = 'mkdir -p ' + output_path
            os.system(cmd)
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # do speech separation
            logging.info('decoding: {}'.format(keys[0]))
            ss_results = speech_separator(**batch)
            for spk in range(num_spks):
                sf.write(os.path.join(output_path, keys[0].replace('.wav', '_s'+str(spk+1)+'.wav')), ss_results[spk], sample_rate)
        torch.cuda.empty_cache()
        return ss_results
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "mossformer":
        return inference_ss(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Speech Separator Decoding",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # Note(kamo): Use '_' instead of '-' as separator.
    # '-' is confusing if written in yaml.
    parser.add_argument(
        "--log_level",
        type=lambda x: x.upper(),
        default="INFO",
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument(
        "--ngpu",
        type=int,
        default=1,
        help="The number of gpus. 0 indicates CPU mode",
    )
    parser.add_argument(
        "--njob",
        type=int,
        default=1,
        help="The number of jobs for each gpu",
    )
    parser.add_argument(
        "--gpuid_list",
        type=str,
        default="2",
        help="The visible gpus",
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument(
        "--dtype",
        default="float32",
        choices=["float16", "float32", "float64"],
        help="Data type",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=1,
        help="The number of workers used for DataLoader",
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument(
        "--data_path_and_name_and_type",
        type=str2triple_str,
        required=True,
        action="append",
    )
    group = parser.add_argument_group("The model configuration related")
    group.add_argument(
        "--ss_infer_config",
        type=str,
        help="SS infer configuration",
    )
    group.add_argument(
        "--ss_model_file",
        type=str,
        help="SS model parameter file",
    )
    group.add_argument(
        "--ss_train_config",
        type=str,
        help="SS training configuration",
    )
    group = parser.add_argument_group("The inference configuration related")
    group.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="The batch size for inference",
    )
    parser.add_argument(
        '--num-spks', dest='num_spks', type=int, default=2)
    parser.add_argument(
        '--one-time-decode-length', dest='one_time_decode_length', type=int,
        default=60, help='the max length (second) for one-time decoding')
    parser.add_argument(
        '--decode-window', dest='decode_window', type=int,
        default=1, help='segmental decoding window length (second)')
    parser.add_argument(
        '--sample-rate', dest='sample_rate', type=int, default='8000')
    return parser
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    parser.add_argument(
        "--mode",
        type=str,
        default="mossformer",
        help="The decoding mode",
    )
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    kwargs.pop("config", None)
    # set logging messages
    logging.basicConfig(
        level=args.log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    logging.info("Decoding args: {}".format(kwargs))
    # gpu setting
    if args.ngpu > 0:
        jobid = int(args.output_dir.split(".")[-1])
        gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
    inference_pipeline = inference_launch(**kwargs)
    return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":
    main()
funasr/build_utils/build_model.py
@@ -5,6 +5,7 @@
from funasr.build_utils.build_punc_model import build_punc_model
from funasr.build_utils.build_sv_model import build_sv_model
from funasr.build_utils.build_vad_model import build_vad_model
from funasr.build_utils.build_ss_model import build_ss_model
def build_model(args):
@@ -22,6 +23,8 @@
        model = build_diar_model(args)
    elif args.task_name == "sv":
        model = build_sv_model(args)
    elif args.task_name == "ss":
        model = build_ss_model(args)
    else:
        raise NotImplementedError("Not supported task: {}".format(args.task_name))
funasr/build_utils/build_model_from_file.py
@@ -11,6 +11,18 @@
from funasr.models.base_model import FunASRModel
def load_checkpoint(checkpoint_path, use_cuda=1):
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(
            checkpoint_path, map_location=lambda storage, loc: storage)
    return checkpoint
def reload_ss_for_eval(model, checkpoint_path, use_cuda=False):
    checkpoint = load_checkpoint(checkpoint_path, use_cuda)
    model.load_state_dict(checkpoint['model'], strict=False)
def build_model_from_file(
        config_file: Union[Path, str] = None,
        model_file: Union[Path, str] = None,
@@ -70,6 +82,9 @@
            model.load_state_dict(model_dict)
        else:
            model_dict = torch.load(model_file, map_location=device)
    if task_name == 'ss':
        reload_ss_for_eval(model, model_file, use_cuda=True)
        logging.info("model is loaded from path: {}".format(model_file))
    if task_name == "diar" and mode == "sond":
        model_dict = fileter_model_dict(model_dict, model.state_dict())
    if task_name == "vad":
funasr/build_utils/build_ss_model.py
New file
@@ -0,0 +1,15 @@
from funasr.models.e2e_ss import MossFormer
def build_ss_model(args):
    model = MossFormer(
        in_channels=args.encoder_embedding_dim,
        out_channels=args.mossformer_sequence_dim,
        num_blocks=args.num_mossformer_layer,
        kernel_size=args.encoder_kernel_size,
        norm=args.norm,
        num_spks=args.num_spks,
        skip_around_intra=args.skip_around_intra,
        use_global_pos_enc=args.use_global_pos_enc,
        max_length=args.max_length)
    return model
funasr/models/decoder/mossformer_decoder.py
New file
@@ -0,0 +1,53 @@
import torch
import torch.nn as nn
class MossFormerDecoder(nn.ConvTranspose1d):
    """A decoder layer that consists of ConvTranspose1d.
    Arguments
    ---------
    kernel_size : int
        Length of filters.
    in_channels : int
        Number of  input channels.
    out_channels : int
        Number of output channels.
    Example
    ---------
    >>> x = torch.randn(2, 100, 1000)
    >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
    >>> h = decoder(x)
    >>> h.shape
    torch.Size([2, 1003])
    """
    def __init__(self, *args, **kwargs):
        super(MossFormerDecoder, self).__init__(*args, **kwargs)
    def forward(self, x):
        """Return the decoded output.
        Arguments
        ---------
        x : torch.Tensor
            Input tensor with dimensionality [B, N, L].
                where, B = Batchsize,
                       N = number of filters
                       L = time points
        """
        if x.dim() not in [2, 3]:
            raise RuntimeError(
                "{} accept 3/4D tensor as input".format(self.__name__)
            )
        x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
        if torch.squeeze(x).dim() == 1:
            x = torch.squeeze(x, dim=1)
        else:
            x = torch.squeeze(x)
        return x
funasr/models/e2e_ss.py
New file
@@ -0,0 +1,95 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from funasr.models.base_model import FunASRModel
from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet
from funasr.models.decoder.mossformer_decoder import MossFormerDecoder
class MossFormer(FunASRModel):
    """The MossFormer model for separating input mixed speech into different speaker's speech.
    Arguments
    ---------
    in_channels : int
        Number of channels at the output of the encoder.
    out_channels : int
        Number of channels that would be inputted to the intra and inter blocks.
    num_blocks : int
        Number of layers of Dual Computation Block.
    norm : str
        Normalization type.
    num_spks : int
        Number of sources (speakers).
    skip_around_intra : bool
        Skip connection around intra.
    use_global_pos_enc : bool
        Global positional encodings.
    max_length : int
        Maximum sequence length.
    kernel_size: int
        Encoder and decoder kernel size
    """
    def __init__(
        self,
        in_channels=512,
        out_channels=512,
        num_blocks=24,
        kernel_size=16,
        norm="ln",
        num_spks=2,
        skip_around_intra=True,
        use_global_pos_enc=True,
        max_length=20000,
    ):
        super(MossFormer, self).__init__()
        self.num_spks = num_spks
        # Encoding
        self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1)
        ##Compute Mask
        self.mask_net = MossFormer_MaskNet(
            in_channels=in_channels,
            out_channels=out_channels,
            num_blocks=num_blocks,
            norm=norm,
            num_spks=num_spks,
            skip_around_intra=skip_around_intra,
            use_global_pos_enc=use_global_pos_enc,
            max_length=max_length,
        )
        self.dec = MossFormerDecoder(
           in_channels=out_channels,
           out_channels=1,
           kernel_size=kernel_size,
           stride = kernel_size//2,
           bias=False
        )
    def forward(self, input):
        x = self.enc(input)
        mask = self.mask_net(x)
        x = torch.stack([x] * self.num_spks)
        sep_x = x * mask
        # Decoding
        est_source = torch.cat(
            [
                self.dec(sep_x[i]).unsqueeze(-1)
                for i in range(self.num_spks)
            ],
            dim=-1,
        )
        T_origin = input.size(1)
        T_est = est_source.size(1)
        if T_origin > T_est:
            est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
        else:
            est_source = est_source[:, :T_origin, :]
        out = []
        for spk in range(self.num_spks):
            out.append(est_source[:,:,spk])
        return out
funasr/models/encoder/mossformer_encoder.py
New file
@@ -0,0 +1,417 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from rotary_embedding_torch import RotaryEmbedding
from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
from funasr.modules.embedding import ScaledSinuEmbedding
from funasr.modules.mossformer import FLASH_ShareA_FFConvM
def select_norm(norm, dim, shape):
    """Just a wrapper to select the normalization type.
    """
    if norm == "gln":
        return GlobalLayerNorm(dim, shape, elementwise_affine=True)
    if norm == "cln":
        return CumulativeLayerNorm(dim, elementwise_affine=True)
    if norm == "ln":
        return nn.GroupNorm(1, dim, eps=1e-8)
    else:
        return nn.BatchNorm1d(dim)
class MossformerBlock(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        group_size = 256,
        query_key_dim = 128,
        expansion_factor = 4.,
        causal = False,
        attn_dropout = 0.1,
        norm_type = 'scalenorm',
        shift_tokens = True
    ):
        super().__init__()
        assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
        if norm_type == 'scalenorm':
            norm_klass = ScaleNorm
        elif norm_type == 'layernorm':
            norm_klass = nn.LayerNorm
        self.group_size = group_size
        rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
        # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
        self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])
    def forward(
        self,
        x,
        *,
        mask = None
    ):
        ii = 0
        for flash in self.layers:
            x = flash(x, mask = mask)
            ii = ii + 1
        return x
class MossFormer_MaskNet(nn.Module):
    """The MossFormer module for computing output masks.
    Arguments
    ---------
    in_channels : int
        Number of channels at the output of the encoder.
    out_channels : int
        Number of channels that would be inputted to the intra and inter blocks.
    num_blocks : int
        Number of layers of Dual Computation Block.
    norm : str
        Normalization type.
    num_spks : int
        Number of sources (speakers).
    skip_around_intra : bool
        Skip connection around intra.
    use_global_pos_enc : bool
        Global positional encodings.
    max_length : int
        Maximum sequence length.
    Example
    ---------
    >>> mossformer_block = MossFormerM(1, 64, 8)
    >>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2)
    >>> x = torch.randn(10, 64, 2000)
    >>> x = mossformer_masknet(x)
    >>> x.shape
    torch.Size([2, 10, 64, 2000])
    """
    def __init__(
        self,
        in_channels,
        out_channels,
        num_blocks=24,
        norm="ln",
        num_spks=2,
        skip_around_intra=True,
        use_global_pos_enc=True,
        max_length=20000,
    ):
        super(MossFormer_MaskNet, self).__init__()
        self.num_spks = num_spks
        self.num_blocks = num_blocks
        self.norm = select_norm(norm, in_channels, 3)
        self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
        self.use_global_pos_enc = use_global_pos_enc
        if self.use_global_pos_enc:
            self.pos_enc = ScaledSinuEmbedding(out_channels)
        self.mdl = Computation_Block(
                    num_blocks,
                    out_channels,
                    norm,
                    skip_around_intra=skip_around_intra,
                )
        self.conv1d_out = nn.Conv1d(
            out_channels, out_channels * num_spks, kernel_size=1
        )
        self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
        self.prelu = nn.PReLU()
        self.activation = nn.ReLU()
        # gated output layer
        self.output = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
        )
        self.output_gate = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
        )
    def forward(self, x):
        """Returns the output tensor.
        Arguments
        ---------
        x : torch.Tensor
            Input tensor of dimension [B, N, S].
        Returns
        -------
        out : torch.Tensor
            Output tensor of dimension [spks, B, N, S]
            where, spks = Number of speakers
               B = Batchsize,
               N = number of filters
               S = the number of time frames
        """
        # before each line we indicate the shape after executing the line
        # [B, N, L]
        x = self.norm(x)
        # [B, N, L]
        x = self.conv1d_encoder(x)
        if self.use_global_pos_enc:
            #x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
            #    x.size(1) ** 0.5)
            base = x
            x = x.transpose(1, -1)
            emb = self.pos_enc(x)
            emb = emb.transpose(0, -1)
            #print('base: {}, emb: {}'.format(base.shape, emb.shape))
            x = base + emb
        # [B, N, S]
        #for i in range(self.num_modules):
        #    x = self.dual_mdl[i](x)
        x = self.mdl(x)
        x = self.prelu(x)
        # [B, N*spks, S]
        x = self.conv1d_out(x)
        B, _, S = x.shape
        # [B*spks, N, S]
        x = x.view(B * self.num_spks, -1, S)
        # [B*spks, N, S]
        x = self.output(x) * self.output_gate(x)
        # [B*spks, N, S]
        x = self.conv1_decoder(x)
        # [B, spks, N, S]
        _, N, L = x.shape
        x = x.view(B, self.num_spks, N, L)
        x = self.activation(x)
        # [spks, B, N, S]
        x = x.transpose(0, 1)
        return x
class MossFormerEncoder(nn.Module):
    """Convolutional Encoder Layer.
    Arguments
    ---------
    kernel_size : int
        Length of filters.
    in_channels : int
        Number of  input channels.
    out_channels : int
        Number of output channels.
    Example
    -------
    >>> x = torch.randn(2, 1000)
    >>> encoder = Encoder(kernel_size=4, out_channels=64)
    >>> h = encoder(x)
    >>> h.shape
    torch.Size([2, 64, 499])
    """
    def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
        super(MossFormerEncoder, self).__init__()
        self.conv1d = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=kernel_size // 2,
            groups=1,
            bias=False,
        )
        self.in_channels = in_channels
    def forward(self, x):
        """Return the encoded output.
        Arguments
        ---------
        x : torch.Tensor
            Input tensor with dimensionality [B, L].
        Return
        ------
        x : torch.Tensor
            Encoded tensor with dimensionality [B, N, T_out].
        where B = Batchsize
              L = Number of timepoints
              N = Number of filters
              T_out = Number of timepoints at the output of the encoder
        """
        # B x L -> B x 1 x L
        if self.in_channels == 1:
            x = torch.unsqueeze(x, dim=1)
        # B x 1 x L -> B x N x T_out
        x = self.conv1d(x)
        x = F.relu(x)
        return x
class MossFormerM(nn.Module):
    """This class implements the transformer encoder.
    Arguments
    ---------
    num_blocks : int
        Number of mossformer blocks to include.
    d_model : int
        The dimension of the input embedding.
    attn_dropout : float
        Dropout for the self-attention (Optional).
    group_size: int
        the chunk size
    query_key_dim: int
        the attention vector dimension
    expansion_factor: int
        the expansion factor for the linear projection in conv module
    causal: bool
        true for causal / false for non causal
    Example
    -------
    >>> import torch
    >>> x = torch.rand((8, 60, 512))
    >>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
    >>> output, _ = net(x)
    >>> output.shape
    torch.Size([8, 60, 512])
    """
    def __init__(
        self,
        num_blocks,
        d_model=None,
        causal=False,
        group_size = 256,
        query_key_dim = 128,
        expansion_factor = 4.,
        attn_dropout = 0.1
    ):
        super().__init__()
        self.mossformerM = MossformerBlock(
                           dim=d_model,
                           depth=num_blocks,
                           group_size=group_size,
                           query_key_dim=query_key_dim,
                           expansion_factor=expansion_factor,
                           causal=causal,
                           attn_dropout=attn_dropout
                              )
        self.norm = nn.LayerNorm(d_model, eps=1e-6)
    def forward(
        self,
        src,
    ):
        """
        Arguments
        ----------
        src : torch.Tensor
            Tensor shape [B, L, N],
            where, B = Batchsize,
                   L = time points
                   N = number of filters
            The sequence to the encoder layer (required).
        src_mask : tensor
            The mask for the src sequence (optional).
        src_key_padding_mask : tensor
            The mask for the src keys per batch (optional).
        """
        output = self.mossformerM(src)
        output = self.norm(output)
        return output
class Computation_Block(nn.Module):
    """Computation block for dual-path processing.
    Arguments
    ---------
     out_channels : int
        Dimensionality of inter/intra model.
     norm : str
        Normalization type.
     skip_around_intra : bool
        Skip connection around the intra layer.
    Example
    ---------
        >>> comp_block = Computation_Block(64)
        >>> x = torch.randn(10, 64, 100)
        >>> x = comp_block(x)
        >>> x.shape
        torch.Size([10, 64, 100])
    """
    def __init__(
        self,
        num_blocks,
        out_channels,
        norm="ln",
        skip_around_intra=True,
    ):
        super(Computation_Block, self).__init__()
        ##MossFormer2M: MossFormer with recurrence
        #self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
        ##MossFormerM: the orignal MossFormer
        self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
        self.skip_around_intra = skip_around_intra
        # Norm
        self.norm = norm
        if norm is not None:
            self.intra_norm = select_norm(norm, out_channels, 3)
    def forward(self, x):
        """Returns the output tensor.
        Arguments
        ---------
        x : torch.Tensor
            Input tensor of dimension [B, N, S].
        Return
        ---------
        out: torch.Tensor
            Output tensor of dimension [B, N, S].
            where, B = Batchsize,
               N = number of filters
               S = sequence time index
        """
        B, N, S = x.shape
        # intra RNN
        # [B, S, N]
        intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
        intra = self.intra_mdl(intra)
        # [B, N, S]
        intra = intra.permute(0, 2, 1).contiguous()
        if self.norm is not None:
            intra = self.intra_norm(intra)
        # [B, N, S]
        if self.skip_around_intra:
            intra = intra + x
        out = intra
        return out
funasr/modules/embedding.py
@@ -9,6 +9,7 @@
import math
import torch
import torch.nn.functional as F
from torch import einsum
def _pre_hook(
    state_dict,
@@ -510,3 +511,19 @@
        pos_enc = self.dropout(pos_enc)
        return pos_enc
class ScaledSinuEmbedding(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = torch.nn.Parameter(torch.ones(1,))
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
    def forward(self, x):
        n, device = x.shape[1], x.device
        t = torch.arange(n, device = device).type_as(self.inv_freq)
        sinu = einsum('i , j -> i j', t, self.inv_freq)
        emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
        return emb * self.scale
funasr/modules/layer_norm.py
@@ -7,6 +7,7 @@
"""Layer normalization module."""
import torch
import torch.nn as nn
class LayerNorm(torch.nn.LayerNorm):
@@ -40,3 +41,137 @@
            .forward(x.transpose(self.dim, -1))
            .transpose(self.dim, -1)
        )
class GlobalLayerNorm(nn.Module):
    """Calculate Global Layer Normalization.
    Arguments
    ---------
       dim : (int or list or torch.Size)
           Input shape from an expected input of size.
       eps : float
           A value added to the denominator for numerical stability.
       elementwise_affine : bool
          A boolean value that when set to True,
          this module has learnable per-element affine parameters
          initialized to ones (for weights) and zeros (for biases).
    Example
    -------
    >>> x = torch.randn(5, 10, 20)
    >>> GLN = GlobalLayerNorm(10, 3)
    >>> x_norm = GLN(x)
    """
    def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
        super(GlobalLayerNorm, self).__init__()
        self.dim = dim
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            if shape == 3:
                self.weight = nn.Parameter(torch.ones(self.dim, 1))
                self.bias = nn.Parameter(torch.zeros(self.dim, 1))
            if shape == 4:
                self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
                self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)
    def forward(self, x):
        """Returns the normalized tensor.
        Arguments
        ---------
        x : torch.Tensor
            Tensor of size [N, C, K, S] or [N, C, L].
        """
        # x = N x C x K x S or N x C x L
        # N x 1 x 1
        # cln: mean,var N x 1 x K x S
        # gln: mean,var N x 1 x 1
        if x.dim() == 3:
            mean = torch.mean(x, (1, 2), keepdim=True)
            var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
            if self.elementwise_affine:
                x = (
                    self.weight * (x - mean) / torch.sqrt(var + self.eps)
                    + self.bias
                )
            else:
                x = (x - mean) / torch.sqrt(var + self.eps)
        if x.dim() == 4:
            mean = torch.mean(x, (1, 2, 3), keepdim=True)
            var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
            if self.elementwise_affine:
                x = (
                    self.weight * (x - mean) / torch.sqrt(var + self.eps)
                    + self.bias
                )
            else:
                x = (x - mean) / torch.sqrt(var + self.eps)
        return x
class CumulativeLayerNorm(nn.LayerNorm):
    """Calculate Cumulative Layer Normalization.
       Arguments
       ---------
       dim : int
        Dimension that you want to normalize.
       elementwise_affine : True
        Learnable per-element affine parameters.
    Example
    -------
    >>> x = torch.randn(5, 10, 20)
    >>> CLN = CumulativeLayerNorm(10)
    >>> x_norm = CLN(x)
    """
    def __init__(self, dim, elementwise_affine=True):
        super(CumulativeLayerNorm, self).__init__(
            dim, elementwise_affine=elementwise_affine, eps=1e-8
        )
    def forward(self, x):
        """Returns the normalized tensor.
        Arguments
        ---------
        x : torch.Tensor
            Tensor size [N, C, K, S] or [N, C, L]
        """
        # x: N x C x K x S or N x C x L
        # N x K x S x C
        if x.dim() == 4:
            x = x.permute(0, 2, 3, 1).contiguous()
            # N x K x S x C == only channel norm
            x = super().forward(x)
            # N x C x K x S
            x = x.permute(0, 3, 1, 2).contiguous()
        if x.dim() == 3:
            x = torch.transpose(x, 1, 2)
            # N x L x C == only channel norm
            x = super().forward(x)
            # N x C x L
            x = torch.transpose(x, 1, 2)
        return x
class ScaleNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.scale = dim ** -0.5
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1))
    def forward(self, x):
        norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
        return x / norm.clamp(min = self.eps) * self.g
funasr/modules/mossformer.py
New file
@@ -0,0 +1,307 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
def identity(t, *args, **kwargs):
    return t
def append_dims(x, num_dims):
    if num_dims <= 0:
        return x
    return x.view(*x.shape, *((1,) * num_dims))
def exists(val):
    return val is not None
def default(val, d):
    return val if exists(val) else d
def padding_to_multiple_of(n, mult):
    remainder = n % mult
    if remainder == 0:
        return 0
    return mult - remainder
class Transpose(nn.Module):
    """ Wrapper class of torch.transpose() for Sequential module. """
    def __init__(self, shape: tuple):
        super(Transpose, self).__init__()
        self.shape = shape
    def forward(self, x):
        return x.transpose(*self.shape)
class DepthwiseConv1d(nn.Module):
    """
    When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
    this operation is termed in literature as depthwise convolution.
    Args:
        in_channels (int): Number of channels in the input
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int or tuple): Size of the convolving kernel
        stride (int, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
        bias (bool, optional): If True, adds a learnable bias to the output. Default: True
    Inputs: inputs
        - **inputs** (batch, in_channels, time): Tensor containing input vector
    Returns: outputs
        - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
    """
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            stride: int = 1,
            padding: int = 0,
            bias: bool = False,
    ) -> None:
        super(DepthwiseConv1d, self).__init__()
        assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            groups=in_channels,
            stride=stride,
            padding=padding,
            bias=bias,
        )
    def forward(self, inputs):
        return self.conv(inputs)
class ConvModule(nn.Module):
    """
    Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
    This is followed by a single 1-D depthwise convolution layer. Batchnorm is  deployed just after the convolution
    to aid training deep models.
    Args:
        in_channels (int): Number of channels in the input
        kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
        dropout_p (float, optional): probability of dropout
    Inputs: inputs
        inputs (batch, time, dim): Tensor contains input sequences
    Outputs: outputs
        outputs (batch, time, dim): Tensor produces by conformer convolution module.
    """
    def __init__(
            self,
            in_channels: int,
            kernel_size: int = 17,
            expansion_factor: int = 2,
            dropout_p: float = 0.1,
    ) -> None:
        super(ConvModule, self).__init__()
        assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
        assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
        self.sequential = nn.Sequential(
            Transpose(shape=(1, 2)),
            DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
        )
    def forward(self, inputs):
        return inputs + self.sequential(inputs).transpose(1, 2)
class OffsetScale(nn.Module):
    def __init__(self, dim, heads = 1):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(heads, dim))
        self.beta = nn.Parameter(torch.zeros(heads, dim))
        nn.init.normal_(self.gamma, std = 0.02)
    def forward(self, x):
        out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
        return out.unbind(dim = -2)
class FFConvM(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        norm_klass = nn.LayerNorm,
        dropout = 0.1
    ):
        super().__init__()
        self.mdl = nn.Sequential(
            norm_klass(dim_in),
            nn.Linear(dim_in, dim_out),
            nn.SiLU(),
            ConvModule(dim_out),
            nn.Dropout(dropout)
        )
    def forward(
        self,
        x,
    ):
        output = self.mdl(x)
        return output
class FLASH_ShareA_FFConvM(nn.Module):
    def __init__(
        self,
        *,
        dim,
        group_size = 256,
        query_key_dim = 128,
        expansion_factor = 1.,
        causal = False,
        dropout = 0.1,
        rotary_pos_emb = None,
        norm_klass = nn.LayerNorm,
        shift_tokens = True
    ):
        super().__init__()
        hidden_dim = int(dim * expansion_factor)
        self.group_size = group_size
        self.causal = causal
        self.shift_tokens = shift_tokens
        # positional embeddings
        self.rotary_pos_emb = rotary_pos_emb
        # norm
        self.dropout = nn.Dropout(dropout)
        # projections
        self.to_hidden = FFConvM(
            dim_in = dim,
            dim_out = hidden_dim,
            norm_klass = norm_klass,
            dropout = dropout,
            )
        self.to_qk = FFConvM(
            dim_in = dim,
            dim_out = query_key_dim,
            norm_klass = norm_klass,
            dropout = dropout,
            )
        self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
        self.to_out = FFConvM(
            dim_in = dim*2,
            dim_out = dim,
            norm_klass = norm_klass,
            dropout = dropout,
            )
        self.gateActivate=nn.Sigmoid()
    def forward(
        self,
        x,
        *,
        mask = None
    ):
        """
        b - batch
        n - sequence length (within groups)
        g - group dimension
        d - feature dimension (keys)
        e - feature dimension (values)
        i - sequence dimension (source)
        j - sequence dimension (target)
        """
        normed_x = x
        # do token shift - a great, costless trick from an independent AI researcher in Shenzhen
        residual = x
        if self.shift_tokens:
            x_shift, x_pass = normed_x.chunk(2, dim = -1)
            x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
            normed_x = torch.cat((x_shift, x_pass), dim = -1)
        # initial projections
        v, u = self.to_hidden(normed_x).chunk(2, dim = -1)
        qk = self.to_qk(normed_x)
        # offset and scale
        quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
        att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
        out = (att_u*v ) * self.gateActivate(att_v*u)
        x = x + self.to_out(out)
        return x
    def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None):
        b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
        if exists(mask):
            lin_mask = rearrange(mask, '... -> ... 1')
            lin_k = lin_k.masked_fill(~lin_mask, 0.)
        # rotate queries and keys
        if exists(self.rotary_pos_emb):
            quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
        # padding for groups
        padding = padding_to_multiple_of(n, g)
        if padding > 0:
            quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))
            mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
            mask = F.pad(mask, (0, padding), value = False)
        # group along sequence
        quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
        if exists(mask):
            mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
        # calculate quadratic attention output
        sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
        attn = F.relu(sim) ** 2
        attn = self.dropout(attn)
        if exists(mask):
            attn = attn.masked_fill(~mask, 0.)
        if self.causal:
            causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
            attn = attn.masked_fill(causal_mask, 0.)
        quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
        quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
        # calculate linear attention output
        if self.causal:
            lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
            # exclusive cumulative sum along group dimension
            lin_kv = lin_kv.cumsum(dim = 1)
            lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
            lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
            lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
            # exclusive cumulative sum along group dimension
            lin_ku = lin_ku.cumsum(dim = 1)
            lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
            lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
        else:
            lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
            lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
            lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
            lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
        # fold back groups into full sequence, and excise out padding
        return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u))