From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/models/e2e_diar_eend_ola.py |  187 ++++++++++++++++++++--------------------------
 1 files changed, 83 insertions(+), 104 deletions(-)

diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index f589269..a0b545a 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -1,22 +1,20 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
-
 from contextlib import contextmanager
 from distutils.version import LooseVersion
-from typing import Dict
-from typing import Tuple
+from typing import Dict, List, Tuple, Optional
 
 import numpy as np
 import torch
 import torch.nn as  nn
-from typeguard import check_argument_types
+import torch.nn.functional as F
 
+from funasr.models.base_model import FunASRModel
 from funasr.models.frontend.wav_frontend import WavFrontendMel23
 from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
 from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
+from funasr.modules.eend_ola.utils.losses import standard_loss, cal_power_loss, fast_batch_pit_n_speaker_loss
+from funasr.modules.eend_ola.utils.power import create_powerlabel
 from funasr.modules.eend_ola.utils.power import generate_mapping_dict
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     pass
@@ -34,12 +32,35 @@
     return att
 
 
-class DiarEENDOLAModel(AbsESPnetModel):
+def pad_labels(ts, out_size):
+    for i, t in enumerate(ts):
+        if t.shape[1] < out_size:
+            ts[i] = F.pad(
+                t,
+                (0, out_size - t.shape[1], 0, 0),
+                mode='constant',
+                value=0.
+            )
+    return ts
+
+
+def pad_results(ys, out_size):
+    ys_padded = []
+    for i, y in enumerate(ys):
+        if y.shape[1] < out_size:
+            ys_padded.append(
+                torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1))
+        else:
+            ys_padded.append(y)
+    return ys_padded
+
+
+class DiarEENDOLAModel(FunASRModel):
     """EEND-OLA diarization model"""
 
     def __init__(
             self,
-            frontend: WavFrontendMel23,
+            frontend: Optional[WavFrontendMel23],
             encoder: EENDOLATransformerEncoder,
             encoder_decoder_attractor: EncoderDecoderAttractor,
             n_units: int = 256,
@@ -48,11 +69,9 @@
             mapping_dict=None,
             **kwargs,
     ):
-        assert check_argument_types()
-
         super().__init__()
         self.frontend = frontend
-        self.encoder = encoder
+        self.enc = encoder
         self.encoder_decoder_attractor = encoder_decoder_attractor
         self.attractor_loss_weight = attractor_loss_weight
         self.max_n_speaker = max_n_speaker
@@ -60,7 +79,7 @@
             mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
             self.mapping_dict = mapping_dict
         # PostNet
-        self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
+        self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
         self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
 
     def forward_encoder(self, xs, ilens):
@@ -68,7 +87,7 @@
         pad_shape = xs.shape
         xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
         xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
-        emb = self.encoder(xs, xs_mask)
+        emb = self.enc(xs, xs_mask)
         emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
         emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
         return emb
@@ -76,8 +95,9 @@
     def forward_post_net(self, logits, ilens):
         maxlen = torch.max(ilens).to(torch.int).item()
         logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
-        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
-        outputs, (_, _) = self.PostNet(logits)
+        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True,
+                                                   enforce_sorted=False)
+        outputs, (_, _) = self.postnet(logits)
         outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
         outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
         outputs = [self.output_layer(output) for output in outputs]
@@ -85,96 +105,45 @@
 
     def forward(
             self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
+            speech: List[torch.Tensor],
+            speaker_labels: List[torch.Tensor],
+            orders: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-        """Frontend + Encoder + Decoder + Calc loss
 
-        Args:
-            speech: (Batch, Length, ...)
-            speech_lengths: (Batch, )
-            text: (Batch, Length)
-            text_lengths: (Batch,)
-        """
-        assert text_lengths.dim() == 1, text_lengths.shape
         # Check that batch_size is unified
-        assert (
-                speech.shape[0]
-                == speech_lengths.shape[0]
-                == text.shape[0]
-                == text_lengths.shape[0]
-        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
-        batch_size = speech.shape[0]
+        assert (len(speech) == len(speaker_labels)), (len(speech), len(speaker_labels))
+        speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
+        speaker_labels_lengths = torch.tensor([spk.shape[-1] for spk in speaker_labels]).to(torch.int64)
+        batch_size = len(speech)
 
-        # for data-parallel
-        text = text[:, : text_lengths.max()]
+        # Encoder
+        encoder_out = self.forward_encoder(speech, speech_lengths)
 
-        # 1. Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-        intermediate_outs = None
-        if isinstance(encoder_out, tuple):
-            intermediate_outs = encoder_out[1]
-            encoder_out = encoder_out[0]
+        # Encoder-decoder attractor
+        attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)],
+                                                                    speaker_labels_lengths)
+        speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)]
 
