From 1e4eba6a72ea97d9a9e733df3e3b1eb86e4fd44d Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 14 三月 2023 23:22:31 +0800
Subject: [PATCH] update

---
 funasr/models/e2e_diar_eend_ola.py |   75 +++++++++++++++++++++++++------------
 1 files changed, 50 insertions(+), 25 deletions(-)

diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index 5c1c9ce..f589269 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -11,7 +11,8 @@
 import torch.nn as  nn
 from typeguard import check_argument_types
 
-from funasr.modules.eend_ola.encoder import TransformerEncoder
+from funasr.models.frontend.wav_frontend import WavFrontendMel23
+from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
 from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
 from funasr.modules.eend_ola.utils.power import generate_mapping_dict
 from funasr.torch_utils.device_funcs import force_gatherable
@@ -26,13 +27,21 @@
         yield
 
 
+def pad_attractor(att, max_n_speakers):
+    C, D = att.shape
+    if C < max_n_speakers:
+        att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
+    return att
+
+
 class DiarEENDOLAModel(AbsESPnetModel):
-    """CTC-attention hybrid Encoder-Decoder model"""
+    """EEND-OLA diarization model"""
 
     def __init__(
             self,
-            encoder: TransformerEncoder,
-            eda: EncoderDecoderAttractor,
+            frontend: WavFrontendMel23,
+            encoder: EENDOLATransformerEncoder,
+            encoder_decoder_attractor: EncoderDecoderAttractor,
             n_units: int = 256,
             max_n_speaker: int = 8,
             attractor_loss_weight: float = 1.0,
@@ -42,8 +51,9 @@
         assert check_argument_types()
 
         super().__init__()
+        self.frontend = frontend
         self.encoder = encoder
-        self.eda = eda
+        self.encoder_decoder_attractor = encoder_decoder_attractor
         self.attractor_loss_weight = attractor_loss_weight
         self.max_n_speaker = max_n_speaker
         if mapping_dict is None:
@@ -52,6 +62,26 @@
         # PostNet
         self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
         self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
+
+    def forward_encoder(self, xs, ilens):
+        xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
+        pad_shape = xs.shape
+        xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
+        xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
+        emb = self.encoder(xs, xs_mask)
+        emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
+        emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
+        return emb
+
+    def forward_post_net(self, logits, ilens):
+        maxlen = torch.max(ilens).to(torch.int).item()
+        logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
+        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
+        outputs, (_, _) = self.PostNet(logits)
+        outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
+        outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
+        outputs = [self.output_layer(output) for output in outputs]
+        return outputs
 
     def forward(
             self,
@@ -156,51 +186,47 @@
     def estimate_sequential(self,
                             speech: torch.Tensor,
                             speech_lengths: torch.Tensor,
-                            n_speakers: int,
-                            shuffle: bool,
-                            threshold: float,
+                            n_speakers: int = None,
+                            shuffle: bool = True,
+                            threshold: float = 0.5,
                             **kwargs):
+        if self.frontend is not None:
+            speech = self.frontend(speech)
         speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
-        emb = self.forward_core(speech)  # list, [(T1, C1), ..., (T1, C1)]
+        emb = self.forward_encoder(speech, speech_lengths)
         if shuffle:
             orders = [np.arange(e.shape[0]) for e in emb]
             for order in orders:
                 np.random.shuffle(order)
-            # e[order]: shuffle鍚庣殑embeddings, list, [(T1, C1), ..., (T1, C1)]  姣忎釜sample鐨凾缁村害宸茶繘琛岄殢鏈洪『搴忎氦鎹�
-            # attractors, list, hts(璁烘枃閲岀殑as), [(max_n_speakers, n_units), ..., (max_n_speakers, n_units)]
-            # probs, list, [(max_n_speakers, ), ..., (max_n_speakers, ]
-            attractors, probs = self.eda.estimate(
-                [e[torch.from_numpy(order).to(torch.long).to(xs[0].device)] for e, order in zip(emb, orders)])
+            attractors, probs = self.encoder_decoder_attractor.estimate(
+                [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
         else:
-            attractors, probs = self.eda.estimate(emb)
+            attractors, probs = self.encoder_decoder_attractor.estimate(emb)
         attractors_active = []
         for p, att, e in zip(probs, attractors, emb):
-            if n_speakers and n_speakers >= 0:  # 鏍规嵁鎸囧畾璇磋瘽浜烘暟, 閫夋嫨瀵瑰簲鏁伴噺鐨剏s
-                # TODO锛氬湪娴嬭瘯鏈変笉鍚屾暟閲弒peaker鏁扮殑鏁版嵁闆嗘椂锛岃�冭檻鏀规垚鏍规嵁sample鏉ョ‘瀹氬叿浣撶殑speaker鏁帮紝鑰屼笉鏄洿鎺ユ寚瀹�
-                # raise NotImplementedError
+            if n_speakers and n_speakers >= 0:
                 att = att[:n_speakers, ]
                 attractors_active.append(att)
             elif threshold is not None:
-                silence = torch.nonzero(p < threshold)[0]  # 鎵惧埌绗竴涓緭鍑烘鐜囧皬浜庨槇鍊肩殑绱㈠紩, 浣滀负缁撴潫, 涓斿�煎垰濂界瓑浜庤璇濅汉鏁�
+                silence = torch.nonzero(p < threshold)[0]
                 n_spk = silence[0] if silence.size else None
                 att = att[:n_spk, ]
                 attractors_active.append(att)
             else:
-                NotImplementedError('n_speakers or th has to be given.')
-        raw_n_speakers = [att.shape[0] for att in attractors_active]  # [C1, C2, ..., CB]
+                NotImplementedError('n_speakers or threshold has to be given.')
+        raw_n_speakers = [att.shape[0] for att in attractors_active]
         attractors = [
             pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
             for att in attractors_active]
         ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
-        # ys_eda = [torch.sigmoid(y[:, :n_spk]) for y,n_spk in zip(ys, raw_n_speakers)]
-        logits = self.cal_postnet(ys, self.max_n_speaker)
+        logits = self.forward_post_net(ys, speech_lengths)
         ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
               zip(logits, raw_n_speakers)]
 
         return ys, emb, attractors, raw_n_speakers
 
     def recover_y_from_powerlabel(self, logit, n_speaker):
-        pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)  # (T, )
+        pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
         oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
         for i in oov_index:
             if i > 0:
@@ -208,7 +234,6 @@
             else:
                 pred[i] = 0
         pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
-        # print(pred)
         decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
         decisions = torch.from_numpy(
             np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(

--
Gitblit v1.9.1