游雁
2023-11-21 8b9c53cfc30b43129f66dad04d761b87f7dcc89f
funasr v2 setup
4个文件已修改
7个文件已添加
851 ■■■■■ 已修改文件
funasr/layers/complex_utils.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/layers/stft.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/mossformer_encoder.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/default.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/frontends_utils/__init__.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/frontends_utils/beamformer.py 84 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/frontends_utils/dnn_beamformer.py 172 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/frontends_utils/dnn_wpe.py 93 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/frontends_utils/feature_transform.py 263 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/frontends_utils/frontend.py 151 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/frontends_utils/mask_estimator.py 77 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/layers/complex_utils.py
@@ -9,7 +9,7 @@
    from torch_complex import functional as FC
    from torch_complex.tensor import ComplexTensor
except:
    raise "Please install torch_complex firstly"
    print("Please install torch_complex firstly")
funasr/layers/stft.py
@@ -8,7 +8,7 @@
try:
    from torch_complex.tensor import ComplexTensor
except:
    raise "Please install torch_complex firstly"
    print("Please install torch_complex firstly")
from funasr.modules.nets_utils import make_pad_mask
from funasr.layers.complex_utils import is_complex
from funasr.layers.inversible_interface import InversibleInterface
funasr/models/encoder/mossformer_encoder.py
@@ -4,7 +4,7 @@
try:
    from rotary_embedding_torch import RotaryEmbedding
except:
    raise "Please install rotary_embedding_torch by: \n pip install -U funasr[all]"
    print("Please install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch")
from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
from funasr.modules.embedding import ScaledSinuEmbedding
from funasr.modules.mossformer import FLASH_ShareA_FFConvM
funasr/models/frontend/default.py
@@ -9,12 +9,12 @@
try:
    from torch_complex.tensor import ComplexTensor
except:
    raise "Please install torch_complex firstly"
    print("Please install torch_complex firstly")
from funasr.layers.log_mel import LogMel
from funasr.layers.stft import Stft
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
from funasr.models.frontend.frontends_utils.frontend import Frontend
from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.modules.nets_utils import make_pad_mask
funasr/models/frontend/frontends_utils/__init__.py
New file
@@ -0,0 +1 @@
"""Initialize sub package."""
funasr/models/frontend/frontends_utils/beamformer.py
New file
@@ -0,0 +1,84 @@
import torch
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor
def get_power_spectral_density_matrix(
    xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15
) -> ComplexTensor:
    """Return cross-channel power spectral density (PSD) matrix
    Args:
        xs (ComplexTensor): (..., F, C, T)
        mask (torch.Tensor): (..., F, C, T)
        normalization (bool):
        eps (float):
    Returns
        psd (ComplexTensor): (..., F, C, C)
    """
    # outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2)
    psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()])
    # Averaging mask along C: (..., C, T) -> (..., T)
    mask = mask.mean(dim=-2)
    # Normalized mask along T: (..., T)
    if normalization:
        # If assuming the tensor is padded with zero, the summation along
        # the time axis is same regardless of the padding length.
        mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
    # psd: (..., T, C, C)
    psd = psd_Y * mask[..., None, None]
    # (..., T, C, C) -> (..., C, C)
    psd = psd.sum(dim=-3)
    return psd
def get_mvdr_vector(
    psd_s: ComplexTensor,
    psd_n: ComplexTensor,
    reference_vector: torch.Tensor,
    eps: float = 1e-15,
) -> ComplexTensor:
    """Return the MVDR(Minimum Variance Distortionless Response) vector:
        h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u
    Reference:
        On optimal frequency-domain multichannel linear filtering
        for noise reduction; M. Souden et al., 2010;
        https://ieeexplore.ieee.org/document/5089420
    Args:
        psd_s (ComplexTensor): (..., F, C, C)
        psd_n (ComplexTensor): (..., F, C, C)
        reference_vector (torch.Tensor): (..., C)
        eps (float):
    Returns:
        beamform_vector (ComplexTensor)r: (..., F, C)
    """
    # Add eps
    C = psd_n.size(-1)
    eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
    shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
    eye = eye.view(*shape)
    psd_n += eps * eye
    # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
    numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s])
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
    # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
    beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
    return beamform_vector
