From 54931dd4e1a099d7d6f144c4e12e5453deb3aa26 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 28 六月 2023 10:41:57 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main

---
 funasr/models/e2e_diar_eend_ola.py |  369 ++++++++++++++++------------------------------------
 1 files changed, 113 insertions(+), 256 deletions(-)

diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index 967c0d4..da7c674 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -1,38 +1,25 @@
 # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
 #  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
 
-import logging
-import torch
 from contextlib import contextmanager
 from distutils.version import LooseVersion
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
-    LabelSmoothingLoss,  # noqa: H301
-)
-from funasr.models.ctc import CTC
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.modules.add_sos_eos import add_sos_eos
-from funasr.modules.e2e_asr_common import ErrorCalculator
-from funasr.modules.eend_ola.encoder import TransformerEncoder
+from typing import Dict
+from typing import Tuple
+
+import numpy as np
+import torch
+import torch.nn as  nn
+from typeguard import check_argument_types
+
+from funasr.models.frontend.wav_frontend import WavFrontendMel23
+from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
 from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
 from funasr.modules.eend_ola.utils.power import generate_mapping_dict
-from funasr.modules.nets_utils import th_accuracy
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
-from typeguard import check_argument_types
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
+from funasr.models.base_model import FunASRModel
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
-    from torch.cuda.amp import autocast
+    pass
 else:
     # Nothing to do if torch<1.6.0
     @contextmanager
@@ -40,13 +27,22 @@
         yield
 
 
