From 8b9c53cfc30b43129f66dad04d761b87f7dcc89f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 21 十一月 2023 15:47:18 +0800
Subject: [PATCH] funasr v2 setup
---
funasr/layers/complex_utils.py | 2
funasr/models/frontend/frontends_utils/mask_estimator.py | 77 +++++
funasr/models/encoder/mossformer_encoder.py | 2
funasr/models/frontend/default.py | 4
funasr/models/frontend/frontends_utils/dnn_beamformer.py | 172 +++++++++++
funasr/models/frontend/frontends_utils/__init__.py | 1
funasr/models/frontend/frontends_utils/feature_transform.py | 263 +++++++++++++++++
funasr/models/frontend/frontends_utils/dnn_wpe.py | 93 ++++++
funasr/layers/stft.py | 2
funasr/models/frontend/frontends_utils/beamformer.py | 84 +++++
funasr/models/frontend/frontends_utils/frontend.py | 151 ++++++++++
11 files changed, 846 insertions(+), 5 deletions(-)
diff --git a/funasr/layers/complex_utils.py b/funasr/layers/complex_utils.py
index d6f7c6d..5d313c6 100644
--- a/funasr/layers/complex_utils.py
+++ b/funasr/layers/complex_utils.py
@@ -9,7 +9,7 @@
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor
except:
- raise "Please install torch_complex firstly"
+ print("Please install torch_complex firstly")
diff --git a/funasr/layers/stft.py b/funasr/layers/stft.py
index 67ebf7a..c71af8e 100644
--- a/funasr/layers/stft.py
+++ b/funasr/layers/stft.py
@@ -8,7 +8,7 @@
try:
from torch_complex.tensor import ComplexTensor
except:
- raise "Please install torch_complex firstly"
+ print("Please install torch_complex firstly")
from funasr.modules.nets_utils import make_pad_mask
from funasr.layers.complex_utils import is_complex
from funasr.layers.inversible_interface import InversibleInterface
diff --git a/funasr/models/encoder/mossformer_encoder.py b/funasr/models/encoder/mossformer_encoder.py
index f7d9c47..6c092e7 100644
--- a/funasr/models/encoder/mossformer_encoder.py
+++ b/funasr/models/encoder/mossformer_encoder.py
@@ -4,7 +4,7 @@
try:
from rotary_embedding_torch import RotaryEmbedding
except:
- raise "Please install rotary_embedding_torch by: \n pip install -U funasr[all]"
+ print("Please install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch")
from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
from funasr.modules.embedding import ScaledSinuEmbedding
from funasr.modules.mossformer import FLASH_ShareA_FFConvM
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 8d60e20..3435321 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -9,12 +9,12 @@
try:
from torch_complex.tensor import ComplexTensor
except:
- raise "Please install torch_complex firstly"
+ print("Please install torch_complex firstly")
from funasr.layers.log_mel import LogMel
from funasr.layers.stft import Stft
from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.modules.frontends.frontend import Frontend
+from funasr.models.frontend.frontends_utils.frontend import Frontend
from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.modules.nets_utils import make_pad_mask
diff --git a/funasr/models/frontend/frontends_utils/__init__.py b/funasr/models/frontend/frontends_utils/__init__.py
new file mode 100644
index 0000000..b7f1773
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/__init__.py
@@ -0,0 +1 @@
+"""Initialize sub package."""
diff --git a/funasr/models/frontend/frontends_utils/beamformer.py b/funasr/models/frontend/frontends_utils/beamformer.py
new file mode 100644
index 0000000..f3eccee
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/beamformer.py
@@ -0,0 +1,84 @@
+import torch
+from torch_complex import functional as FC
+from torch_complex.tensor import ComplexTensor
+
+
+def get_power_spectral_density_matrix(
+ xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15
+) -> ComplexTensor:
+ """Return cross-channel power spectral density (PSD) matrix
+
+ Args:
+ xs (ComplexTensor): (..., F, C, T)
+ mask (torch.Tensor): (..., F, C, T)
+ normalization (bool):
+ eps (float):
+ Returns
+ psd (ComplexTensor): (..., F, C, C)
+
+ """
+ # outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2)
+ psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()])
+
+ # Averaging mask along C: (..., C, T) -> (..., T)
+ mask = mask.mean(dim=-2)
+
+ # Normalized mask along T: (..., T)
+ if normalization:
+ # If assuming the tensor is padded with zero, the summation along
+ # the time axis is same regardless of the padding length.
+ mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
+
+ # psd: (..., T, C, C)
+ psd = psd_Y * mask[..., None, None]
+ # (..., T, C, C) -> (..., C, C)
+ psd = psd.sum(dim=-3)
+
+ return psd
+
+
+def get_mvdr_vector(
+ psd_s: ComplexTensor,
+ psd_n: ComplexTensor,
+ reference_vector: torch.Tensor,
+ eps: float = 1e-15,
+) -> ComplexTensor:
+ """Return the MVDR(Minimum Variance Distortionless Response) vector:
+
+ h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u
+
+ Reference:
+ On optimal frequency-domain multichannel linear filtering
+ for noise reduction; M. Souden et al., 2010;
+ https://ieeexplore.ieee.org/document/5089420
+
+ Args:
+ psd_s (ComplexTensor): (..., F, C, C)
+ psd_n (ComplexTensor): (..., F, C, C)
+ reference_vector (torch.Tensor): (..., C)
+ eps (float):
+ Returns:
+ beamform_vector (ComplexTensor)r: (..., F, C)
+ """
+ # Add eps
+ C = psd_n.size(-1)
+ eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
+ shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
+ eye = eye.view(*shape)
+ psd_n += eps * eye
+
+ # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
+ numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s])
+ # ws: (..., C, C) / (...,) -> (..., C, C)
+ ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
+ # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
+ beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
+ return beamform_vector
+
+
+def apply_beamforming_vector(
+ beamform_vector: ComplexTensor, mix: ComplexTensor
+) -> ComplexTensor:
+ # (..., C) x (..., C, T) -> (..., T)
+ es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix])
+ return es
diff --git a/funasr/models/frontend/frontends_utils/dnn_beamformer.py b/funasr/models/frontend/frontends_utils/dnn_beamformer.py
new file mode 100644
index 0000000..05e241d
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/dnn_beamformer.py
@@ -0,0 +1,172 @@
+"""DNN beamformer module."""
+from typing import Tuple
+
+import torch
+from torch.nn import functional as F
+
+from funasr.models.frontend.frontends_utils.beamformer import apply_beamforming_vector
+from funasr.models.frontend.frontends_utils.beamformer import get_mvdr_vector
+from funasr.models.frontend.frontends_utils.beamformer import (
+ get_power_spectral_density_matrix, # noqa: H301
+)
+from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
+from torch_complex.tensor import ComplexTensor
+
+
+class DNN_Beamformer(torch.nn.Module):
+ """DNN mask based Beamformer
+
+ Citation:
+ Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
+ https://arxiv.org/abs/1703.04783
+
+ """
+
+ def __init__(
+ self,
+ bidim,
+ btype="blstmp",
+ blayers=3,
+ bunits=300,
+ bprojs=320,
+ bnmask=2,
+ dropout_rate=0.0,
+ badim=320,
+ ref_channel: int = -1,
+ beamformer_type="mvdr",
+ ):
+ super().__init__()
+ self.mask = MaskEstimator(
+ btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
+ )
+ self.ref = AttentionReference(bidim, badim)
+ self.ref_channel = ref_channel
+
+ self.nmask = bnmask
+
+ if beamformer_type != "mvdr":
+ raise ValueError(
+ "Not supporting beamformer_type={}".format(beamformer_type)
+ )
+ self.beamformer_type = beamformer_type
+
+ def forward(
+ self, data: ComplexTensor, ilens: torch.LongTensor
+ ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
+ """The forward function
+
+ Notation:
+ B: Batch
+ C: Channel
+ T: Time or Sequence length
+ F: Freq
+
+ Args:
+ data (ComplexTensor): (B, T, C, F)
+ ilens (torch.Tensor): (B,)
+ Returns:
+ enhanced (ComplexTensor): (B, T, F)
+ ilens (torch.Tensor): (B,)
+
+ """
+
+ def apply_beamforming(data, ilens, psd_speech, psd_noise):
+ # u: (B, C)
+ if self.ref_channel < 0:
+ u, _ = self.ref(psd_speech, ilens)
+ else:
+ # (optional) Create onehot vector for fixed reference microphone
+ u = torch.zeros(
+ *(data.size()[:-3] + (data.size(-2),)), device=data.device
+ )
+ u[..., self.ref_channel].fill_(1)
+
+ ws = get_mvdr_vector(psd_speech, psd_noise, u)
+ enhanced = apply_beamforming_vector(ws, data)
+
+ return enhanced, ws
+
+ # data (B, T, C, F) -> (B, F, C, T)
+ data = data.permute(0, 3, 2, 1)
+
+ # mask: (B, F, C, T)
+ masks, _ = self.mask(data, ilens)
+ assert self.nmask == len(masks)
+
+ if self.nmask == 2: # (mask_speech, mask_noise)
+ mask_speech, mask_noise = masks
+
+ psd_speech = get_power_spectral_density_matrix(data, mask_speech)
+ psd_noise = get_power_spectral_density_matrix(data, mask_noise)
+
+ enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise)
+
+ # (..., F, T) -> (..., T, F)
+ enhanced = enhanced.transpose(-1, -2)
+ mask_speech = mask_speech.transpose(-1, -3)
+ else: # multi-speaker case: (mask_speech1, ..., mask_noise)
+ mask_speech = list(masks[:-1])
+ mask_noise = masks[-1]
+
+ psd_speeches = [
+ get_power_spectral_density_matrix(data, mask) for mask in mask_speech
+ ]
+ psd_noise = get_power_spectral_density_matrix(data, mask_noise)
+
+ enhanced = []
+ ws = []
+ for i in range(self.nmask - 1):
+ psd_speech = psd_speeches.pop(i)
+ # treat all other speakers' psd_speech as noises
+ enh, w = apply_beamforming(
+ data, ilens, psd_speech, sum(psd_speeches) + psd_noise
+ )
+ psd_speeches.insert(i, psd_speech)
+
+ # (..., F, T) -> (..., T, F)
+ enh = enh.transpose(-1, -2)
+ mask_speech[i] = mask_speech[i].transpose(-1, -3)
+
+ enhanced.append(enh)
+ ws.append(w)
+
+ return enhanced, ilens, mask_speech
+
+
+class AttentionReference(torch.nn.Module):
+ def __init__(self, bidim, att_dim):
+ super().__init__()
+ self.mlp_psd = torch.nn.Linear(bidim, att_dim)
+ self.gvec = torch.nn.Linear(att_dim, 1)
+
+ def forward(
+ self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
+ """The forward function
+
+ Args:
+ psd_in (ComplexTensor): (B, F, C, C)
+ ilens (torch.Tensor): (B,)
+ scaling (float):
+ Returns:
+ u (torch.Tensor): (B, C)
+ ilens (torch.Tensor): (B,)
+ """
+ B, _, C = psd_in.size()[:3]
+ assert psd_in.size(2) == psd_in.size(3), psd_in.size()
+ # psd_in: (B, F, C, C)
+ psd = psd_in.masked_fill(
+ torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
+ )
+ # psd: (B, F, C, C) -> (B, C, F)
+ psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
+
+ # Calculate amplitude
+ psd_feat = (psd.real**2 + psd.imag**2) ** 0.5
+
+ # (B, C, F) -> (B, C, F2)
+ mlp_psd = self.mlp_psd(psd_feat)
+ # (B, C, F2) -> (B, C, 1) -> (B, C)
+ e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
+ u = F.softmax(scaling * e, dim=-1)
+ return u, ilens
diff --git a/funasr/models/frontend/frontends_utils/dnn_wpe.py b/funasr/models/frontend/frontends_utils/dnn_wpe.py
new file mode 100644
index 0000000..ffe12a3
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/dnn_wpe.py
@@ -0,0 +1,93 @@
+from typing import Tuple
+
+from pytorch_wpe import wpe_one_iteration
+import torch
+from torch_complex.tensor import ComplexTensor
+
+from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator
+from funasr.modules.nets_utils import make_pad_mask
+
+
+class DNN_WPE(torch.nn.Module):
+ def __init__(
+ self,
+ wtype: str = "blstmp",
+ widim: int = 257,
+ wlayers: int = 3,
+ wunits: int = 300,
+ wprojs: int = 320,
+ dropout_rate: float = 0.0,
+ taps: int = 5,
+ delay: int = 3,
+ use_dnn_mask: bool = True,
+ iterations: int = 1,
+ normalization: bool = False,
+ ):
+ super().__init__()
+ self.iterations = iterations
+ self.taps = taps
+ self.delay = delay
+
+ self.normalization = normalization
+ self.use_dnn_mask = use_dnn_mask
+
+ self.inverse_power = True
+
+ if self.use_dnn_mask:
+ self.mask_est = MaskEstimator(
+ wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1
+ )
+
+ def forward(
+ self, data: ComplexTensor, ilens: torch.LongTensor
+ ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
+ """The forward function
+
+ Notation:
+ B: Batch
+ C: Channel
+ T: Time or Sequence length
+ F: Freq or Some dimension of the feature vector
+
+ Args:
+ data: (B, C, T, F)
+ ilens: (B,)
+ Returns:
+ data: (B, C, T, F)
+ ilens: (B,)
+ """
+ # (B, T, C, F) -> (B, F, C, T)
+ enhanced = data = data.permute(0, 3, 2, 1)
+ mask = None
+
+ for i in range(self.iterations):
+ # Calculate power: (..., C, T)
+ power = enhanced.real**2 + enhanced.imag**2
+ if i == 0 and self.use_dnn_mask:
+ # mask: (B, F, C, T)
+ (mask,), _ = self.mask_est(enhanced, ilens)
+ if self.normalization:
+ # Normalize along T
+ mask = mask / mask.sum(dim=-1)[..., None]
+ # (..., C, T) * (..., C, T) -> (..., C, T)
+ power = power * mask
+
+ # Averaging along the channel axis: (..., C, T) -> (..., T)
+ power = power.mean(dim=-2)
+
+ # enhanced: (..., C, T) -> (..., C, T)
+ enhanced = wpe_one_iteration(
+ data.contiguous(),
+ power,
+ taps=self.taps,
+ delay=self.delay,
+ inverse_power=self.inverse_power,
+ )
+
+ enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)
+
+ # (B, F, C, T) -> (B, T, C, F)
+ enhanced = enhanced.permute(0, 3, 2, 1)
+ if mask is not None:
+ mask = mask.transpose(-1, -3)
+ return enhanced, ilens, mask
diff --git a/funasr/models/frontend/frontends_utils/feature_transform.py b/funasr/models/frontend/frontends_utils/feature_transform.py
new file mode 100644
index 0000000..353dca1
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/feature_transform.py
@@ -0,0 +1,263 @@
+from typing import List
+from typing import Tuple
+from typing import Union
+
+import librosa
+import numpy as np
+import torch
+from torch_complex.tensor import ComplexTensor
+
+from funasr.modules.nets_utils import make_pad_mask
+
+
+class FeatureTransform(torch.nn.Module):
+ def __init__(
+ self,
+ # Mel options,
+ fs: int = 16000,
+ n_fft: int = 512,
+ n_mels: int = 80,
+ fmin: float = 0.0,
+ fmax: float = None,
+ # Normalization
+ stats_file: str = None,
+ apply_uttmvn: bool = True,
+ uttmvn_norm_means: bool = True,
+ uttmvn_norm_vars: bool = False,
+ ):
+ super().__init__()
+ self.apply_uttmvn = apply_uttmvn
+
+ self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
+ self.stats_file = stats_file
+ if stats_file is not None:
+ self.global_mvn = GlobalMVN(stats_file)
+ else:
+ self.global_mvn = None
+
+ if self.apply_uttmvn is not None:
+ self.uttmvn = UtteranceMVN(
+ norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars
+ )
+ else:
+ self.uttmvn = None
+
+ def forward(
+ self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
+ # (B, T, F) or (B, T, C, F)
+ if x.dim() not in (3, 4):
+ raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
+ if not torch.is_tensor(ilens):
+ ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)
+
+ if x.dim() == 4:
+ # h: (B, T, C, F) -> h: (B, T, F)
+ if self.training:
+ # Select 1ch randomly
+ ch = np.random.randint(x.size(2))
+ h = x[:, :, ch, :]
+ else:
+ # Use the first channel
+ h = x[:, :, 0, :]
+ else:
+ h = x
+
+ # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
+ h = h.real**2 + h.imag**2
+
+ h, _ = self.logmel(h, ilens)
+ if self.stats_file is not None:
+ h, _ = self.global_mvn(h, ilens)
+ if self.apply_uttmvn:
+ h, _ = self.uttmvn(h, ilens)
+
+ return h, ilens
+
+
+class LogMel(torch.nn.Module):
+ """Convert STFT to fbank feats
+
+ The arguments is same as librosa.filters.mel
+
+ Args:
+ fs: number > 0 [scalar] sampling rate of the incoming signal
+ n_fft: int > 0 [scalar] number of FFT components
+ n_mels: int > 0 [scalar] number of Mel bands to generate
+ fmin: float >= 0 [scalar] lowest frequency (in Hz)
+ fmax: float >= 0 [scalar] highest frequency (in Hz).
+ If `None`, use `fmax = fs / 2.0`
+ htk: use HTK formula instead of Slaney
+ norm: {None, 1, np.inf} [scalar]
+ if 1, divide the triangular mel weights by the width of the mel band
+ (area normalization). Otherwise, leave all the triangles aiming for
+ a peak value of 1.0
+
+ """
+
+ def __init__(
+ self,
+ fs: int = 16000,
+ n_fft: int = 512,
+ n_mels: int = 80,
+ fmin: float = 0.0,
+ fmax: float = None,
+ htk: bool = False,
+ norm=1,
+ ):
+ super().__init__()
+
+ _mel_options = dict(
+ sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
+ )
+ self.mel_options = _mel_options
+
+ # Note(kamo): The mel matrix of librosa is different from kaldi.
+ melmat = librosa.filters.mel(**_mel_options)
+ # melmat: (D2, D1) -> (D1, D2)
+ self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
+
+ def extra_repr(self):
+ return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
+
+ def forward(
+ self, feat: torch.Tensor, ilens: torch.LongTensor
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
+ # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
+ mel_feat = torch.matmul(feat, self.melmat)
+
+ logmel_feat = (mel_feat + 1e-20).log()
+ # Zero padding
+ logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
+ return logmel_feat, ilens
+
+
+class GlobalMVN(torch.nn.Module):
+ """Apply global mean and variance normalization
+
+ Args:
+ stats_file(str): npy file of 1-dim array or text file.
+ From the _first element to
+ the {(len(array) - 1) / 2}th element are treated as
+ the sum of features,
+ and the rest excluding the last elements are
+ treated as the sum of the square value of features,
+ and the last elements eqauls to the number of samples.
+ std_floor(float):
+ """
+
+ def __init__(
+ self,
+ stats_file: str,
+ norm_means: bool = True,
+ norm_vars: bool = True,
+ eps: float = 1.0e-20,
+ ):
+ super().__init__()
+ self.norm_means = norm_means
+ self.norm_vars = norm_vars
+
+ self.stats_file = stats_file
+ stats = np.load(stats_file)
+
+ stats = stats.astype(float)
+ assert (len(stats) - 1) % 2 == 0, stats.shape
+
+ count = stats.flatten()[-1]
+ mean = stats[: (len(stats) - 1) // 2] / count
+ var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean
+ std = np.maximum(np.sqrt(var), eps)
+
+ self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32)))
+ self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
+
+ def extra_repr(self):
+ return (
+ f"stats_file={self.stats_file}, "
+ f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
+ )
+
+ def forward(
+ self, x: torch.Tensor, ilens: torch.LongTensor
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
+ # feat: (B, T, D)
+ if self.norm_means:
+ x += self.bias.type_as(x)
+ x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
+
+ if self.norm_vars:
+ x *= self.scale.type_as(x)
+ return x, ilens
+
+
+class UtteranceMVN(torch.nn.Module):
+ def __init__(
+ self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20
+ ):
+ super().__init__()
+ self.norm_means = norm_means
+ self.norm_vars = norm_vars
+ self.eps = eps
+
+ def extra_repr(self):
+ return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
+
+ def forward(
+ self, x: torch.Tensor, ilens: torch.LongTensor
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
+ return utterance_mvn(
+ x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps
+ )
+
+
+def utterance_mvn(
+ x: torch.Tensor,
+ ilens: torch.LongTensor,
+ norm_means: bool = True,
+ norm_vars: bool = False,
+ eps: float = 1.0e-20,
+) -> Tuple[torch.Tensor, torch.LongTensor]:
+ """Apply utterance mean and variance normalization
+
+ Args:
+ x: (B, T, D), assumed zero padded
+ ilens: (B, T, D)
+ norm_means:
+ norm_vars:
+ eps:
+
+ """
+ ilens_ = ilens.type_as(x)
+ # mean: (B, D)
+ mean = x.sum(dim=1) / ilens_[:, None]
+
+ if norm_means:
+ x -= mean[:, None, :]
+ x_ = x
+ else:
+ x_ = x - mean[:, None, :]
+
+ # Zero padding
+ x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
+ if norm_vars:
+ var = x_.pow(2).sum(dim=1) / ilens_[:, None]
+ var = torch.clamp(var, min=eps)
+ x /= var.sqrt()[:, None, :]
+ x_ = x
+ return x_, ilens
+
+
+def feature_transform_for(args, n_fft):
+ return FeatureTransform(
+ # Mel options,
+ fs=args.fbank_fs,
+ n_fft=n_fft,
+ n_mels=args.n_mels,
+ fmin=args.fbank_fmin,
+ fmax=args.fbank_fmax,
+ # Normalization
+ stats_file=args.stats_file,
+ apply_uttmvn=args.apply_uttmvn,
+ uttmvn_norm_means=args.uttmvn_norm_means,
+ uttmvn_norm_vars=args.uttmvn_norm_vars,
+ )
diff --git a/funasr/models/frontend/frontends_utils/frontend.py b/funasr/models/frontend/frontends_utils/frontend.py
new file mode 100644
index 0000000..bf266e9
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/frontend.py
@@ -0,0 +1,151 @@
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy
+import torch
+import torch.nn as nn
+from torch_complex.tensor import ComplexTensor
+
+from funasr.models.frontend.frontends_utils.dnn_beamformer import DNN_Beamformer
+from funasr.models.frontend.frontends_utils.dnn_wpe import DNN_WPE
+
+
+class Frontend(nn.Module):
+ def __init__(
+ self,
+ idim: int,
+ # WPE options
+ use_wpe: bool = False,
+ wtype: str = "blstmp",
+ wlayers: int = 3,
+ wunits: int = 300,
+ wprojs: int = 320,
+ wdropout_rate: float = 0.0,
+ taps: int = 5,
+ delay: int = 3,
+ use_dnn_mask_for_wpe: bool = True,
+ # Beamformer options
+ use_beamformer: bool = False,
+ btype: str = "blstmp",
+ blayers: int = 3,
+ bunits: int = 300,
+ bprojs: int = 320,
+ bnmask: int = 2,
+ badim: int = 320,
+ ref_channel: int = -1,
+ bdropout_rate=0.0,
+ ):
+ super().__init__()
+
+ self.use_beamformer = use_beamformer
+ self.use_wpe = use_wpe
+ self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
+ # use frontend for all the data,
+ # e.g. in the case of multi-speaker speech separation
+ self.use_frontend_for_all = bnmask > 2
+
+ if self.use_wpe:
+ if self.use_dnn_mask_for_wpe:
+ # Use DNN for power estimation
+ # (Not observed significant gains)
+ iterations = 1
+ else:
+ # Performing as conventional WPE, without DNN Estimator
+ iterations = 2
+
+ self.wpe = DNN_WPE(
+ wtype=wtype,
+ widim=idim,
+ wunits=wunits,
+ wprojs=wprojs,
+ wlayers=wlayers,
+ taps=taps,
+ delay=delay,
+ dropout_rate=wdropout_rate,
+ iterations=iterations,
+ use_dnn_mask=use_dnn_mask_for_wpe,
+ )
+ else:
+ self.wpe = None
+
+ if self.use_beamformer:
+ self.beamformer = DNN_Beamformer(
+ btype=btype,
+ bidim=idim,
+ bunits=bunits,
+ bprojs=bprojs,
+ blayers=blayers,
+ bnmask=bnmask,
+ dropout_rate=bdropout_rate,
+ badim=badim,
+ ref_channel=ref_channel,
+ )
+ else:
+ self.beamformer = None
+
+ def forward(
+ self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
+ ) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
+ assert len(x) == len(ilens), (len(x), len(ilens))
+ # (B, T, F) or (B, T, C, F)
+ if x.dim() not in (3, 4):
+ raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
+ if not torch.is_tensor(ilens):
+ ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
+
+ mask = None
+ h = x
+ if h.dim() == 4:
+ if self.training:
+ choices = [(False, False)] if not self.use_frontend_for_all else []
+ if self.use_wpe:
+ choices.append((True, False))
+
+ if self.use_beamformer:
+ choices.append((False, True))
+
+ use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
+
+ else:
+ use_wpe = self.use_wpe
+ use_beamformer = self.use_beamformer
+
+ # 1. WPE
+ if use_wpe:
+ # h: (B, T, C, F) -> h: (B, T, C, F)
+ h, ilens, mask = self.wpe(h, ilens)
+
+ # 2. Beamformer
+ if use_beamformer:
+ # h: (B, T, C, F) -> h: (B, T, F)
+ h, ilens, mask = self.beamformer(h, ilens)
+
+ return h, ilens, mask
+
+
+def frontend_for(args, idim):
+ return Frontend(
+ idim=idim,
+ # WPE options
+ use_wpe=args.use_wpe,
+ wtype=args.wtype,
+ wlayers=args.wlayers,
+ wunits=args.wunits,
+ wprojs=args.wprojs,
+ wdropout_rate=args.wdropout_rate,
+ taps=args.wpe_taps,
+ delay=args.wpe_delay,
+ use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
+ # Beamformer options
+ use_beamformer=args.use_beamformer,
+ btype=args.btype,
+ blayers=args.blayers,
+ bunits=args.bunits,
+ bprojs=args.bprojs,
+ bnmask=args.bnmask,
+ badim=args.badim,
+ ref_channel=args.ref_channel,
+ bdropout_rate=args.bdropout_rate,
+ )
diff --git a/funasr/models/frontend/frontends_utils/mask_estimator.py b/funasr/models/frontend/frontends_utils/mask_estimator.py
new file mode 100644
index 0000000..53072bf
--- /dev/null
+++ b/funasr/models/frontend/frontends_utils/mask_estimator.py
@@ -0,0 +1,77 @@
+from typing import Tuple
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torch_complex.tensor import ComplexTensor
+
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.rnn.encoders import RNN
+from funasr.modules.rnn.encoders import RNNP
+
+
+class MaskEstimator(torch.nn.Module):
+ def __init__(self, type, idim, layers, units, projs, dropout, nmask=1):
+ super().__init__()
+ subsample = np.ones(layers + 1, dtype=np.int32)
+
+ typ = type.lstrip("vgg").rstrip("p")
+ if type[-1] == "p":
+ self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ)
+ else:
+ self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ)
+
+ self.type = type
+ self.nmask = nmask
+ self.linears = torch.nn.ModuleList(
+ [torch.nn.Linear(projs, idim) for _ in range(nmask)]
+ )
+
+ def forward(
+ self, xs: ComplexTensor, ilens: torch.LongTensor
+ ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
+ """The forward function
+
+ Args:
+ xs: (B, F, C, T)
+ ilens: (B,)
+ Returns:
+ hs (torch.Tensor): The hidden vector (B, F, C, T)
+ masks: A tuple of the masks. (B, F, C, T)
+ ilens: (B,)
+ """
+ assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
+ _, _, C, input_length = xs.size()
+ # (B, F, C, T) -> (B, C, T, F)
+ xs = xs.permute(0, 2, 3, 1)
+
+ # Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
+ xs = (xs.real**2 + xs.imag**2) ** 0.5
+ # xs: (B, C, T, F) -> xs: (B * C, T, F)
+ xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
+ # ilens: (B,) -> ilens_: (B * C)
+ ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)
+
+ # xs: (B * C, T, F) -> xs: (B * C, T, D)
+ xs, _, _ = self.brnn(xs, ilens_)
+ # xs: (B * C, T, D) -> xs: (B, C, T, D)
+ xs = xs.view(-1, C, xs.size(-2), xs.size(-1))
+
+ masks = []
+ for linear in self.linears:
+ # xs: (B, C, T, D) -> mask:(B, C, T, F)
+ mask = linear(xs)
+
+ mask = torch.sigmoid(mask)
+ # Zero padding
+ mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)
+
+ # (B, C, T, F) -> (B, F, C, T)
+ mask = mask.permute(0, 3, 1, 2)
+
+ # Take cares of multi gpu cases: If input_length > max(ilens)
+ if mask.size(-1) < input_length:
+ mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
+ masks.append(mask)
+
+ return tuple(masks), ilens
--
Gitblit v1.9.1