| | |
| | | from torch_complex import functional as FC |
| | | from torch_complex.tensor import ComplexTensor |
| | | except: |
| | | raise "Please install torch_complex firstly" |
| | | print("Please install torch_complex firstly") |
| | | |
| | | |
| | | |
| | |
| | | try: |
| | | from torch_complex.tensor import ComplexTensor |
| | | except: |
| | | raise "Please install torch_complex firstly" |
| | | print("Please install torch_complex firstly") |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.layers.complex_utils import is_complex |
| | | from funasr.layers.inversible_interface import InversibleInterface |
| | |
| | | try: |
| | | from rotary_embedding_torch import RotaryEmbedding |
| | | except: |
| | | raise "Please install rotary_embedding_torch by: \n pip install -U funasr[all]" |
| | | print("Please install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch") |
| | | from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm |
| | | from funasr.modules.embedding import ScaledSinuEmbedding |
| | | from funasr.modules.mossformer import FLASH_ShareA_FFConvM |
| | |
| | | try: |
| | | from torch_complex.tensor import ComplexTensor |
| | | except: |
| | | raise "Please install torch_complex firstly" |
| | | print("Please install torch_complex firstly") |
| | | |
| | | from funasr.layers.log_mel import LogMel |
| | | from funasr.layers.stft import Stft |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.modules.frontends.frontend import Frontend |
| | | from funasr.models.frontend.frontends_utils.frontend import Frontend |
| | | from funasr.utils.get_default_kwargs import get_default_kwargs |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | |
| New file |
| | |
| | | """Initialize sub package.""" |
| New file |
| | |
| | | import torch |
| | | from torch_complex import functional as FC |
| | | from torch_complex.tensor import ComplexTensor |
| | | |
| | | |
| | | def get_power_spectral_density_matrix( |
| | | xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15 |
| | | ) -> ComplexTensor: |
| | | """Return cross-channel power spectral density (PSD) matrix |
| | | |
| | | Args: |
| | | xs (ComplexTensor): (..., F, C, T) |
| | | mask (torch.Tensor): (..., F, C, T) |
| | | normalization (bool): |
| | | eps (float): |
| | | Returns |
| | | psd (ComplexTensor): (..., F, C, C) |
| | | |
| | | """ |
| | | # outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2) |
| | | psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()]) |
| | | |
| | | # Averaging mask along C: (..., C, T) -> (..., T) |
| | | mask = mask.mean(dim=-2) |
| | | |
| | | # Normalized mask along T: (..., T) |
| | | if normalization: |
| | | # If assuming the tensor is padded with zero, the summation along |
| | | # the time axis is same regardless of the padding length. |
| | | mask = mask / (mask.sum(dim=-1, keepdim=True) + eps) |
| | | |
| | | # psd: (..., T, C, C) |
| | | psd = psd_Y * mask[..., None, None] |
| | | # (..., T, C, C) -> (..., C, C) |
| | | psd = psd.sum(dim=-3) |
| | | |
| | | return psd |
| | | |
| | | |
| | | def get_mvdr_vector( |
| | | psd_s: ComplexTensor, |
| | | psd_n: ComplexTensor, |
| | | reference_vector: torch.Tensor, |
| | | eps: float = 1e-15, |
| | | ) -> ComplexTensor: |
| | | """Return the MVDR(Minimum Variance Distortionless Response) vector: |
| | | |
| | | h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u |
| | | |
| | | Reference: |
| | | On optimal frequency-domain multichannel linear filtering |
| | | for noise reduction; M. Souden et al., 2010; |
| | | https://ieeexplore.ieee.org/document/5089420 |
| | | |
| | | Args: |
| | | psd_s (ComplexTensor): (..., F, C, C) |
| | | psd_n (ComplexTensor): (..., F, C, C) |
| | | reference_vector (torch.Tensor): (..., C) |
| | | eps (float): |
| | | Returns: |
| | | beamform_vector (ComplexTensor)r: (..., F, C) |
| | | """ |
| | | # Add eps |
| | | C = psd_n.size(-1) |
| | | eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device) |
| | | shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C] |
| | | eye = eye.view(*shape) |
| | | psd_n += eps * eye |
| | | |
| | | # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) |
| | | numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s]) |
| | | # ws: (..., C, C) / (...,) -> (..., C, C) |
| | | ws = numerator / (FC.trace(numerator)[..., None, None] + eps) |
| | | # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) |
| | | beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) |
| | | return beamform_vector |
| | | |
| | | |
| | | def apply_beamforming_vector( |
| | | beamform_vector: ComplexTensor, mix: ComplexTensor |
| | | ) -> ComplexTensor: |
| | | # (..., C) x (..., C, T) -> (..., T) |
| | | es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix]) |
| | | return es |
| New file |
| | |
| | | """DNN beamformer module.""" |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | from torch.nn import functional as F |
| | | |
| | | from funasr.models.frontend.frontends_utils.beamformer import apply_beamforming_vector |
| | | from funasr.models.frontend.frontends_utils.beamformer import get_mvdr_vector |
| | | from funasr.models.frontend.frontends_utils.beamformer import ( |
| | | get_power_spectral_density_matrix, # noqa: H301 |
| | | ) |
| | | from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator |
| | | from torch_complex.tensor import ComplexTensor |
| | | |
| | | |
| | | class DNN_Beamformer(torch.nn.Module): |
| | | """DNN mask based Beamformer |
| | | |
| | | Citation: |
| | | Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017; |
| | | https://arxiv.org/abs/1703.04783 |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | bidim, |
| | | btype="blstmp", |
| | | blayers=3, |
| | | bunits=300, |
| | | bprojs=320, |
| | | bnmask=2, |
| | | dropout_rate=0.0, |
| | | badim=320, |
| | | ref_channel: int = -1, |
| | | beamformer_type="mvdr", |
| | | ): |
| | | super().__init__() |
| | | self.mask = MaskEstimator( |
| | | btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask |
| | | ) |
| | | self.ref = AttentionReference(bidim, badim) |
| | | self.ref_channel = ref_channel |
| | | |
| | | self.nmask = bnmask |
| | | |
| | | if beamformer_type != "mvdr": |
| | | raise ValueError( |
| | | "Not supporting beamformer_type={}".format(beamformer_type) |
| | | ) |
| | | self.beamformer_type = beamformer_type |
| | | |
| | | def forward( |
| | | self, data: ComplexTensor, ilens: torch.LongTensor |
| | | ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: |
| | | """The forward function |
| | | |
| | | Notation: |
| | | B: Batch |
| | | C: Channel |
| | | T: Time or Sequence length |
| | | F: Freq |
| | | |
| | | Args: |
| | | data (ComplexTensor): (B, T, C, F) |
| | | ilens (torch.Tensor): (B,) |
| | | Returns: |
| | | enhanced (ComplexTensor): (B, T, F) |
| | | ilens (torch.Tensor): (B,) |
| | | |
| | | """ |
| | | |
| | | def apply_beamforming(data, ilens, psd_speech, psd_noise): |
| | | # u: (B, C) |
| | | if self.ref_channel < 0: |
| | | u, _ = self.ref(psd_speech, ilens) |
| | | else: |
| | | # (optional) Create onehot vector for fixed reference microphone |
| | | u = torch.zeros( |
| | | *(data.size()[:-3] + (data.size(-2),)), device=data.device |
| | | ) |
| | | u[..., self.ref_channel].fill_(1) |
| | | |
| | | ws = get_mvdr_vector(psd_speech, psd_noise, u) |
| | | enhanced = apply_beamforming_vector(ws, data) |
| | | |
| | | return enhanced, ws |
| | | |
| | | # data (B, T, C, F) -> (B, F, C, T) |
| | | data = data.permute(0, 3, 2, 1) |
| | | |
| | | # mask: (B, F, C, T) |
| | | masks, _ = self.mask(data, ilens) |
| | | assert self.nmask == len(masks) |
| | | |
| | | if self.nmask == 2: # (mask_speech, mask_noise) |
| | | mask_speech, mask_noise = masks |
| | | |
| | | psd_speech = get_power_spectral_density_matrix(data, mask_speech) |
| | | psd_noise = get_power_spectral_density_matrix(data, mask_noise) |
| | | |
| | | enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise) |
| | | |
| | | # (..., F, T) -> (..., T, F) |
| | | enhanced = enhanced.transpose(-1, -2) |
| | | mask_speech = mask_speech.transpose(-1, -3) |
| | | else: # multi-speaker case: (mask_speech1, ..., mask_noise) |
| | | mask_speech = list(masks[:-1]) |
| | | mask_noise = masks[-1] |
| | | |
| | | psd_speeches = [ |
| | | get_power_spectral_density_matrix(data, mask) for mask in mask_speech |
| | | ] |
| | | psd_noise = get_power_spectral_density_matrix(data, mask_noise) |
| | | |
| | | enhanced = [] |
| | | ws = [] |
| | | for i in range(self.nmask - 1): |
| | | psd_speech = psd_speeches.pop(i) |
| | | # treat all other speakers' psd_speech as noises |
| | | enh, w = apply_beamforming( |
| | | data, ilens, psd_speech, sum(psd_speeches) + psd_noise |
| | | ) |
| | | psd_speeches.insert(i, psd_speech) |
| | | |
| | | # (..., F, T) -> (..., T, F) |
| | | enh = enh.transpose(-1, -2) |
| | | mask_speech[i] = mask_speech[i].transpose(-1, -3) |
| | | |
| | | enhanced.append(enh) |
| | | ws.append(w) |
| | | |
| | | return enhanced, ilens, mask_speech |
| | | |
| | | |
| | | class AttentionReference(torch.nn.Module): |
| | | def __init__(self, bidim, att_dim): |
| | | super().__init__() |
| | | self.mlp_psd = torch.nn.Linear(bidim, att_dim) |
| | | self.gvec = torch.nn.Linear(att_dim, 1) |
| | | |
| | | def forward( |
| | | self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0 |
| | | ) -> Tuple[torch.Tensor, torch.LongTensor]: |
| | | """The forward function |
| | | |
| | | Args: |
| | | psd_in (ComplexTensor): (B, F, C, C) |
| | | ilens (torch.Tensor): (B,) |
| | | scaling (float): |
| | | Returns: |
| | | u (torch.Tensor): (B, C) |
| | | ilens (torch.Tensor): (B,) |
| | | """ |
| | | B, _, C = psd_in.size()[:3] |
| | | assert psd_in.size(2) == psd_in.size(3), psd_in.size() |
| | | # psd_in: (B, F, C, C) |
| | | psd = psd_in.masked_fill( |
| | | torch.eye(C, dtype=torch.bool, device=psd_in.device), 0 |
| | | ) |
| | | # psd: (B, F, C, C) -> (B, C, F) |
| | | psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2) |
| | | |
| | | # Calculate amplitude |
| | | psd_feat = (psd.real**2 + psd.imag**2) ** 0.5 |
| | | |
| | | # (B, C, F) -> (B, C, F2) |
| | | mlp_psd = self.mlp_psd(psd_feat) |
| | | # (B, C, F2) -> (B, C, 1) -> (B, C) |
| | | e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1) |
| | | u = F.softmax(scaling * e, dim=-1) |
| | | return u, ilens |
| New file |
| | |
| | | from typing import Tuple |
| | | |
| | | from pytorch_wpe import wpe_one_iteration |
| | | import torch |
| | | from torch_complex.tensor import ComplexTensor |
| | | |
| | | from funasr.models.frontend.frontends_utils.mask_estimator import MaskEstimator |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | |
| | | |
| | | class DNN_WPE(torch.nn.Module): |
| | | def __init__( |
| | | self, |
| | | wtype: str = "blstmp", |
| | | widim: int = 257, |
| | | wlayers: int = 3, |
| | | wunits: int = 300, |
| | | wprojs: int = 320, |
| | | dropout_rate: float = 0.0, |
| | | taps: int = 5, |
| | | delay: int = 3, |
| | | use_dnn_mask: bool = True, |
| | | iterations: int = 1, |
| | | normalization: bool = False, |
| | | ): |
| | | super().__init__() |
| | | self.iterations = iterations |
| | | self.taps = taps |
| | | self.delay = delay |
| | | |
| | | self.normalization = normalization |
| | | self.use_dnn_mask = use_dnn_mask |
| | | |
| | | self.inverse_power = True |
| | | |
| | | if self.use_dnn_mask: |
| | | self.mask_est = MaskEstimator( |
| | | wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1 |
| | | ) |
| | | |
| | | def forward( |
| | | self, data: ComplexTensor, ilens: torch.LongTensor |
| | | ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: |
| | | """The forward function |
| | | |
| | | Notation: |
| | | B: Batch |
| | | C: Channel |
| | | T: Time or Sequence length |
| | | F: Freq or Some dimension of the feature vector |
| | | |
| | | Args: |
| | | data: (B, C, T, F) |
| | | ilens: (B,) |
| | | Returns: |
| | | data: (B, C, T, F) |
| | | ilens: (B,) |
| | | """ |
| | | # (B, T, C, F) -> (B, F, C, T) |
| | | enhanced = data = data.permute(0, 3, 2, 1) |
| | | mask = None |
| | | |
| | | for i in range(self.iterations): |
| | | # Calculate power: (..., C, T) |
| | | power = enhanced.real**2 + enhanced.imag**2 |
| | | if i == 0 and self.use_dnn_mask: |
| | | # mask: (B, F, C, T) |
| | | (mask,), _ = self.mask_est(enhanced, ilens) |
| | | if self.normalization: |
| | | # Normalize along T |
| | | mask = mask / mask.sum(dim=-1)[..., None] |
| | | # (..., C, T) * (..., C, T) -> (..., C, T) |
| | | power = power * mask |
| | | |
| | | # Averaging along the channel axis: (..., C, T) -> (..., T) |
| | | power = power.mean(dim=-2) |
| | | |
| | | # enhanced: (..., C, T) -> (..., C, T) |
| | | enhanced = wpe_one_iteration( |
| | | data.contiguous(), |
| | | power, |
| | | taps=self.taps, |
| | | delay=self.delay, |
| | | inverse_power=self.inverse_power, |
| | | ) |
| | | |
| | | enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0) |
| | | |
| | | # (B, F, C, T) -> (B, T, C, F) |
| | | enhanced = enhanced.permute(0, 3, 2, 1) |
| | | if mask is not None: |
| | | mask = mask.transpose(-1, -3) |
| | | return enhanced, ilens, mask |
| New file |
| | |
| | | from typing import List |
| | | from typing import Tuple |
| | | from typing import Union |
| | | |
| | | import librosa |
| | | import numpy as np |
| | | import torch |
| | | from torch_complex.tensor import ComplexTensor |
| | | |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | |
| | | |
| | | class FeatureTransform(torch.nn.Module): |
| | | def __init__( |
| | | self, |
| | | # Mel options, |
| | | fs: int = 16000, |
| | | n_fft: int = 512, |
| | | n_mels: int = 80, |
| | | fmin: float = 0.0, |
| | | fmax: float = None, |
| | | # Normalization |
| | | stats_file: str = None, |
| | | apply_uttmvn: bool = True, |
| | | uttmvn_norm_means: bool = True, |
| | | uttmvn_norm_vars: bool = False, |
| | | ): |
| | | super().__init__() |
| | | self.apply_uttmvn = apply_uttmvn |
| | | |
| | | self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) |
| | | self.stats_file = stats_file |
| | | if stats_file is not None: |
| | | self.global_mvn = GlobalMVN(stats_file) |
| | | else: |
| | | self.global_mvn = None |
| | | |
| | | if self.apply_uttmvn is not None: |
| | | self.uttmvn = UtteranceMVN( |
| | | norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars |
| | | ) |
| | | else: |
| | | self.uttmvn = None |
| | | |
| | | def forward( |
| | | self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]] |
| | | ) -> Tuple[torch.Tensor, torch.LongTensor]: |
| | | # (B, T, F) or (B, T, C, F) |
| | | if x.dim() not in (3, 4): |
| | | raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") |
| | | if not torch.is_tensor(ilens): |
| | | ilens = torch.from_numpy(np.asarray(ilens)).to(x.device) |
| | | |
| | | if x.dim() == 4: |
| | | # h: (B, T, C, F) -> h: (B, T, F) |
| | | if self.training: |
| | | # Select 1ch randomly |
| | | ch = np.random.randint(x.size(2)) |
| | | h = x[:, :, ch, :] |
| | | else: |
| | | # Use the first channel |
| | | h = x[:, :, 0, :] |
| | | else: |
| | | h = x |
| | | |
| | | # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) |
| | | h = h.real**2 + h.imag**2 |
| | | |
| | | h, _ = self.logmel(h, ilens) |
| | | if self.stats_file is not None: |
| | | h, _ = self.global_mvn(h, ilens) |
| | | if self.apply_uttmvn: |
| | | h, _ = self.uttmvn(h, ilens) |
| | | |
| | | return h, ilens |
| | | |
| | | |
| | | class LogMel(torch.nn.Module): |
| | | """Convert STFT to fbank feats |
| | | |
| | | The arguments is same as librosa.filters.mel |
| | | |
| | | Args: |
| | | fs: number > 0 [scalar] sampling rate of the incoming signal |
| | | n_fft: int > 0 [scalar] number of FFT components |
| | | n_mels: int > 0 [scalar] number of Mel bands to generate |
| | | fmin: float >= 0 [scalar] lowest frequency (in Hz) |
| | | fmax: float >= 0 [scalar] highest frequency (in Hz). |
| | | If `None`, use `fmax = fs / 2.0` |
| | | htk: use HTK formula instead of Slaney |
| | | norm: {None, 1, np.inf} [scalar] |
| | | if 1, divide the triangular mel weights by the width of the mel band |
| | | (area normalization). Otherwise, leave all the triangles aiming for |
| | | a peak value of 1.0 |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | fs: int = 16000, |
| | | n_fft: int = 512, |
| | | n_mels: int = 80, |
| | | fmin: float = 0.0, |
| | | fmax: float = None, |
| | | htk: bool = False, |
| | | norm=1, |
| | | ): |
| | | super().__init__() |
| | | |
| | | _mel_options = dict( |
| | | sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm |
| | | ) |
| | | self.mel_options = _mel_options |
| | | |
| | | # Note(kamo): The mel matrix of librosa is different from kaldi. |
| | | melmat = librosa.filters.mel(**_mel_options) |
| | | # melmat: (D2, D1) -> (D1, D2) |
| | | self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) |
| | | |
| | | def extra_repr(self): |
| | | return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) |
| | | |
| | | def forward( |
| | | self, feat: torch.Tensor, ilens: torch.LongTensor |
| | | ) -> Tuple[torch.Tensor, torch.LongTensor]: |
| | | # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) |
| | | mel_feat = torch.matmul(feat, self.melmat) |
| | | |
| | | logmel_feat = (mel_feat + 1e-20).log() |
| | | # Zero padding |
| | | logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0) |
| | | return logmel_feat, ilens |
| | | |
| | | |
| | | class GlobalMVN(torch.nn.Module): |
| | | """Apply global mean and variance normalization |
| | | |
| | | Args: |
| | | stats_file(str): npy file of 1-dim array or text file. |
| | | From the _first element to |
| | | the {(len(array) - 1) / 2}th element are treated as |
| | | the sum of features, |
| | | and the rest excluding the last elements are |
| | | treated as the sum of the square value of features, |
| | | and the last elements eqauls to the number of samples. |
| | | std_floor(float): |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | stats_file: str, |
| | | norm_means: bool = True, |
| | | norm_vars: bool = True, |
| | | eps: float = 1.0e-20, |
| | | ): |
| | | super().__init__() |
| | | self.norm_means = norm_means |
| | | self.norm_vars = norm_vars |
| | | |
| | | self.stats_file = stats_file |
| | | stats = np.load(stats_file) |
| | | |
| | | stats = stats.astype(float) |
| | | assert (len(stats) - 1) % 2 == 0, stats.shape |
| | | |
| | | count = stats.flatten()[-1] |
| | | mean = stats[: (len(stats) - 1) // 2] / count |
| | | var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean |
| | | std = np.maximum(np.sqrt(var), eps) |
| | | |
| | | self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32))) |
| | | self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32))) |
| | | |
| | | def extra_repr(self): |
| | | return ( |
| | | f"stats_file={self.stats_file}, " |
| | | f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" |
| | | ) |
| | | |
| | | def forward( |
| | | self, x: torch.Tensor, ilens: torch.LongTensor |
| | | ) -> Tuple[torch.Tensor, torch.LongTensor]: |
| | | # feat: (B, T, D) |
| | | if self.norm_means: |
| | | x += self.bias.type_as(x) |
| | | x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) |
| | | |
| | | if self.norm_vars: |
| | | x *= self.scale.type_as(x) |
| | | return x, ilens |
| | | |
| | | |
| | | class UtteranceMVN(torch.nn.Module): |
| | | def __init__( |
| | | self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20 |
| | | ): |
| | | super().__init__() |
| | | self.norm_means = norm_means |
| | | self.norm_vars = norm_vars |
| | | self.eps = eps |
| | | |
| | | def extra_repr(self): |
| | | return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" |
| | | |
| | | def forward( |
| | | self, x: torch.Tensor, ilens: torch.LongTensor |
| | | ) -> Tuple[torch.Tensor, torch.LongTensor]: |
| | | return utterance_mvn( |
| | | x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps |
| | | ) |
| | | |
| | | |
| | | def utterance_mvn( |
| | | x: torch.Tensor, |
| | | ilens: torch.LongTensor, |
| | | norm_means: bool = True, |
| | | norm_vars: bool = False, |
| | | eps: float = 1.0e-20, |
| | | ) -> Tuple[torch.Tensor, torch.LongTensor]: |
| | | """Apply utterance mean and variance normalization |
| | | |
| | | Args: |
| | | x: (B, T, D), assumed zero padded |
| | | ilens: (B, T, D) |
| | | norm_means: |
| | | norm_vars: |
| | | eps: |
| | | |
| | | """ |
| | | ilens_ = ilens.type_as(x) |
| | | # mean: (B, D) |
| | | mean = x.sum(dim=1) / ilens_[:, None] |
| | | |
| | | if norm_means: |
| | | x -= mean[:, None, :] |
| | | x_ = x |
| | | else: |
| | | x_ = x - mean[:, None, :] |
| | | |
| | | # Zero padding |
| | | x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0) |
| | | if norm_vars: |
| | | var = x_.pow(2).sum(dim=1) / ilens_[:, None] |
| | | var = torch.clamp(var, min=eps) |
| | | x /= var.sqrt()[:, None, :] |
| | | x_ = x |
| | | return x_, ilens |
| | | |
| | | |
| | | def feature_transform_for(args, n_fft): |
| | | return FeatureTransform( |
| | | # Mel options, |
| | | fs=args.fbank_fs, |
| | | n_fft=n_fft, |
| | | n_mels=args.n_mels, |
| | | fmin=args.fbank_fmin, |
| | | fmax=args.fbank_fmax, |
| | | # Normalization |
| | | stats_file=args.stats_file, |
| | | apply_uttmvn=args.apply_uttmvn, |
| | | uttmvn_norm_means=args.uttmvn_norm_means, |
| | | uttmvn_norm_vars=args.uttmvn_norm_vars, |
| | | ) |
| New file |
| | |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | from typing import Union |
| | | |
| | | import numpy |
| | | import torch |
| | | import torch.nn as nn |
| | | from torch_complex.tensor import ComplexTensor |
| | | |
| | | from funasr.models.frontend.frontends_utils.dnn_beamformer import DNN_Beamformer |
| | | from funasr.models.frontend.frontends_utils.dnn_wpe import DNN_WPE |
| | | |
| | | |
| | | class Frontend(nn.Module): |
| | | def __init__( |
| | | self, |
| | | idim: int, |
| | | # WPE options |
| | | use_wpe: bool = False, |
| | | wtype: str = "blstmp", |
| | | wlayers: int = 3, |
| | | wunits: int = 300, |
| | | wprojs: int = 320, |
| | | wdropout_rate: float = 0.0, |
| | | taps: int = 5, |
| | | delay: int = 3, |
| | | use_dnn_mask_for_wpe: bool = True, |
| | | # Beamformer options |
| | | use_beamformer: bool = False, |
| | | btype: str = "blstmp", |
| | | blayers: int = 3, |
| | | bunits: int = 300, |
| | | bprojs: int = 320, |
| | | bnmask: int = 2, |
| | | badim: int = 320, |
| | | ref_channel: int = -1, |
| | | bdropout_rate=0.0, |
| | | ): |
| | | super().__init__() |
| | | |
| | | self.use_beamformer = use_beamformer |
| | | self.use_wpe = use_wpe |
| | | self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe |
| | | # use frontend for all the data, |
| | | # e.g. in the case of multi-speaker speech separation |
| | | self.use_frontend_for_all = bnmask > 2 |
| | | |
| | | if self.use_wpe: |
| | | if self.use_dnn_mask_for_wpe: |
| | | # Use DNN for power estimation |
| | | # (Not observed significant gains) |
| | | iterations = 1 |
| | | else: |
| | | # Performing as conventional WPE, without DNN Estimator |
| | | iterations = 2 |
| | | |
| | | self.wpe = DNN_WPE( |
| | | wtype=wtype, |
| | | widim=idim, |
| | | wunits=wunits, |
| | | wprojs=wprojs, |
| | | wlayers=wlayers, |
| | | taps=taps, |
| | | delay=delay, |
| | | dropout_rate=wdropout_rate, |
| | | iterations=iterations, |
| | | use_dnn_mask=use_dnn_mask_for_wpe, |
| | | ) |
| | | else: |
| | | self.wpe = None |
| | | |
| | | if self.use_beamformer: |
| | | self.beamformer = DNN_Beamformer( |
| | | btype=btype, |
| | | bidim=idim, |
| | | bunits=bunits, |
| | | bprojs=bprojs, |
| | | blayers=blayers, |
| | | bnmask=bnmask, |
| | | dropout_rate=bdropout_rate, |
| | | badim=badim, |
| | | ref_channel=ref_channel, |
| | | ) |
| | | else: |
| | | self.beamformer = None |
| | | |
| | | def forward( |
| | | self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]] |
| | | ) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]: |
| | | assert len(x) == len(ilens), (len(x), len(ilens)) |
| | | # (B, T, F) or (B, T, C, F) |
| | | if x.dim() not in (3, 4): |
| | | raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") |
| | | if not torch.is_tensor(ilens): |
| | | ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device) |
| | | |
| | | mask = None |
| | | h = x |
| | | if h.dim() == 4: |
| | | if self.training: |
| | | choices = [(False, False)] if not self.use_frontend_for_all else [] |
| | | if self.use_wpe: |
| | | choices.append((True, False)) |
| | | |
| | | if self.use_beamformer: |
| | | choices.append((False, True)) |
| | | |
| | | use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))] |
| | | |
| | | else: |
| | | use_wpe = self.use_wpe |
| | | use_beamformer = self.use_beamformer |
| | | |
| | | # 1. WPE |
| | | if use_wpe: |
| | | # h: (B, T, C, F) -> h: (B, T, C, F) |
| | | h, ilens, mask = self.wpe(h, ilens) |
| | | |
| | | # 2. Beamformer |
| | | if use_beamformer: |
| | | # h: (B, T, C, F) -> h: (B, T, F) |
| | | h, ilens, mask = self.beamformer(h, ilens) |
| | | |
| | | return h, ilens, mask |
| | | |
| | | |
| | | def frontend_for(args, idim): |
| | | return Frontend( |
| | | idim=idim, |
| | | # WPE options |
| | | use_wpe=args.use_wpe, |
| | | wtype=args.wtype, |
| | | wlayers=args.wlayers, |
| | | wunits=args.wunits, |
| | | wprojs=args.wprojs, |
| | | wdropout_rate=args.wdropout_rate, |
| | | taps=args.wpe_taps, |
| | | delay=args.wpe_delay, |
| | | use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe, |
| | | # Beamformer options |
| | | use_beamformer=args.use_beamformer, |
| | | btype=args.btype, |
| | | blayers=args.blayers, |
| | | bunits=args.bunits, |
| | | bprojs=args.bprojs, |
| | | bnmask=args.bnmask, |
| | | badim=args.badim, |
| | | ref_channel=args.ref_channel, |
| | | bdropout_rate=args.bdropout_rate, |
| | | ) |
| New file |
| | |
| | | from typing import Tuple |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from torch.nn import functional as F |
| | | from torch_complex.tensor import ComplexTensor |
| | | |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.modules.rnn.encoders import RNN |
| | | from funasr.modules.rnn.encoders import RNNP |
| | | |
| | | |
| | | class MaskEstimator(torch.nn.Module): |
| | | def __init__(self, type, idim, layers, units, projs, dropout, nmask=1): |
| | | super().__init__() |
| | | subsample = np.ones(layers + 1, dtype=np.int32) |
| | | |
| | | typ = type.lstrip("vgg").rstrip("p") |
| | | if type[-1] == "p": |
| | | self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ) |
| | | else: |
| | | self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ) |
| | | |
| | | self.type = type |
| | | self.nmask = nmask |
| | | self.linears = torch.nn.ModuleList( |
| | | [torch.nn.Linear(projs, idim) for _ in range(nmask)] |
| | | ) |
| | | |
| | | def forward( |
| | | self, xs: ComplexTensor, ilens: torch.LongTensor |
| | | ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: |
| | | """The forward function |
| | | |
| | | Args: |
| | | xs: (B, F, C, T) |
| | | ilens: (B,) |
| | | Returns: |
| | | hs (torch.Tensor): The hidden vector (B, F, C, T) |
| | | masks: A tuple of the masks. (B, F, C, T) |
| | | ilens: (B,) |
| | | """ |
| | | assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0)) |
| | | _, _, C, input_length = xs.size() |
| | | # (B, F, C, T) -> (B, C, T, F) |
| | | xs = xs.permute(0, 2, 3, 1) |
| | | |
| | | # Calculate amplitude: (B, C, T, F) -> (B, C, T, F) |
| | | xs = (xs.real**2 + xs.imag**2) ** 0.5 |
| | | # xs: (B, C, T, F) -> xs: (B * C, T, F) |
| | | xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1)) |
| | | # ilens: (B,) -> ilens_: (B * C) |
| | | ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1) |
| | | |
| | | # xs: (B * C, T, F) -> xs: (B * C, T, D) |
| | | xs, _, _ = self.brnn(xs, ilens_) |
| | | # xs: (B * C, T, D) -> xs: (B, C, T, D) |
| | | xs = xs.view(-1, C, xs.size(-2), xs.size(-1)) |
| | | |
| | | masks = [] |
| | | for linear in self.linears: |
| | | # xs: (B, C, T, D) -> mask:(B, C, T, F) |
| | | mask = linear(xs) |
| | | |
| | | mask = torch.sigmoid(mask) |
| | | # Zero padding |
| | | mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0) |
| | | |
| | | # (B, C, T, F) -> (B, F, C, T) |
| | | mask = mask.permute(0, 3, 1, 2) |
| | | |
| | | # Take cares of multi gpu cases: If input_length > max(ilens) |
| | | if mask.size(-1) < input_length: |
| | | mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0) |
| | | masks.append(mask) |
| | | |
| | | return tuple(masks), ilens |