From 8b9c53cfc30b43129f66dad04d761b87f7dcc89f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 21 十一月 2023 15:47:18 +0800
Subject: [PATCH] funasr v2 setup

---
 funasr/layers/complex_utils.py                              |    2 
 funasr/models/frontend/frontends_utils/mask_estimator.py    |   77 +++++
 funasr/models/encoder/mossformer_encoder.py                 |    2 
 funasr/models/frontend/default.py                           |    4 
 funasr/models/frontend/frontends_utils/dnn_beamformer.py    |  172 +++++++++++
 funasr/models/frontend/frontends_utils/__init__.py          |    1 
 funasr/models/frontend/frontends_utils/feature_transform.py |  263 +++++++++++++++++
 funasr/models/frontend/frontends_utils/dnn_wpe.py           |   93 ++++++
 funasr/layers/stft.py                                       |    2 
 funasr/models/frontend/frontends_utils/beamformer.py        |   84 +++++
 funasr/models/frontend/frontends_utils/frontend.py          |  151 ++++++++++
 11 files changed, 846 insertions(+), 5 deletions(-)

diff --git a/funasr/layers/complex_utils.py b/funasr/layers/complex_utils.py
index d6f7c6d..5d313c6 100644
--- a/funasr/layers/complex_utils.py
+++ b/funasr/layers/complex_utils.py
@@ -9,7 +9,7 @@
     from torch_complex import functional as FC
     from torch_complex.tensor import ComplexTensor
 except:
-    raise "Please install torch_complex firstly"
+    print("Please install torch_complex firstly")
 
 
 
diff --git a/funasr/layers/stft.py b/funasr/layers/stft.py
index 67ebf7a..c71af8e 100644
--- a/funasr/layers/stft.py
+++ b/funasr/layers/stft.py
@@ -8,7 +8,7 @@
 try:
     from torch_complex.tensor import ComplexTensor
 except:
-    raise "Please install torch_complex firstly"
+    print("Please install torch_complex firstly")
 from funasr.modules.nets_utils import make_pad_mask
 from funasr.layers.complex_utils import is_complex
 from funasr.layers.inversible_interface import InversibleInterface
diff --git a/funasr/models/encoder/mossformer_encoder.py b/funasr/models/encoder/mossformer_encoder.py
index f7d9c47..6c092e7 100644
--- a/funasr/models/encoder/mossformer_encoder.py
+++ b/funasr/models/encoder/mossformer_encoder.py
@@ -4,7 +4,7 @@
 try:
     from rotary_embedding_torch import RotaryEmbedding
 except:
-    raise "Please install rotary_embedding_torch by: \n pip install -U funasr[all]"
+    print("Please install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch")
 from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
 from funasr.modules.embedding import ScaledSinuEmbedding
 from funasr.modules.mossformer import FLASH_ShareA_FFConvM
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 8d60e20..3435321 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -9,12 +9,12 @@
 try:
     from torch_complex.tensor import ComplexTensor
 except:
-    raise "Please install torch_complex firstly"
+    print("Please install torch_complex firstly")
 
 from funasr.layers.log_mel import LogMel
 from funasr.layers.stft import Stft
 from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.modules.frontends.frontend import Frontend
+from funasr.models.frontend.frontends_utils.frontend import Frontend
 from funasr.utils.get_default_kwargs import get_default_kwargs
 from funasr.modules.nets_utils import make_pad_mask
 
