From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/eend/e2e_diar_eend_ola.py |  162 ++++++++++++++++++++++++++++++++++-------------------
 1 files changed, 103 insertions(+), 59 deletions(-)

diff --git a/funasr/models/eend/e2e_diar_eend_ola.py b/funasr/models/eend/e2e_diar_eend_ola.py
index 28aa223..cae5d1f 100644
--- a/funasr/models/eend/e2e_diar_eend_ola.py
+++ b/funasr/models/eend/e2e_diar_eend_ola.py
@@ -4,13 +4,17 @@
 
 import numpy as np
 import torch
-import torch.nn as  nn
+import torch.nn as nn
 import torch.nn.functional as F
 
-from funasr.models.frontend.wav_frontend import WavFrontendMel23
+from funasr.frontends.wav_frontend import WavFrontendMel23
 from funasr.models.eend.encoder import EENDOLATransformerEncoder
 from funasr.models.eend.encoder_decoder_attractor import EncoderDecoderAttractor
-from funasr.models.eend.utils.losses import standard_loss, cal_power_loss, fast_batch_pit_n_speaker_loss
+from funasr.models.eend.utils.losses import (
+    standard_loss,
+    cal_power_loss,
+    fast_batch_pit_n_speaker_loss,
+)
 from funasr.models.eend.utils.power import create_powerlabel
 from funasr.models.eend.utils.power import generate_mapping_dict
 from funasr.train_utils.device_funcs import force_gatherable
@@ -27,19 +31,16 @@
 def pad_attractor(att, max_n_speakers):
     C, D = att.shape
     if C < max_n_speakers:
-        att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
+        att = torch.cat(
+            [att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0
+        )
     return att
 
 
 def pad_labels(ts, out_size):
     for i, t in enumerate(ts):
         if t.shape[1] < out_size:
-            ts[i] = F.pad(
-                t,
-                (0, out_size - t.shape[1], 0, 0),
-                mode='constant',
-                value=0.
-            )
+            ts[i] = F.pad(t, (0, out_size - t.shape[1], 0, 0), mode="constant", value=0.0)
     return ts
 
 
@@ -48,7 +49,16 @@
     for i, y in enumerate(ys):
         if y.shape[1] < out_size:
             ys_padded.append(
-                torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1))
+                torch.cat(
+                    [
+                        y,
+                        torch.zeros(y.shape[0], out_size - y.shape[1])
+                        .to(torch.float32)
+                        .to(y.device),
+                    ],
+                    dim=1,
+                )
+            )
         else:
             ys_padded.append(y)
     return ys_padded
@@ -58,15 +68,15 @@
     """EEND-OLA diarization model"""
 
     def __init__(
-            self,
-            frontend: Optional[WavFrontendMel23],
-            encoder: EENDOLATransformerEncoder,
-            encoder_decoder_attractor: EncoderDecoderAttractor,
-            n_units: int = 256,
-            max_n_speaker: int = 8,
-            attractor_loss_weight: float = 1.0,
-            mapping_dict=None,
-            **kwargs,
+        self,
+        frontend: Optional[WavFrontendMel23],
+        encoder: EENDOLATransformerEncoder,
+        encoder_decoder_attractor: EncoderDecoderAttractor,
+        n_units: int = 256,
+        max_n_speaker: int = 8,
+        attractor_loss_weight: float = 1.0,
+        mapping_dict=None,
+        **kwargs,
     ):
         super().__init__()
         self.frontend = frontend
@@ -79,13 +89,15 @@
             self.mapping_dict = mapping_dict
         # PostNet
         self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
-        self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
+        self.output_layer = nn.Linear(n_units, mapping_dict["oov"] + 1)
 
     def forward_encoder(self, xs, ilens):
         xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
         pad_shape = xs.shape
         xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
-        xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
+        xs_mask = torch.nn.utils.rnn.pad_sequence(
+            xs_mask, batch_first=True, padding_value=0
+        ).unsqueeze(-2)
         emb = self.enc(xs, xs_mask)
         emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
         emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
@@ -94,34 +106,42 @@
     def forward_post_net(self, logits, ilens):
         maxlen = torch.max(ilens).to(torch.int).item()
         logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
-        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True,
-                                                   enforce_sorted=False)
+        logits = nn.utils.rnn.pack_padded_sequence(
+            logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False
+        )
         outputs, (_, _) = self.postnet(logits)