def apply_beamforming_vector(
    beamform_vector: ComplexTensor, mix: ComplexTensor
) -> ComplexTensor:
    # (..., C) x (..., C, T) -> (..., T)
    es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix])
    return es
funasr/models/frontend/frontends_utils/dnn_beamformer.py
New file
@@ -0,0 +1,172 @@
"""DNN beamformer module."""
from typing import Tuple
import torch
from torch.nn import functional as F
from funasr.models.frontend.frontends_utils.beamformer import apply_beamforming_vector
from funasr.models.frontend.frontends_utils.beamformer import get_mvdr_vector
from funasr.models.frontend.frontends_utils.beamformer import (
    get_power_spectral_density_matrix,  # noqa: H301
)
from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
from torch_complex.tensor import ComplexTensor
class DNN_Beamformer(torch.nn.Module):
    """DNN mask based Beamformer
    Citation:
        Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
        https://arxiv.org/abs/1703.04783
    """
    def __init__(
        self,
        bidim,
        btype="blstmp",
        blayers=3,
        bunits=300,
        bprojs=320,
        bnmask=2,
        dropout_rate=0.0,
        badim=320,
        ref_channel: int = -1,
        beamformer_type="mvdr",
    ):
        super().__init__()
        self.mask = MaskEstimator(
            btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
        )
        self.ref = AttentionReference(bidim, badim)
        self.ref_channel = ref_channel
        self.nmask = bnmask
        if beamformer_type != "mvdr":
            raise ValueError(
                "Not supporting beamformer_type={}".format(beamformer_type)
            )
        self.beamformer_type = beamformer_type
    def forward(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """The forward function
        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq
        Args:
            data (ComplexTensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
        Returns:
            enhanced (ComplexTensor): (B, T, F)
            ilens (torch.Tensor): (B,)
        """
        def apply_beamforming(data, ilens, psd_speech, psd_noise):
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech, ilens)
            else:
                # (optional) Create onehot vector for fixed reference microphone
                u = torch.zeros(
                    *(data.size()[:-3] + (data.size(-2),)), device=data.device
                )
                u[..., self.ref_channel].fill_(1)
            ws = get_mvdr_vector(psd_speech, psd_noise, u)
            enhanced = apply_beamforming_vector(ws, data)
            return enhanced, ws
        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)
        # mask: (B, F, C, T)
        masks, _ = self.mask(data, ilens)
        assert self.nmask == len(masks)
        if self.nmask == 2:  # (mask_speech, mask_noise)
            mask_speech, mask_noise = masks
            psd_speech = get_power_spectral_density_matrix(data, mask_speech)
            psd_noise = get_power_spectral_density_matrix(data, mask_noise)
            enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise)
            # (..., F, T) -> (..., T, F)
            enhanced = enhanced.transpose(-1, -2)
            mask_speech = mask_speech.transpose(-1, -3)
        else:  # multi-speaker case: (mask_speech1, ..., mask_noise)
            mask_speech = list(masks[:-1])
            mask_noise = masks[-1]
            psd_speeches = [
                get_power_spectral_density_matrix(data, mask) for mask in mask_speech
            ]
            psd_noise = get_power_spectral_density_matrix(data, mask_noise)
            enhanced = []
            ws = []
            for i in range(self.nmask - 1):
                psd_speech = psd_speeches.pop(i)
                # treat all other speakers' psd_speech as noises
                enh, w = apply_beamforming(
                    data, ilens, psd_speech, sum(psd_speeches) + psd_noise
                )
                psd_speeches.insert(i, psd_speech)
                # (..., F, T) -> (..., T, F)
                enh = enh.transpose(-1, -2)
                mask_speech[i] = mask_speech[i].transpose(-1, -3)
                enhanced.append(enh)
                ws.append(w)
        return enhanced, ilens, mask_speech
