From 73a7cf596bd1ebc3bd0c9674a21072f9eaf4cc57 Mon Sep 17 00:00:00 2001
From: dyyzhmm <dyyzhmm@163.com>
Date: 星期四, 16 三月 2023 15:24:56 +0800
Subject: [PATCH] Merge pull request #3 from alibaba-damo-academy/main

---
 funasr/models/e2e_diar_eend_ola.py |   35 +++++++++++++++++++++++------------
 1 files changed, 23 insertions(+), 12 deletions(-)

diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index f589269..097b23a 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -52,15 +52,15 @@
 
         super().__init__()
         self.frontend = frontend
-        self.encoder = encoder
-        self.encoder_decoder_attractor = encoder_decoder_attractor
+        self.enc = encoder
+        self.eda = encoder_decoder_attractor
         self.attractor_loss_weight = attractor_loss_weight
         self.max_n_speaker = max_n_speaker
         if mapping_dict is None:
             mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
             self.mapping_dict = mapping_dict
         # PostNet
-        self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
+        self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
         self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
 
     def forward_encoder(self, xs, ilens):
@@ -68,7 +68,7 @@
         pad_shape = xs.shape
         xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
         xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
-        emb = self.encoder(xs, xs_mask)
+        emb = self.enc(xs, xs_mask)
         emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
         emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
         return emb
@@ -76,8 +76,8 @@
     def forward_post_net(self, logits, ilens):
         maxlen = torch.max(ilens).to(torch.int).item()
         logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
-        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
-        outputs, (_, _) = self.PostNet(logits)
+        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
+        outputs, (_, _) = self.postnet(logits)
         outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
         outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
         outputs = [self.output_layer(output) for output in outputs]
@@ -112,7 +112,7 @@
         text = text[:, : text_lengths.max()]
 
         # 1. Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
         intermediate_outs = None
         if isinstance(encoder_out, tuple):
             intermediate_outs = encoder_out[1]
@@ -190,18 +190,16 @@
                             shuffle: bool = True,
                             threshold: float = 0.5,
                             **kwargs):
-        if self.frontend is not None:
-            speech = self.frontend(speech)
         speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
         emb = self.forward_encoder(speech, speech_lengths)
         if shuffle:
             orders = [np.arange(e.shape[0]) for e in emb]
             for order in orders:
                 np.random.shuffle(order)
-            attractors, probs = self.encoder_decoder_attractor.estimate(
+            attractors, probs = self.eda.estimate(
                 [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
         else:
-            attractors, probs = self.encoder_decoder_attractor.estimate(emb)
+            attractors, probs = self.eda.estimate(emb)
         attractors_active = []
         for p, att, e in zip(probs, attractors, emb):
             if n_speakers and n_speakers >= 0:
@@ -233,10 +231,23 @@
                 pred[i] = pred[i - 1]
             else:
                 pred[i] = 0
-        pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
+        pred = [self.inv_mapping_func(i) for i in pred]
         decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
         decisions = torch.from_numpy(
             np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
             torch.float32)
         decisions = decisions[:, :n_speaker]
         return decisions
+
+    def inv_mapping_func(self, label):
+
+        if not isinstance(label, int):
+            label = int(label)
+        if label in self.mapping_dict['label2dec'].keys():
+            num = self.mapping_dict['label2dec'][label]
+        else:
+            num = -1
+        return num
+
+    def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
+        pass
\ No newline at end of file

--
Gitblit v1.9.1