-        outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
-        outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
+        outputs = nn.utils.rnn.pad_packed_sequence(
+            outputs, batch_first=True, padding_value=-1, total_length=maxlen
+        )[0]
+        outputs = [output[: ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
         outputs = [self.output_layer(output) for output in outputs]
         return outputs
 
     def forward(
-            self,
-            speech: List[torch.Tensor],
-            speaker_labels: List[torch.Tensor],
-            orders: torch.Tensor,
+        self,
+        speech: List[torch.Tensor],
+        speaker_labels: List[torch.Tensor],
+        orders: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
 
         # Check that batch_size is unified
-        assert (len(speech) == len(speaker_labels)), (len(speech), len(speaker_labels))
+        assert len(speech) == len(speaker_labels), (len(speech), len(speaker_labels))
         speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
-        speaker_labels_lengths = torch.tensor([spk.shape[-1] for spk in speaker_labels]).to(torch.int64)
+        speaker_labels_lengths = torch.tensor([spk.shape[-1] for spk in speaker_labels]).to(
+            torch.int64
+        )
         batch_size = len(speech)
 
         # Encoder
         encoder_out = self.forward_encoder(speech, speech_lengths)
 
         # Encoder-decoder attractor
-        attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)],
-                                                                    speaker_labels_lengths)
-        speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)]
+        attractor_loss, attractors = self.encoder_decoder_attractor(
+            [e[order] for e, order in zip(encoder_out, orders)], speaker_labels_lengths
+        )
+        speaker_logits = [
+            torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)
+        ]
 
         # pit loss
         pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels)
@@ -129,10 +149,17 @@
 
         # pse loss
         with torch.no_grad():
-            power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).
-                            to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels]
+            power_ts = [
+                create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).to(
+                    encoder_out[0].device, non_blocking=True
+                )
+                for label in pit_speaker_labels
+            ]
         pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors]
-        pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)]
+        pse_speaker_logits = [
+            torch.matmul(e, pad_att.permute(1, 0))
+            for e, pad_att in zip(encoder_out, pad_attractors)
+        ]
         pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths)
         pse_loss = cal_power_loss(pse_speaker_logits, power_ts)
 
@@ -151,12 +178,14 @@
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
-    def estimate_sequential(self,
-                            speech: torch.Tensor,
-                            n_speakers: int = None,
-                            shuffle: bool = True,
-                            threshold: float = 0.5,
-                            **kwargs):
+    def estimate_sequential(
+        self,
+        speech: torch.Tensor,
+        n_speakers: int = None,
+        shuffle: bool = True,
+        threshold: float = 0.5,
+        **kwargs,
+    ):
         speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
         emb = self.forward_encoder(speech, speech_lengths)
         if shuffle:
@@ -164,35 +193,46 @@
             for order in orders:
                 np.random.shuffle(order)
             attractors, probs = self.encoder_decoder_attractor.estimate(
-                [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
+                [
+                    e[torch.from_numpy(order).to(torch.long).to(speech[0].device)]
+                    for e, order in zip(emb, orders)
+                ]
+            )
         else:
             attractors, probs = self.encoder_decoder_attractor.estimate(emb)
         attractors_active = []
         for p, att, e in zip(probs, attractors, emb):
             if n_speakers and n_speakers >= 0:
-                att = att[:n_speakers, ]
+                att = att[:n_speakers,]
                 attractors_active.append(att)
             elif threshold is not None:
                 silence = torch.nonzero(p < threshold)[0]
                 n_spk = silence[0] if silence.size else None
-                att = att[:n_spk, ]
+                att = att[:n_spk,]
                 attractors_active.append(att)
             else:
-                NotImplementedError('n_speakers or threshold has to be given.')
+                NotImplementedError("n_speakers or threshold has to be given.")
         raw_n_speakers = [att.shape[0] for att in attractors_active]
         attractors = [
-            pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
-            for att in attractors_active]
+            (
+                pad_attractor(att, self.max_n_speaker)
+                if att.shape[0] <= self.max_n_speaker
+                else att[: self.max_n_speaker]
+            )
+            for att in attractors_active
+        ]
         ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
         logits = self.forward_post_net(ys, speech_lengths)
-        ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
-              zip(logits, raw_n_speakers)]
+        ys = [
+            self.recover_y_from_powerlabel(logit, raw_n_speaker)
+            for logit, raw_n_speaker in zip(logits, raw_n_speakers)
+        ]
 
         return ys, emb, attractors, raw_n_speakers
 
     def recover_y_from_powerlabel(self, logit, n_speaker):
         pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
-        oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
+        oov_index = torch.where(pred == self.mapping_dict["oov"])[0]
         for i in oov_index:
             if i > 0:
                 pred[i] = pred[i - 1]
@@ -200,9 +240,13 @@
                 pred[i] = 0
         pred = [self.inv_mapping_func(i) for i in pred]
         decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
-        decisions = torch.from_numpy(
-            np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
-            torch.float32)
+        decisions = (
+            torch.from_numpy(
+                np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)
+            )
+            .to(logit.device)
+            .to(torch.float32)
+        )
         decisions = decisions[:, :n_speaker]
         return decisions
 
@@ -210,11 +254,11 @@
 
         if not isinstance(label, int):
             label = int(label)
-        if label in self.mapping_dict['label2dec'].keys():
-            num = self.mapping_dict['label2dec'][label]
+        if label in self.mapping_dict["label2dec"].keys():
+            num = self.mapping_dict["label2dec"][label]
         else:
             num = -1
         return num
 
     def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
-        pass
\ No newline at end of file
+        pass

--
Gitblit v1.9.1