class AttentionReference(torch.nn.Module):
    def __init__(self, bidim, att_dim):
        super().__init__()
        self.mlp_psd = torch.nn.Linear(bidim, att_dim)
        self.gvec = torch.nn.Linear(att_dim, 1)
    def forward(
        self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        """The forward function
        Args:
            psd_in (ComplexTensor): (B, F, C, C)
            ilens (torch.Tensor): (B,)
            scaling (float):
        Returns:
            u (torch.Tensor): (B, C)
            ilens (torch.Tensor): (B,)
        """
        B, _, C = psd_in.size()[:3]
        assert psd_in.size(2) == psd_in.size(3), psd_in.size()
        # psd_in: (B, F, C, C)
        psd = psd_in.masked_fill(
            torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
        )
        # psd: (B, F, C, C) -> (B, C, F)
        psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
        # Calculate amplitude
        psd_feat = (psd.real**2 + psd.imag**2) ** 0.5
        # (B, C, F) -> (B, C, F2)
        mlp_psd = self.mlp_psd(psd_feat)
        # (B, C, F2) -> (B, C, 1) -> (B, C)
        e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
        u = F.softmax(scaling * e, dim=-1)
        return u, ilens
funasr/models/frontend/frontends_utils/dnn_wpe.py
New file
@@ -0,0 +1,93 @@
from typing import Tuple
from pytorch_wpe import wpe_one_iteration
import torch
from torch_complex.tensor import ComplexTensor
from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
from funasr.modules.nets_utils import make_pad_mask
class DNN_WPE(torch.nn.Module):
    def __init__(
        self,
        wtype: str = "blstmp",
        widim: int = 257,
        wlayers: int = 3,
        wunits: int = 300,
        wprojs: int = 320,
        dropout_rate: float = 0.0,
        taps: int = 5,
        delay: int = 3,
        use_dnn_mask: bool = True,
        iterations: int = 1,
        normalization: bool = False,
    ):
        super().__init__()
        self.iterations = iterations
        self.taps = taps
        self.delay = delay
        self.normalization = normalization
        self.use_dnn_mask = use_dnn_mask
        self.inverse_power = True
        if self.use_dnn_mask:
            self.mask_est = MaskEstimator(
                wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1
            )
    def forward(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
        """The forward function
        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq or Some dimension of the feature vector
        Args:
            data: (B, C, T, F)
            ilens: (B,)
        Returns:
            data: (B, C, T, F)
            ilens: (B,)
        """
        # (B, T, C, F) -> (B, F, C, T)
        enhanced = data = data.permute(0, 3, 2, 1)
        mask = None
        for i in range(self.iterations):
            # Calculate power: (..., C, T)
            power = enhanced.real**2 + enhanced.imag**2
            if i == 0 and self.use_dnn_mask:
                # mask: (B, F, C, T)
                (mask,), _ = self.mask_est(enhanced, ilens)
                if self.normalization:
                    # Normalize along T
                    mask = mask / mask.sum(dim=-1)[..., None]
                # (..., C, T) * (..., C, T) -> (..., C, T)
                power = power * mask
            # Averaging along the channel axis: (..., C, T) -> (..., T)
            power = power.mean(dim=-2)
            # enhanced: (..., C, T) -> (..., C, T)
            enhanced = wpe_one_iteration(
                data.contiguous(),
                power,
                taps=self.taps,
                delay=self.delay,
                inverse_power=self.inverse_power,
            )
            enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)
        # (B, F, C, T) -> (B, T, C, F)
        enhanced = enhanced.permute(0, 3, 2, 1)
        if mask is not None:
            mask = mask.transpose(-1, -3)
        return enhanced, ilens, mask
funasr/models/frontend/frontends_utils/feature_transform.py
New file
@@ -0,0 +1,263 @@
from typing import List
from typing import Tuple
from typing import Union
import librosa
import numpy as np
import torch
from torch_complex.tensor import ComplexTensor
from funasr.modules.nets_utils import make_pad_mask
class FeatureTransform(torch.nn.Module):
    def __init__(
        self,
        # Mel options,
        fs: int = 16000,
        n_fft: int = 512,
        n_mels: int = 80,
        fmin: float = 0.0,
        fmax: float = None,
        # Normalization
        stats_file: str = None,
        apply_uttmvn: bool = True,
        uttmvn_norm_means: bool = True,
        uttmvn_norm_vars: bool = False,
    ):
        super().__init__()
        self.apply_uttmvn = apply_uttmvn
        self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
        self.stats_file = stats_file
        if stats_file is not None:
            self.global_mvn = GlobalMVN(stats_file)
        else:
            self.global_mvn = None
        if self.apply_uttmvn is not None:
            self.uttmvn = UtteranceMVN(
                norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars
            )
        else:
            self.uttmvn = None
    def forward(
        self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        # (B, T, F) or (B, T, C, F)
        if x.dim() not in (3, 4):
            raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
        if not torch.is_tensor(ilens):
            ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)
        if x.dim() == 4:
            # h: (B, T, C, F) -> h: (B, T, F)
            if self.training:
                # Select 1ch randomly
                ch = np.random.randint(x.size(2))
                h = x[:, :, ch, :]
            else:
                # Use the first channel
                h = x[:, :, 0, :]
        else:
            h = x
        # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
        h = h.real**2 + h.imag**2
        h, _ = self.logmel(h, ilens)
        if self.stats_file is not None:
            h, _ = self.global_mvn(h, ilens)
        if self.apply_uttmvn:
            h, _ = self.uttmvn(h, ilens)
        return h, ilens
class LogMel(torch.nn.Module):
    """Convert STFT to fbank feats
    The arguments is same as librosa.filters.mel
    Args:
        fs: number > 0 [scalar] sampling rate of the incoming signal
        n_fft: int > 0 [scalar] number of FFT components
        n_mels: int > 0 [scalar] number of Mel bands to generate
        fmin: float >= 0 [scalar] lowest frequency (in Hz)
        fmax: float >= 0 [scalar] highest frequency (in Hz).
            If `None`, use `fmax = fs / 2.0`
        htk: use HTK formula instead of Slaney
        norm: {None, 1, np.inf} [scalar]
            if 1, divide the triangular mel weights by the width of the mel band
            (area normalization).  Otherwise, leave all the triangles aiming for
            a peak value of 1.0
    """
    def __init__(
        self,
        fs: int = 16000,
        n_fft: int = 512,
        n_mels: int = 80,
        fmin: float = 0.0,
        fmax: float = None,
        htk: bool = False,
        norm=1,
    ):
        super().__init__()
        _mel_options = dict(
            sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
        )
        self.mel_options = _mel_options
        # Note(kamo): The mel matrix of librosa is different from kaldi.
        melmat = librosa.filters.mel(**_mel_options)
        # melmat: (D2, D1) -> (D1, D2)
        self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
    def extra_repr(self):
        return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
    def forward(
        self, feat: torch.Tensor, ilens: torch.LongTensor
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
        mel_feat = torch.matmul(feat, self.melmat)
        logmel_feat = (mel_feat + 1e-20).log()
        # Zero padding
        logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
        return logmel_feat, ilens
class GlobalMVN(torch.nn.Module):
    """Apply global mean and variance normalization
    Args:
        stats_file(str): npy file of 1-dim array or text file.
            From the _first element to
            the {(len(array) - 1) / 2}th element are treated as
            the sum of features,
            and the rest excluding the last elements are
            treated as the sum of the square value of features,
            and the last elements eqauls to the number of samples.
        std_floor(float):
    """
    def __init__(
        self,
        stats_file: str,
        norm_means: bool = True,
        norm_vars: bool = True,
        eps: float = 1.0e-20,
    ):
        super().__init__()
        self.norm_means = norm_means
        self.norm_vars = norm_vars
        self.stats_file = stats_file
        stats = np.load(stats_file)
        stats = stats.astype(float)
        assert (len(stats) - 1) % 2 == 0, stats.shape
        count = stats.flatten()[-1]
        mean = stats[: (len(stats) - 1) // 2] / count
        var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean
        std = np.maximum(np.sqrt(var), eps)
        self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32)))
        self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
    def extra_repr(self):
        return (
            f"stats_file={self.stats_file}, "
            f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
        )
    def forward(
        self, x: torch.Tensor, ilens: torch.LongTensor
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        # feat: (B, T, D)
        if self.norm_means:
            x += self.bias.type_as(x)
            x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
        if self.norm_vars:
            x *= self.scale.type_as(x)
        return x, ilens
class UtteranceMVN(torch.nn.Module):
    def __init__(
        self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20
    ):
        super().__init__()
        self.norm_means = norm_means
        self.norm_vars = norm_vars
        self.eps = eps
    def extra_repr(self):
        return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
    def forward(
        self, x: torch.Tensor, ilens: torch.LongTensor
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        return utterance_mvn(
            x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps
        )
def utterance_mvn(
    x: torch.Tensor,
    ilens: torch.LongTensor,
    norm_means: bool = True,
    norm_vars: bool = False,
    eps: float = 1.0e-20,
) -> Tuple[torch.Tensor, torch.LongTensor]:
    """Apply utterance mean and variance normalization
    Args:
        x: (B, T, D), assumed zero padded
        ilens: (B, T, D)
        norm_means:
        norm_vars:
        eps:
    """
    ilens_ = ilens.type_as(x)
    # mean: (B, D)
    mean = x.sum(dim=1) / ilens_[:, None]
    if norm_means:
        x -= mean[:, None, :]
        x_ = x
    else:
        x_ = x - mean[:, None, :]
    # Zero padding
    x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
    if norm_vars:
        var = x_.pow(2).sum(dim=1) / ilens_[:, None]
        var = torch.clamp(var, min=eps)
        x /= var.sqrt()[:, None, :]
        x_ = x
    return x_, ilens
def feature_transform_for(args, n_fft):
    return FeatureTransform(
        # Mel options,
        fs=args.fbank_fs,
        n_fft=n_fft,
        n_mels=args.n_mels,
        fmin=args.fbank_fmin,
        fmax=args.fbank_fmax,
        # Normalization
        stats_file=args.stats_file,
        apply_uttmvn=args.apply_uttmvn,
        uttmvn_norm_means=args.uttmvn_norm_means,
        uttmvn_norm_vars=args.uttmvn_norm_vars,
    )
funasr/models/frontend/frontends_utils/frontend.py
New file
@@ -0,0 +1,151 @@
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy
import torch
import torch.nn as nn
from torch_complex.tensor import ComplexTensor
from funasr.models.frontend.frontends_utils.dnn_beamformer import DNN_Beamformer
from funasr.models.frontend.frontends_utils.dnn_wpe import DNN_WPE
class Frontend(nn.Module):
    def __init__(
        self,
        idim: int,
        # WPE options
        use_wpe: bool = False,
        wtype: str = "blstmp",
        wlayers: int = 3,
        wunits: int = 300,
        wprojs: int = 320,
        wdropout_rate: float = 0.0,
        taps: int = 5,
        delay: int = 3,
        use_dnn_mask_for_wpe: bool = True,
        # Beamformer options
        use_beamformer: bool = False,
        btype: str = "blstmp",
        blayers: int = 3,
        bunits: int = 300,
        bprojs: int = 320,
        bnmask: int = 2,
        badim: int = 320,
        ref_channel: int = -1,
        bdropout_rate=0.0,
    ):
        super().__init__()
        self.use_beamformer = use_beamformer
        self.use_wpe = use_wpe
        self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
        # use frontend for all the data,
        # e.g. in the case of multi-speaker speech separation
        self.use_frontend_for_all = bnmask > 2
        if self.use_wpe:
            if self.use_dnn_mask_for_wpe:
                # Use DNN for power estimation
                # (Not observed significant gains)
                iterations = 1
            else:
                # Performing as conventional WPE, without DNN Estimator
                iterations = 2
            self.wpe = DNN_WPE(
                wtype=wtype,
                widim=idim,
                wunits=wunits,
                wprojs=wprojs,
                wlayers=wlayers,
                taps=taps,
                delay=delay,
                dropout_rate=wdropout_rate,
                iterations=iterations,
                use_dnn_mask=use_dnn_mask_for_wpe,
            )
        else:
            self.wpe = None
        if self.use_beamformer:
            self.beamformer = DNN_Beamformer(
                btype=btype,
                bidim=idim,
                bunits=bunits,
                bprojs=bprojs,
                blayers=blayers,
                bnmask=bnmask,
                dropout_rate=bdropout_rate,
                badim=badim,
                ref_channel=ref_channel,
            )
        else:
            self.beamformer = None
    def forward(
        self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
    ) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
        assert len(x) == len(ilens), (len(x), len(ilens))
        # (B, T, F) or (B, T, C, F)
        if x.dim() not in (3, 4):
            raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
        if not torch.is_tensor(ilens):
            ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
        mask = None
        h = x
        if h.dim() == 4:
            if self.training:
                choices = [(False, False)] if not self.use_frontend_for_all else []
                if self.use_wpe:
                    choices.append((True, False))
                if self.use_beamformer:
                    choices.append((False, True))
                use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
            else:
                use_wpe = self.use_wpe
                use_beamformer = self.use_beamformer
            # 1. WPE
            if use_wpe:
                # h: (B, T, C, F) -> h: (B, T, C, F)
                h, ilens, mask = self.wpe(h, ilens)
            # 2. Beamformer
            if use_beamformer:
                # h: (B, T, C, F) -> h: (B, T, F)
                h, ilens, mask = self.beamformer(h, ilens)
        return h, ilens, mask
def frontend_for(args, idim):
    return Frontend(
        idim=idim,
        # WPE options
        use_wpe=args.use_wpe,
        wtype=args.wtype,
        wlayers=args.wlayers,
        wunits=args.wunits,
        wprojs=args.wprojs,
        wdropout_rate=args.wdropout_rate,
        taps=args.wpe_taps,
        delay=args.wpe_delay,
        use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
        # Beamformer options
        use_beamformer=args.use_beamformer,
        btype=args.btype,
        blayers=args.blayers,
        bunits=args.bunits,
        bprojs=args.bprojs,
        bnmask=args.bnmask,
        badim=args.badim,
        ref_channel=args.ref_channel,
        bdropout_rate=args.bdropout_rate,
    )
funasr/models/frontend/frontends_utils/mask_estimator.py
New file
@@ -0,0 +1,77 @@
from typing import Tuple
import numpy as np
import torch
from torch.nn import functional as F
from torch_complex.tensor import ComplexTensor
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.rnn.encoders import RNN
from funasr.modules.rnn.encoders import RNNP
class MaskEstimator(torch.nn.Module):
    def __init__(self, type, idim, layers, units, projs, dropout, nmask=1):
        super().__init__()
        subsample = np.ones(layers + 1, dtype=np.int32)
        typ = type.lstrip("vgg").rstrip("p")
        if type[-1] == "p":
            self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ)
        else:
            self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ)
        self.type = type
        self.nmask = nmask
        self.linears = torch.nn.ModuleList(
            [torch.nn.Linear(projs, idim) for _ in range(nmask)]
        )
    def forward(
        self, xs: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
        """The forward function
        Args:
            xs: (B, F, C, T)
            ilens: (B,)
        Returns:
            hs (torch.Tensor): The hidden vector (B, F, C, T)
            masks: A tuple of the masks. (B, F, C, T)
            ilens: (B,)
        """
        assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
        _, _, C, input_length = xs.size()
        # (B, F, C, T) -> (B, C, T, F)
        xs = xs.permute(0, 2, 3, 1)
        # Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
        xs = (xs.real**2 + xs.imag**2) ** 0.5
        # xs: (B, C, T, F) -> xs: (B * C, T, F)
        xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
        # ilens: (B,) -> ilens_: (B * C)
        ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)
        # xs: (B * C, T, F) -> xs: (B * C, T, D)
        xs, _, _ = self.brnn(xs, ilens_)
        # xs: (B * C, T, D) -> xs: (B, C, T, D)
        xs = xs.view(-1, C, xs.size(-2), xs.size(-1))
        masks = []
        for linear in self.linears:
            # xs: (B, C, T, D) -> mask:(B, C, T, F)
            mask = linear(xs)
            mask = torch.sigmoid(mask)
            # Zero padding
            mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)
            # (B, C, T, F) -> (B, F, C, T)
            mask = mask.permute(0, 3, 1, 2)
            # Take cares of multi gpu cases: If input_length > max(ilens)
            if mask.size(-1) < input_length:
                mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
            masks.append(mask)
        return tuple(masks), ilens