-        loss_att, acc_att, cer_att, wer_att = None, None, None, None
-        loss_ctc, cer_ctc = None, None
+        # pit loss
+        pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels)
+        pit_loss = standard_loss(speaker_logits, pit_speaker_labels)
+
+        # pse loss
+        with torch.no_grad():
+            power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).
+                            to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels]
+        pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors]
+        pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)]
+        pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths)
+        pse_loss = cal_power_loss(pse_speaker_logits, power_ts)
+
+        loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss
+
         stats = dict()
-
-        # 1. CTC branch
-        if self.ctc_weight != 0.0:
-            loss_ctc, cer_ctc = self._calc_ctc_loss(
-                encoder_out, encoder_out_lens, text, text_lengths
-            )
-
-            # Collect CTC branch stats
-            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
-            stats["cer_ctc"] = cer_ctc
-
-        # Intermediate CTC (optional)
-        loss_interctc = 0.0
-        if self.interctc_weight != 0.0 and intermediate_outs is not None:
-            for layer_idx, intermediate_out in intermediate_outs:
-                # we assume intermediate_out has the same length & padding
-                # as those of encoder_out
-                loss_ic, cer_ic = self._calc_ctc_loss(
-                    intermediate_out, encoder_out_lens, text, text_lengths
-                )
-                loss_interctc = loss_interctc + loss_ic
-
-                # Collect Intermedaite CTC stats
-                stats["loss_interctc_layer{}".format(layer_idx)] = (
-                    loss_ic.detach() if loss_ic is not None else None
-                )
-                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-
-            loss_interctc = loss_interctc / len(intermediate_outs)
-
-            # calculate whole encoder loss
-            loss_ctc = (
-                               1 - self.interctc_weight
-                       ) * loss_ctc + self.interctc_weight * loss_interctc
-
-        # 2b. Attention decoder branch
-        if self.ctc_weight != 1.0:
-            loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
-                encoder_out, encoder_out_lens, text, text_lengths
-            )
-
-        # 3. CTC-Att loss definition
-        if self.ctc_weight == 0.0:
-            loss = loss_att
-        elif self.ctc_weight == 1.0:
-            loss = loss_ctc
-        else:
-            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
-
-        # Collect Attn branch stats
-        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
-        stats["acc"] = acc_att
-        stats["cer"] = cer_att
-        stats["wer"] = wer_att
+        stats["pse_loss"] = pse_loss.detach()
+        stats["pit_loss"] = pit_loss.detach()
+        stats["attractor_loss"] = attractor_loss.detach()
+        stats["batch_size"] = batch_size
 
         # Collect total loss stats
         stats["loss"] = torch.clone(loss.detach())
@@ -185,14 +154,11 @@
 
     def estimate_sequential(self,
                             speech: torch.Tensor,
-                            speech_lengths: torch.Tensor,
                             n_speakers: int = None,
                             shuffle: bool = True,
                             threshold: float = 0.5,
                             **kwargs):
-        if self.frontend is not None:
-            speech = self.frontend(speech)
-        speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
+        speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
         emb = self.forward_encoder(speech, speech_lengths)
         if shuffle:
             orders = [np.arange(e.shape[0]) for e in emb]
@@ -233,10 +199,23 @@
                 pred[i] = pred[i - 1]
             else:
                 pred[i] = 0
-        pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
+        pred = [self.inv_mapping_func(i) for i in pred]
         decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
         decisions = torch.from_numpy(
             np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
             torch.float32)
         decisions = decisions[:, :n_speaker]
         return decisions
+
+    def inv_mapping_func(self, label):
+
+        if not isinstance(label, int):
+            label = int(label)
+        if label in self.mapping_dict['label2dec'].keys():
+            num = self.mapping_dict['label2dec'][label]
+        else:
+            num = -1
+        return num
+
+    def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
+        pass
\ No newline at end of file

--
Gitblit v1.9.1