-class DiarEENDOLAModel(AbsESPnetModel):
-    """CTC-attention hybrid Encoder-Decoder model"""
+def pad_attractor(att, max_n_speakers):
+    C, D = att.shape
+    if C < max_n_speakers:
+        att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
+    return att
+
+
+class DiarEENDOLAModel(FunASRModel):
+    """EEND-OLA diarization model"""
 
     def __init__(
             self,
-            encoder: TransformerEncoder,
-            eda: EncoderDecoderAttractor,
+            frontend: WavFrontendMel23,
+            encoder: EENDOLATransformerEncoder,
+            encoder_decoder_attractor: EncoderDecoderAttractor,
+            n_units: int = 256,
             max_n_speaker: int = 8,
             attractor_loss_weight: float = 1.0,
             mapping_dict=None,
@@ -55,13 +51,37 @@
         assert check_argument_types()
 
         super().__init__()
-        self.encoder = encoder
-        self.eda = eda
+        self.frontend = frontend
+        self.enc = encoder
+        self.eda = encoder_decoder_attractor
         self.attractor_loss_weight = attractor_loss_weight
         self.max_n_speaker = max_n_speaker
         if mapping_dict is None:
             mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
             self.mapping_dict = mapping_dict
+        # PostNet
+        self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
+        self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
+
+    def forward_encoder(self, xs, ilens):
+        xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
+        pad_shape = xs.shape
+        xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
+        xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
+        emb = self.enc(xs, xs_mask)
+        emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
+        emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
+        return emb
+
+    def forward_post_net(self, logits, ilens):
+        maxlen = torch.max(ilens).to(torch.int).item()
+        logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
+        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
+        outputs, (_, _) = self.postnet(logits)
+        outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
+        outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
+        outputs = [self.output_layer(output) for output in outputs]
+        return outputs
 
     def forward(
             self,
@@ -71,7 +91,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -92,7 +111,7 @@
         text = text[:, : text_lengths.max()]
 
         # 1. Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
         intermediate_outs = None
         if isinstance(encoder_out, tuple):
             intermediate_outs = encoder_out[1]
@@ -163,233 +182,71 @@
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
-    def collect_feats(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
-    ) -> Dict[str, torch.Tensor]:
-        if self.extract_feats_in_collect_stats:
-            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+    def estimate_sequential(self,
+                            speech: torch.Tensor,
+                            speech_lengths: torch.Tensor,
+                            n_speakers: int = None,
+                            shuffle: bool = True,
+                            threshold: float = 0.5,
+                            **kwargs):
+        speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
+        emb = self.forward_encoder(speech, speech_lengths)
+        if shuffle:
+            orders = [np.arange(e.shape[0]) for e in emb]
+            for order in orders:
+                np.random.shuffle(order)
+            attractors, probs = self.eda.estimate(
+                [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
         else:
-            # Generate dummy stats if extract_feats_in_collect_stats is False
-            logging.warning(
-                "Generating dummy stats for feats and feats_lengths, "
-                "because encoder_conf.extract_feats_in_collect_stats is "
-                f"{self.extract_feats_in_collect_stats}"
-            )
-            feats, feats_lengths = speech, speech_lengths
-        return {"feats": feats, "feats_lengths": feats_lengths}
+            attractors, probs = self.eda.estimate(emb)
+        attractors_active = []
+        for p, att, e in zip(probs, attractors, emb):
+            if n_speakers and n_speakers >= 0:
+                att = att[:n_speakers, ]
+                attractors_active.append(att)
+            elif threshold is not None:
+                silence = torch.nonzero(p < threshold)[0]
+                n_spk = silence[0] if silence.size else None
+                att = att[:n_spk, ]
+                attractors_active.append(att)
+            else:
+                NotImplementedError('n_speakers or threshold has to be given.')
+        raw_n_speakers = [att.shape[0] for att in attractors_active]
+        attractors = [
+            pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
+            for att in attractors_active]
+        ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
+        logits = self.forward_post_net(ys, speech_lengths)
+        ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
+              zip(logits, raw_n_speakers)]
 
-    def encode(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Frontend + Encoder. Note that this method is used by asr_inference.py
+        return ys, emb, attractors, raw_n_speakers
 
-        Args:
-            speech: (Batch, Length, ...)
-            speech_lengths: (Batch, )
-        """
-        with autocast(False):
-            # 1. Extract feats
-            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+    def recover_y_from_powerlabel(self, logit, n_speaker):
+        pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
+        oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
+        for i in oov_index:
+            if i > 0:
+                pred[i] = pred[i - 1]
+            else:
+                pred[i] = 0
+        pred = [self.inv_mapping_func(i) for i in pred]
+        decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
+        decisions = torch.from_numpy(
+            np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
+            torch.float32)
+        decisions = decisions[:, :n_speaker]
+        return decisions
 
-            # 2. Data augmentation
-            if self.specaug is not None and self.training:
-                feats, feats_lengths = self.specaug(feats, feats_lengths)
+    def inv_mapping_func(self, label):
 
-            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
-            if self.normalize is not None:
-                feats, feats_lengths = self.normalize(feats, feats_lengths)
-
-        # Pre-encoder, e.g. used for raw input data
-        if self.preencoder is not None:
-            feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
-        # 4. Forward encoder
-        # feats: (Batch, Length, Dim)
-        # -> encoder_out: (Batch, Length2, Dim2)
-        if self.encoder.interctc_use_conditioning:
-            encoder_out, encoder_out_lens, _ = self.encoder(
-                feats, feats_lengths, ctc=self.ctc
-            )
+        if not isinstance(label, int):
+            label = int(label)
+        if label in self.mapping_dict['label2dec'].keys():
+            num = self.mapping_dict['label2dec'][label]
         else:
-            encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
-        intermediate_outs = None
-        if isinstance(encoder_out, tuple):
-            intermediate_outs = encoder_out[1]
-            encoder_out = encoder_out[0]
+            num = -1
+        return num
 
-        # Post-encoder, e.g. NLU
-        if self.postencoder is not None:
-            encoder_out, encoder_out_lens = self.postencoder(
-                encoder_out, encoder_out_lens
-            )
-
-        assert encoder_out.size(0) == speech.size(0), (
-            encoder_out.size(),
-            speech.size(0),
-        )
-        assert encoder_out.size(1) <= encoder_out_lens.max(), (
-            encoder_out.size(),
-            encoder_out_lens.max(),
-        )
-
-        if intermediate_outs is not None:
-            return (encoder_out, intermediate_outs), encoder_out_lens
-
-        return encoder_out, encoder_out_lens
-
-    def _extract_feats(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        assert speech_lengths.dim() == 1, speech_lengths.shape
-
-        # for data-parallel
-        speech = speech[:, : speech_lengths.max()]
-
-        if self.frontend is not None:
-            # Frontend
-            #  e.g. STFT and Feature extract
-            #       data_loader may send time-domain signal in this case
-            # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
-            feats, feats_lengths = self.frontend(speech, speech_lengths)
-        else:
-            # No frontend and no feature extract
-            feats, feats_lengths = speech, speech_lengths
-        return feats, feats_lengths
-
-    def nll(
-            self,
-            encoder_out: torch.Tensor,
-            encoder_out_lens: torch.Tensor,
-            ys_pad: torch.Tensor,
-            ys_pad_lens: torch.Tensor,
-    ) -> torch.Tensor:
-        """Compute negative log likelihood(nll) from transformer-decoder
-
-        Normally, this function is called in batchify_nll.
-
-        Args:
-            encoder_out: (Batch, Length, Dim)
-            encoder_out_lens: (Batch,)
-            ys_pad: (Batch, Length)
-            ys_pad_lens: (Batch,)
-        """
-        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
-        ys_in_lens = ys_pad_lens + 1
-
-        # 1. Forward decoder
-        decoder_out, _ = self.decoder(
-            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
-        )  # [batch, seqlen, dim]
-        batch_size = decoder_out.size(0)
-        decoder_num_class = decoder_out.size(2)
-        # nll: negative log-likelihood
-        nll = torch.nn.functional.cross_entropy(
-            decoder_out.view(-1, decoder_num_class),
-            ys_out_pad.view(-1),
-            ignore_index=self.ignore_id,
-            reduction="none",
-        )
-        nll = nll.view(batch_size, -1)
-        nll = nll.sum(dim=1)
-        assert nll.size(0) == batch_size
-        return nll
-
-    def batchify_nll(
-            self,
-            encoder_out: torch.Tensor,
-            encoder_out_lens: torch.Tensor,
-            ys_pad: torch.Tensor,
-            ys_pad_lens: torch.Tensor,
-            batch_size: int = 100,
-    ):
-        """Compute negative log likelihood(nll) from transformer-decoder
-
-        To avoid OOM, this fuction seperate the input into batches.
-        Then call nll for each batch and combine and return results.
-        Args:
-            encoder_out: (Batch, Length, Dim)
-            encoder_out_lens: (Batch,)
-            ys_pad: (Batch, Length)
-            ys_pad_lens: (Batch,)
-            batch_size: int, samples each batch contain when computing nll,
-                        you may change this to avoid OOM or increase
-                        GPU memory usage
-        """
-        total_num = encoder_out.size(0)
-        if total_num <= batch_size:
-            nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-        else:
-            nll = []
-            start_idx = 0
-            while True:
-                end_idx = min(start_idx + batch_size, total_num)
-                batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
-                batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
-                batch_ys_pad = ys_pad[start_idx:end_idx, :]
-                batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
-                batch_nll = self.nll(
-                    batch_encoder_out,
-                    batch_encoder_out_lens,
-                    batch_ys_pad,
-                    batch_ys_pad_lens,
-                )
-                nll.append(batch_nll)
-                start_idx = end_idx
-                if start_idx == total_num:
-                    break
-            nll = torch.cat(nll)
-        assert nll.size(0) == total_num
-        return nll
-
-    def _calc_att_loss(
-            self,
-            encoder_out: torch.Tensor,
-            encoder_out_lens: torch.Tensor,
-            ys_pad: torch.Tensor,
-            ys_pad_lens: torch.Tensor,
-    ):
-        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
-        ys_in_lens = ys_pad_lens + 1
-
-        # 1. Forward decoder
-        decoder_out, _ = self.decoder(
-            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
-        )
-
-        # 2. Compute attention loss
-        loss_att = self.criterion_att(decoder_out, ys_out_pad)
-        acc_att = th_accuracy(
-            decoder_out.view(-1, self.vocab_size),
-            ys_out_pad,
-            ignore_label=self.ignore_id,
-        )
-
-        # Compute cer/wer using attention-decoder
-        if self.training or self.error_calculator is None:
-            cer_att, wer_att = None, None
-        else:
-            ys_hat = decoder_out.argmax(dim=-1)
-            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
-        return loss_att, acc_att, cer_att, wer_att
-
-    def _calc_ctc_loss(
-            self,
-            encoder_out: torch.Tensor,
-            encoder_out_lens: torch.Tensor,
-            ys_pad: torch.Tensor,
-            ys_pad_lens: torch.Tensor,
-    ):
-        # Calc CTC loss
-        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-
-        # Calc CER using CTC
-        cer_ctc = None
-        if not self.training and self.error_calculator is not None:
-            ys_hat = self.ctc.argmax(encoder_out).data
-            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
-        return loss_ctc, cer_ctc
+    def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
+        pass
\ No newline at end of file

--
Gitblit v1.9.1