diff --git a/funasr/models/frontend/frontends_utils/__init__.py b/funasr/models/frontend/frontends_utils/__init__.py
new file mode 100644
index 0000000..b7f1773
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/__init__.py
@@ -0,0 +1 @@
+"""Initialize sub package."""
diff --git a/funasr/models/frontend/frontends_utils/beamformer.py b/funasr/models/frontend/frontends_utils/beamformer.py
new file mode 100644
index 0000000..f3eccee
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/beamformer.py
@@ -0,0 +1,84 @@
+import torch
+from torch_complex import functional as FC
+from torch_complex.tensor import ComplexTensor
+
+
+def get_power_spectral_density_matrix(
+    xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15
+) -> ComplexTensor:
+    """Return cross-channel power spectral density (PSD) matrix
+
+    Args:
+        xs (ComplexTensor): (..., F, C, T)
+        mask (torch.Tensor): (..., F, C, T)
+        normalization (bool):
+        eps (float):
+    Returns
+        psd (ComplexTensor): (..., F, C, C)
+
+    """
+    # outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2)
+    psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()])
+
+    # Averaging mask along C: (..., C, T) -> (..., T)
+    mask = mask.mean(dim=-2)
+
+    # Normalized mask along T: (..., T)
+    if normalization:
+        # If assuming the tensor is padded with zero, the summation along
+        # the time axis is same regardless of the padding length.
+        mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
+
+    # psd: (..., T, C, C)
+    psd = psd_Y * mask[..., None, None]
+    # (..., T, C, C) -> (..., C, C)
+    psd = psd.sum(dim=-3)
+
+    return psd
+
+
+def get_mvdr_vector(
+    psd_s: ComplexTensor,
+    psd_n: ComplexTensor,
+    reference_vector: torch.Tensor,
+    eps: float = 1e-15,
+) -> ComplexTensor:
+    """Return the MVDR(Minimum Variance Distortionless Response) vector:
+
+        h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u
+
+    Reference:
+        On optimal frequency-domain multichannel linear filtering
+        for noise reduction; M. Souden et al., 2010;
+        https://ieeexplore.ieee.org/document/5089420
+
+    Args:
+        psd_s (ComplexTensor): (..., F, C, C)
+        psd_n (ComplexTensor): (..., F, C, C)
+        reference_vector (torch.Tensor): (..., C)
+        eps (float):
+    Returns:
+        beamform_vector (ComplexTensor)r: (..., F, C)
+    """
+    # Add eps
+    C = psd_n.size(-1)
+    eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
+    shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
+    eye = eye.view(*shape)
+    psd_n += eps * eye
+
+    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
+    numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s])
+    # ws: (..., C, C) / (...,) -> (..., C, C)
+    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
+    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
+    beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
+    return beamform_vector
+
+
+def apply_beamforming_vector(
+    beamform_vector: ComplexTensor, mix: ComplexTensor
+) -> ComplexTensor:
+    # (..., C) x (..., C, T) -> (..., T)
+    es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix])
+    return es
diff --git a/funasr/models/frontend/frontends_utils/dnn_beamformer.py b/funasr/models/frontend/frontends_utils/dnn_beamformer.py
new file mode 100644
index 0000000..05e241d
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/dnn_beamformer.py
@@ -0,0 +1,172 @@
+"""DNN beamformer module."""
+from typing import Tuple
+
+import torch
+from torch.nn import functional as F
+
+from funasr.models.frontend.frontends_utils.beamformer import apply_beamforming_vector
+from funasr.models.frontend.frontends_utils.beamformer import get_mvdr_vector
+from funasr.models.frontend.frontends_utils.beamformer import (
+    get_power_spectral_density_matrix,  # noqa: H301
+)
+from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
+from torch_complex.tensor import ComplexTensor
+
+
+class DNN_Beamformer(torch.nn.Module):
+    """DNN mask based Beamformer
+
+    Citation:
+        Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
+        https://arxiv.org/abs/1703.04783
+
+    """
+
+    def __init__(
+        self,
+        bidim,
+        btype="blstmp",
+        blayers=3,
+        bunits=300,
+        bprojs=320,
+        bnmask=2,
+        dropout_rate=0.0,
+        badim=320,
+        ref_channel: int = -1,
+        beamformer_type="mvdr",
+    ):
+        super().__init__()
+        self.mask = MaskEstimator(
+            btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
+        )
+        self.ref = AttentionReference(bidim, badim)
+        self.ref_channel = ref_channel
+
+        self.nmask = bnmask
+
+        if beamformer_type != "mvdr":
+            raise ValueError(
+                "Not supporting beamformer_type={}".format(beamformer_type)
+            )
+        self.beamformer_type = beamformer_type
+
+    def forward(
+        self, data: ComplexTensor, ilens: torch.LongTensor
+    ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
+        """The forward function
+
+        Notation:
+            B: Batch
+            C: Channel
+            T: Time or Sequence length
+            F: Freq
+
+        Args:
+            data (ComplexTensor): (B, T, C, F)
+            ilens (torch.Tensor): (B,)
+        Returns:
+            enhanced (ComplexTensor): (B, T, F)
+            ilens (torch.Tensor): (B,)
+
+        """
+
+        def apply_beamforming(data, ilens, psd_speech, psd_noise):
+            # u: (B, C)
+            if self.ref_channel < 0:
+                u, _ = self.ref(psd_speech, ilens)
+            else:
+                # (optional) Create onehot vector for fixed reference microphone
+                u = torch.zeros(
+                    *(data.size()[:-3] + (data.size(-2),)), device=data.device
+                )
+                u[..., self.ref_channel].fill_(1)
+
+            ws = get_mvdr_vector(psd_speech, psd_noise, u)
+            enhanced = apply_beamforming_vector(ws, data)
+
+            return enhanced, ws
+
+        # data (B, T, C, F) -> (B, F, C, T)
+        data = data.permute(0, 3, 2, 1)
+
+        # mask: (B, F, C, T)
+        masks, _ = self.mask(data, ilens)
+        assert self.nmask == len(masks)
+
+        if self.nmask == 2:  # (mask_speech, mask_noise)
+            mask_speech, mask_noise = masks
+
+            psd_speech = get_power_spectral_density_matrix(data, mask_speech)
+            psd_noise = get_power_spectral_density_matrix(data, mask_noise)
+
+            enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise)
+
+            # (..., F, T) -> (..., T, F)
+            enhanced = enhanced.transpose(-1, -2)
+            mask_speech = mask_speech.transpose(-1, -3)
+        else:  # multi-speaker case: (mask_speech1, ..., mask_noise)
+            mask_speech = list(masks[:-1])
+            mask_noise = masks[-1]
+
+            psd_speeches = [
+                get_power_spectral_density_matrix(data, mask) for mask in mask_speech
+            ]
+            psd_noise = get_power_spectral_density_matrix(data, mask_noise)
+
+            enhanced = []
+            ws = []
+            for i in range(self.nmask - 1):
+                psd_speech = psd_speeches.pop(i)
+                # treat all other speakers' psd_speech as noises
+                enh, w = apply_beamforming(
+                    data, ilens, psd_speech, sum(psd_speeches) + psd_noise
+                )
+                psd_speeches.insert(i, psd_speech)
+
+                # (..., F, T) -> (..., T, F)
+                enh = enh.transpose(-1, -2)
+                mask_speech[i] = mask_speech[i].transpose(-1, -3)
+
+                enhanced.append(enh)
+                ws.append(w)
+
+        return enhanced, ilens, mask_speech
+
+
+class AttentionReference(torch.nn.Module):
+    def __init__(self, bidim, att_dim):
+        super().__init__()
+        self.mlp_psd = torch.nn.Linear(bidim, att_dim)
+        self.gvec = torch.nn.Linear(att_dim, 1)
+
+    def forward(
+        self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
+    ) -> Tuple[torch.Tensor, torch.LongTensor]:
+        """The forward function
+
+        Args:
+            psd_in (ComplexTensor): (B, F, C, C)
+            ilens (torch.Tensor): (B,)
+            scaling (float):
+        Returns:
+            u (torch.Tensor): (B, C)
+            ilens (torch.Tensor): (B,)
+        """
+        B, _, C = psd_in.size()[:3]
+        assert psd_in.size(2) == psd_in.size(3), psd_in.size()
+        # psd_in: (B, F, C, C)
+        psd = psd_in.masked_fill(
+            torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
+        )
+        # psd: (B, F, C, C) -> (B, C, F)
+        psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
+
+        # Calculate amplitude
+        psd_feat = (psd.real**2 + psd.imag**2) ** 0.5
+
+        # (B, C, F) -> (B, C, F2)
+        mlp_psd = self.mlp_psd(psd_feat)
+        # (B, C, F2) -> (B, C, 1) -> (B, C)
+        e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
+        u = F.softmax(scaling * e, dim=-1)
+        return u, ilens
diff --git a/funasr/models/frontend/frontends_utils/dnn_wpe.py b/funasr/models/frontend/frontends_utils/dnn_wpe.py
new file mode 100644
index 0000000..ffe12a3
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/dnn_wpe.py
@@ -0,0 +1,93 @@
+from typing import Tuple
+
+from pytorch_wpe import wpe_one_iteration
+import torch
+from torch_complex.tensor import ComplexTensor
+
+from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
+from funasr.modules.nets_utils import make_pad_mask
+
+
+class DNN_WPE(torch.nn.Module):
+    def __init__(
+        self,
+        wtype: str = "blstmp",
+        widim: int = 257,
+        wlayers: int = 3,
+        wunits: int = 300,
+        wprojs: int = 320,
+        dropout_rate: float = 0.0,
+        taps: int = 5,
+        delay: int = 3,
+        use_dnn_mask: bool = True,
+        iterations: int = 1,
+        normalization: bool = False,
+    ):
+        super().__init__()
+        self.iterations = iterations
+        self.taps = taps
+        self.delay = delay
+
+        self.normalization = normalization
+        self.use_dnn_mask = use_dnn_mask
+
+        self.inverse_power = True
+
+        if self.use_dnn_mask:
+            self.mask_est = MaskEstimator(
+                wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1
+            )
+
+    def forward(
+        self, data: ComplexTensor, ilens: torch.LongTensor
+    ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
+        """The forward function
+
+        Notation:
+            B: Batch
+            C: Channel
+            T: Time or Sequence length
+            F: Freq or Some dimension of the feature vector
+
+        Args:
+            data: (B, C, T, F)
+            ilens: (B,)
+        Returns:
+            data: (B, C, T, F)
+            ilens: (B,)
+        """
+        # (B, T, C, F) -> (B, F, C, T)
+        enhanced = data = data.permute(0, 3, 2, 1)
+        mask = None
+
+        for i in range(self.iterations):
+            # Calculate power: (..., C, T)
+            power = enhanced.real**2 + enhanced.imag**2
+            if i == 0 and self.use_dnn_mask:
+                # mask: (B, F, C, T)
+                (mask,), _ = self.mask_est(enhanced, ilens)
+                if self.normalization:
+                    # Normalize along T
+                    mask = mask / mask.sum(dim=-1)[..., None]
+                # (..., C, T) * (..., C, T) -> (..., C, T)
+                power = power * mask
+
+            # Averaging along the channel axis: (..., C, T) -> (..., T)
+            power = power.mean(dim=-2)
+
+            # enhanced: (..., C, T) -> (..., C, T)
+            enhanced = wpe_one_iteration(
+                data.contiguous(),
+                power,
+                taps=self.taps,
+                delay=self.delay,
+                inverse_power=self.inverse_power,
+            )
+
+            enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)
+
+        # (B, F, C, T) -> (B, T, C, F)
+        enhanced = enhanced.permute(0, 3, 2, 1)
+        if mask is not None:
+            mask = mask.transpose(-1, -3)
+        return enhanced, ilens, mask
diff --git a/funasr/models/frontend/frontends_utils/feature_transform.py b/funasr/models/frontend/frontends_utils/feature_transform.py
new file mode 100644
index 0000000..353dca1
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/feature_transform.py
@@ -0,0 +1,263 @@
+from typing import List
+from typing import Tuple
+from typing import Union
+
+import librosa
+import numpy as np
+import torch
+from torch_complex.tensor import ComplexTensor
+
+from funasr.modules.nets_utils import make_pad_mask
+
+
+class FeatureTransform(torch.nn.Module):
+    def __init__(
+        self,
+        # Mel options,
+        fs: int = 16000,
+        n_fft: int = 512,
+        n_mels: int = 80,
+        fmin: float = 0.0,
+        fmax: float = None,
+        # Normalization
+        stats_file: str = None,
+        apply_uttmvn: bool = True,
+        uttmvn_norm_means: bool = True,
+        uttmvn_norm_vars: bool = False,
+    ):
+        super().__init__()
+        self.apply_uttmvn = apply_uttmvn
+
+        self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
+        self.stats_file = stats_file
+        if stats_file is not None:
+            self.global_mvn = GlobalMVN(stats_file)
+        else:
+            self.global_mvn = None
+
+        if self.apply_uttmvn is not None:
+            self.uttmvn = UtteranceMVN(
+                norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars
+            )
+        else:
+            self.uttmvn = None
+
+    def forward(
+        self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
+    ) -> Tuple[torch.Tensor, torch.LongTensor]:
+        # (B, T, F) or (B, T, C, F)
+        if x.dim() not in (3, 4):
+            raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
+        if not torch.is_tensor(ilens):
+            ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)
+
+        if x.dim() == 4:
+            # h: (B, T, C, F) -> h: (B, T, F)
+            if self.training:
+                # Select 1ch randomly
+                ch = np.random.randint(x.size(2))
+                h = x[:, :, ch, :]
+            else:
+                # Use the first channel
+                h = x[:, :, 0, :]
+        else:
+            h = x
+
+        # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
+        h = h.real**2 + h.imag**2
+
+        h, _ = self.logmel(h, ilens)
+        if self.stats_file is not None:
+            h, _ = self.global_mvn(h, ilens)
+        if self.apply_uttmvn:
+            h, _ = self.uttmvn(h, ilens)
+
+        return h, ilens
+
+
+class LogMel(torch.nn.Module):
+    """Convert STFT to fbank feats
+
+    The arguments is same as librosa.filters.mel
+
+    Args:
+        fs: number > 0 [scalar] sampling rate of the incoming signal
+        n_fft: int > 0 [scalar] number of FFT components
+        n_mels: int > 0 [scalar] number of Mel bands to generate
+        fmin: float >= 0 [scalar] lowest frequency (in Hz)
+        fmax: float >= 0 [scalar] highest frequency (in Hz).
+            If `None`, use `fmax = fs / 2.0`
+        htk: use HTK formula instead of Slaney
+        norm: {None, 1, np.inf} [scalar]
+            if 1, divide the triangular mel weights by the width of the mel band
+            (area normalization).  Otherwise, leave all the triangles aiming for
+            a peak value of 1.0
+
+    """
+
+    def __init__(
+        self,
+        fs: int = 16000,
+        n_fft: int = 512,
+        n_mels: int = 80,
+        fmin: float = 0.0,
+        fmax: float = None,
+        htk: bool = False,
+        norm=1,
+    ):
+        super().__init__()
+
+        _mel_options = dict(
+            sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
+        )
+        self.mel_options = _mel_options
+
+        # Note(kamo): The mel matrix of librosa is different from kaldi.
+        melmat = librosa.filters.mel(**_mel_options)
+        # melmat: (D2, D1) -> (D1, D2)
+        self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
+
+    def extra_repr(self):
+        return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
+
+    def forward(
+        self, feat: torch.Tensor, ilens: torch.LongTensor
+    ) -> Tuple[torch.Tensor, torch.LongTensor]:
+        # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
+        mel_feat = torch.matmul(feat, self.melmat)
+
+        logmel_feat = (mel_feat + 1e-20).log()
+        # Zero padding
+        logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
+        return logmel_feat, ilens
+
+
+class GlobalMVN(torch.nn.Module):
+    """Apply global mean and variance normalization
+
+    Args:
+        stats_file(str): npy file of 1-dim array or text file.
+            From the _first element to
+            the {(len(array) - 1) / 2}th element are treated as
+            the sum of features,
+            and the rest excluding the last elements are
+            treated as the sum of the square value of features,
+            and the last elements eqauls to the number of samples.
+        std_floor(float):
+    """
+
+    def __init__(
+        self,
+        stats_file: str,
+        norm_means: bool = True,
+        norm_vars: bool = True,
+        eps: float = 1.0e-20,
+    ):
+        super().__init__()
+        self.norm_means = norm_means
+        self.norm_vars = norm_vars
+
+        self.stats_file = stats_file
+        stats = np.load(stats_file)
+
+        stats = stats.astype(float)
+        assert (len(stats) - 1) % 2 == 0, stats.shape
+
+        count = stats.flatten()[-1]
+        mean = stats[: (len(stats) - 1) // 2] / count
+        var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean
+        std = np.maximum(np.sqrt(var), eps)
+
+        self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32)))
+        self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
+
+    def extra_repr(self):
+        return (
+            f"stats_file={self.stats_file}, "
+            f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
+        )
+
+    def forward(
+        self, x: torch.Tensor, ilens: torch.LongTensor
+    ) -> Tuple[torch.Tensor, torch.LongTensor]:
+        # feat: (B, T, D)
+        if self.norm_means:
+            x += self.bias.type_as(x)
+            x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
+
+        if self.norm_vars:
+            x *= self.scale.type_as(x)
+        return x, ilens
+
+
+class UtteranceMVN(torch.nn.Module):
+    def __init__(
+        self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20
+    ):
+        super().__init__()
+        self.norm_means = norm_means
+        self.norm_vars = norm_vars
+        self.eps = eps
+
+    def extra_repr(self):
+        return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
+
+    def forward(
+        self, x: torch.Tensor, ilens: torch.LongTensor
+    ) -> Tuple[torch.Tensor, torch.LongTensor]:
+        return utterance_mvn(
+            x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps
+        )
+
+
+def utterance_mvn(
+    x: torch.Tensor,
+    ilens: torch.LongTensor,
+    norm_means: bool = True,
+    norm_vars: bool = False,
+    eps: float = 1.0e-20,
+) -> Tuple[torch.Tensor, torch.LongTensor]:
+    """Apply utterance mean and variance normalization
+
+    Args:
+        x: (B, T, D), assumed zero padded
+        ilens: (B, T, D)
+        norm_means:
+        norm_vars:
+        eps:
+
+    """
+    ilens_ = ilens.type_as(x)
+    # mean: (B, D)
+    mean = x.sum(dim=1) / ilens_[:, None]
+
+    if norm_means:
+        x -= mean[:, None, :]
+        x_ = x
+    else:
+        x_ = x - mean[:, None, :]
+
+    # Zero padding
+    x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
+    if norm_vars:
+        var = x_.pow(2).sum(dim=1) / ilens_[:, None]
+        var = torch.clamp(var, min=eps)
+        x /= var.sqrt()[:, None, :]
+        x_ = x
+    return x_, ilens
+
+
+def feature_transform_for(args, n_fft):
+    return FeatureTransform(
+        # Mel options,
+        fs=args.fbank_fs,
+        n_fft=n_fft,
+        n_mels=args.n_mels,
+        fmin=args.fbank_fmin,
+        fmax=args.fbank_fmax,
+        # Normalization
+        stats_file=args.stats_file,
+        apply_uttmvn=args.apply_uttmvn,
+        uttmvn_norm_means=args.uttmvn_norm_means,
+        uttmvn_norm_vars=args.uttmvn_norm_vars,
+    )
diff --git a/funasr/models/frontend/frontends_utils/frontend.py b/funasr/models/frontend/frontends_utils/frontend.py
new file mode 100644
index 0000000..bf266e9
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/frontend.py
@@ -0,0 +1,151 @@
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy
+import torch
+import torch.nn as nn
+from torch_complex.tensor import ComplexTensor
+
+from funasr.models.frontend.frontends_utils.dnn_beamformer import DNN_Beamformer
+from funasr.models.frontend.frontends_utils.dnn_wpe import DNN_WPE
+
+
+class Frontend(nn.Module):
+    def __init__(
+        self,
+        idim: int,
+        # WPE options
+        use_wpe: bool = False,
+        wtype: str = "blstmp",
+        wlayers: int = 3,
+        wunits: int = 300,
+        wprojs: int = 320,
+        wdropout_rate: float = 0.0,
+        taps: int = 5,
+        delay: int = 3,
+        use_dnn_mask_for_wpe: bool = True,
+        # Beamformer options
+        use_beamformer: bool = False,
+        btype: str = "blstmp",
+        blayers: int = 3,
+        bunits: int = 300,
+        bprojs: int = 320,
+        bnmask: int = 2,
+        badim: int = 320,
+        ref_channel: int = -1,
+        bdropout_rate=0.0,
+    ):
+        super().__init__()
+
+        self.use_beamformer = use_beamformer
+        self.use_wpe = use_wpe
+        self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
+        # use frontend for all the data,
+        # e.g. in the case of multi-speaker speech separation
+        self.use_frontend_for_all = bnmask > 2
+
+        if self.use_wpe:
+            if self.use_dnn_mask_for_wpe:
+                # Use DNN for power estimation
+                # (Not observed significant gains)
+                iterations = 1
+            else:
+                # Performing as conventional WPE, without DNN Estimator
+                iterations = 2
+
+            self.wpe = DNN_WPE(
+                wtype=wtype,
+                widim=idim,
+                wunits=wunits,
+                wprojs=wprojs,
+                wlayers=wlayers,
+                taps=taps,
+                delay=delay,
+                dropout_rate=wdropout_rate,
+                iterations=iterations,
+                use_dnn_mask=use_dnn_mask_for_wpe,
+            )
+        else:
+            self.wpe = None
+
+        if self.use_beamformer:
+            self.beamformer = DNN_Beamformer(
+                btype=btype,
+                bidim=idim,
+                bunits=bunits,
+                bprojs=bprojs,
+                blayers=blayers,
+                bnmask=bnmask,
+                dropout_rate=bdropout_rate,
+                badim=badim,
+                ref_channel=ref_channel,
+            )
+        else:
+            self.beamformer = None
+
+    def forward(
+        self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
+    ) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
+        assert len(x) == len(ilens), (len(x), len(ilens))
+        # (B, T, F) or (B, T, C, F)
+        if x.dim() not in (3, 4):
+            raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
+        if not torch.is_tensor(ilens):
+            ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
+
+        mask = None
+        h = x
+        if h.dim() == 4:
+            if self.training:
+                choices = [(False, False)] if not self.use_frontend_for_all else []
+                if self.use_wpe:
+                    choices.append((True, False))
+
+                if self.use_beamformer:
+                    choices.append((False, True))
+
+                use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
+
+            else:
+                use_wpe = self.use_wpe
+                use_beamformer = self.use_beamformer
+
+            # 1. WPE
+            if use_wpe:
+                # h: (B, T, C, F) -> h: (B, T, C, F)
+                h, ilens, mask = self.wpe(h, ilens)
+
+            # 2. Beamformer
+            if use_beamformer:
+                # h: (B, T, C, F) -> h: (B, T, F)
+                h, ilens, mask = self.beamformer(h, ilens)
+
+        return h, ilens, mask
+
+
+def frontend_for(args, idim):
+    return Frontend(
+        idim=idim,
+        # WPE options
+        use_wpe=args.use_wpe,
+        wtype=args.wtype,
+        wlayers=args.wlayers,
+        wunits=args.wunits,
+        wprojs=args.wprojs,
+        wdropout_rate=args.wdropout_rate,
+        taps=args.wpe_taps,
+        delay=args.wpe_delay,
+        use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
+        # Beamformer options
+        use_beamformer=args.use_beamformer,
+        btype=args.btype,
+        blayers=args.blayers,
+        bunits=args.bunits,
+        bprojs=args.bprojs,
+        bnmask=args.bnmask,
+        badim=args.badim,
+        ref_channel=args.ref_channel,
+        bdropout_rate=args.bdropout_rate,
+    )
diff --git a/funasr/models/frontend/frontends_utils/mask_estimator.py b/funasr/models/frontend/frontends_utils/mask_estimator.py
new file mode 100644
index 0000000..53072bf
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/mask_estimator.py
@@ -0,0 +1,77 @@
+from typing import Tuple
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torch_complex.tensor import ComplexTensor
+
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.rnn.encoders import RNN
+from funasr.modules.rnn.encoders import RNNP
+
+
+class MaskEstimator(torch.nn.Module):
+    def __init__(self, type, idim, layers, units, projs, dropout, nmask=1):
+        super().__init__()
+        subsample = np.ones(layers + 1, dtype=np.int32)
+
+        typ = type.lstrip("vgg").rstrip("p")
+        if type[-1] == "p":
+            self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ)
+        else:
+            self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ)
+
+        self.type = type
+        self.nmask = nmask
+        self.linears = torch.nn.ModuleList(
+            [torch.nn.Linear(projs, idim) for _ in range(nmask)]
+        )
+
+    def forward(
+        self, xs: ComplexTensor, ilens: torch.LongTensor
+    ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
+        """The forward function
+
+        Args:
+            xs: (B, F, C, T)
+            ilens: (B,)
+        Returns:
+            hs (torch.Tensor): The hidden vector (B, F, C, T)
+            masks: A tuple of the masks. (B, F, C, T)
+            ilens: (B,)
+        """
+        assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
+        _, _, C, input_length = xs.size()
+        # (B, F, C, T) -> (B, C, T, F)
+        xs = xs.permute(0, 2, 3, 1)
+
+        # Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
+        xs = (xs.real**2 + xs.imag**2) ** 0.5
+        # xs: (B, C, T, F) -> xs: (B * C, T, F)
+        xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
+        # ilens: (B,) -> ilens_: (B * C)
+        ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)
+
+        # xs: (B * C, T, F) -> xs: (B * C, T, D)
+        xs, _, _ = self.brnn(xs, ilens_)
+        # xs: (B * C, T, D) -> xs: (B, C, T, D)
+        xs = xs.view(-1, C, xs.size(-2), xs.size(-1))
+
+        masks = []
+        for linear in self.linears:
+            # xs: (B, C, T, D) -> mask:(B, C, T, F)
+            mask = linear(xs)
+
+            mask = torch.sigmoid(mask)
+            # Zero padding
+            mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)
+
+            # (B, C, T, F) -> (B, F, C, T)
+            mask = mask.permute(0, 3, 1, 2)
+
+            # Take cares of multi gpu cases: If input_length > max(ilens)
+            if mask.size(-1) < input_length:
+                mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
+            masks.append(mask)
+
+        return tuple(masks), ilens

--
Gitblit v1.9.1