From ff0310bfb4ed69f00cbeab89a58f958ae5091d70 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 06 七月 2023 16:24:35 +0800
Subject: [PATCH] update eend_ola

---
 funasr/datasets/small_datasets/sequence_iter_factory.py |    4 
 funasr/build_utils/build_args.py                        |    6 +
 funasr/modules/eend_ola/encoder.py                      |   20 --
 funasr/build_utils/build_dataloader.py                  |   17 ++
 funasr/modules/eend_ola/utils/losses.py                 |   77 ++++--------
 funasr/build_utils/build_diar_model.py                  |    6 
 funasr/modules/eend_ola/eend_ola_dataloader.py          |   57 +++++++++
 funasr/models/e2e_diar_eend_ola.py                      |  167 ++++++++++++---------------
 8 files changed, 184 insertions(+), 170 deletions(-)

diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py
index 632c134..31f210e 100644
--- a/funasr/build_utils/build_args.py
+++ b/funasr/build_utils/build_args.py
@@ -86,6 +86,12 @@
         from funasr.build_utils.build_diar_model import class_choices_list
         for class_choices in class_choices_list:
             class_choices.add_arguments(task_parser)
+        task_parser.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of input dimension of the feature",
+        )
 
     elif args.task_name == "sv":
         from funasr.build_utils.build_sv_model import class_choices_list
diff --git a/funasr/build_utils/build_dataloader.py b/funasr/build_utils/build_dataloader.py
index c95c40d..473097e 100644
--- a/funasr/build_utils/build_dataloader.py
+++ b/funasr/build_utils/build_dataloader.py
@@ -4,8 +4,21 @@
 
 def build_dataloader(args):
     if args.dataset_type == "small":
-        train_iter_factory = SequenceIterFactory(args, mode="train")
-        valid_iter_factory = SequenceIterFactory(args, mode="valid")
+        if args.task_name == "diar" and args.model == "eend_ola":
+            from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader
+            train_iter_factory = EENDOLADataLoader(
+                data_file=args.train_data_path_and_name_and_type[0][0],
+                batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+                num_workers=args.dataset_conf["num_workers"],
+                shuffle=True)
+            valid_iter_factory = EENDOLADataLoader(
+                data_file=args.valid_data_path_and_name_and_type[0][0],
+                batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+                num_workers=0,
+                shuffle=False)
+        else:
+            train_iter_factory = SequenceIterFactory(args, mode="train")
+            valid_iter_factory = SequenceIterFactory(args, mode="valid")
     elif args.dataset_type == "large":
         train_iter_factory = LargeDataLoader(args, mode="train")
         valid_iter_factory = LargeDataLoader(args, mode="valid")
diff --git a/funasr/build_utils/build_diar_model.py b/funasr/build_utils/build_diar_model.py
index 0ea3127..444636a 100644
--- a/funasr/build_utils/build_diar_model.py
+++ b/funasr/build_utils/build_diar_model.py
@@ -198,16 +198,14 @@
             frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
         else:
             frontend = frontend_class(**args.frontend_conf)
-        input_size = frontend.output_size()
     else:
         args.frontend = None
         args.frontend_conf = {}
         frontend = None
-        input_size = args.input_size
 
     # encoder
     encoder_class = encoder_choices.get_class(args.encoder)
-    encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+    encoder = encoder_class(**args.encoder_conf)
 
     if args.model == "sond":
         # data augmentation for spectrogram
@@ -272,7 +270,7 @@
             **args.model_conf,
         )
 
-    elif args.model_name == "eend_ola":
+    elif args.model == "eend_ola":
         # encoder-decoder attractor
         encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
         encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
diff --git a/funasr/datasets/small_datasets/sequence_iter_factory.py b/funasr/datasets/small_datasets/sequence_iter_factory.py
index 3ebcc5a..e748c3d 100644
--- a/funasr/datasets/small_datasets/sequence_iter_factory.py
+++ b/funasr/datasets/small_datasets/sequence_iter_factory.py
@@ -57,7 +57,7 @@
             data_path_and_name_and_type,
             preprocess=preprocess_fn,
             dest_sample_rate=dest_sample_rate,
-            speed_perturb=args.speed_perturb if mode=="train" else None,
+            speed_perturb=args.speed_perturb if mode == "train" else None,
         )
 
         # sampler
@@ -84,7 +84,7 @@
             args.max_update = len(bs_list) * args.max_epoch
             logging.info("Max update: {}".format(args.max_update))
 
-        if args.distributed and mode=="train":
+        if args.distributed and mode == "train":
             world_size = torch.distributed.get_world_size()
             rank = torch.distributed.get_rank()
             for batch in batches:
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index ae3a436..af0fd62 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -1,21 +1,21 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
-
 from contextlib import contextmanager
 from distutils.version import LooseVersion
-from typing import Dict
-from typing import Tuple
+from typing import Dict, List, Tuple, Optional
 
 import numpy as np
 import torch
 import torch.nn as  nn
+import torch.nn.functional as F
+from typeguard import check_argument_types
 
+from funasr.models.base_model import FunASRModel
 from funasr.models.frontend.wav_frontend import WavFrontendMel23
 from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
 from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
+from funasr.modules.eend_ola.utils.losses import fast_batch_pit_n_speaker_loss, standard_loss, cal_power_loss
+from funasr.modules.eend_ola.utils.power import create_powerlabel
 from funasr.modules.eend_ola.utils.power import generate_mapping_dict
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.models.base_model import FunASRModel
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     pass
@@ -33,12 +33,35 @@
     return att
 
 
+def pad_labels(ts, out_size):
+    for i, t in enumerate(ts):
+        if t.shape[1] < out_size:
+            ts[i] = F.pad(
+                t,
+                (0, out_size - t.shape[1], 0, 0),
+                mode='constant',
+                value=0.
+            )
+    return ts
+
+
+def pad_results(ys, out_size):
+    ys_padded = []
+    for i, y in enumerate(ys):
+        if y.shape[1] < out_size:
+            ys_padded.append(
+                torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1))
+        else:
+            ys_padded.append(y)
+    return ys_padded
+
+
 class DiarEENDOLAModel(FunASRModel):
     """EEND-OLA diarization model"""
 
     def __init__(
             self,
-            frontend: WavFrontendMel23,
+            frontend: Optional[WavFrontendMel23],
             encoder: EENDOLATransformerEncoder,
             encoder_decoder_attractor: EncoderDecoderAttractor,
             n_units: int = 256,
@@ -47,11 +70,12 @@
             mapping_dict=None,
             **kwargs,
     ):
+        assert check_argument_types()
 
         super().__init__()
         self.frontend = frontend
         self.enc = encoder
-        self.eda = encoder_decoder_attractor
+        self.encoder_decoder_attractor = encoder_decoder_attractor
         self.attractor_loss_weight = attractor_loss_weight
         self.max_n_speaker = max_n_speaker
         if mapping_dict is None:
@@ -74,7 +98,8 @@
     def forward_post_net(self, logits, ilens):
         maxlen = torch.max(ilens).to(torch.int).item()
         logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
-        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
+        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True,
+                                                   enforce_sorted=False)
         outputs, (_, _) = self.postnet(logits)
         outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
         outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
@@ -83,95 +108,51 @@
 
     def forward(
             self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
+            speech: List[torch.Tensor],
+            speech_lengths: torch.Tensor,  # num_frames of each sample
+            speaker_labels: List[torch.Tensor],
+            speaker_labels_lengths: torch.Tensor,  # num_speakers of each sample
+            orders: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-        """Frontend + Encoder + Decoder + Calc loss
-        Args:
-            speech: (Batch, Length, ...)
-            speech_lengths: (Batch, )
-            text: (Batch, Length)
-            text_lengths: (Batch,)
-        """
-        assert text_lengths.dim() == 1, text_lengths.shape
+
         # Check that batch_size is unified
         assert (
-                speech.shape[0]
-                == speech_lengths.shape[0]
-                == text.shape[0]
-                == text_lengths.shape[0]
-        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
-        batch_size = speech.shape[0]
+                len(speech)
+                == len(speech_lengths)
+                == len(speaker_labels)
+                == len(speaker_labels_lengths)
+        ), (len(speech), len(speech_lengths), len(speaker_labels), len(speaker_labels_lengths))
+        batch_size = len(speech)
 
-        # for data-parallel
-        text = text[:, : text_lengths.max()]
+        # Encoder
+        speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
+        encoder_out = self.forward_encoder(speech, speech_lengths)
 
-        # 1. Encoder
-        encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
-        intermediate_outs = None
-        if isinstance(encoder_out, tuple):
-            intermediate_outs = encoder_out[1]
-            encoder_out = encoder_out[0]
+        # Encoder-decoder attractor
+        attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)],
+                                                                    speaker_labels_lengths)
+        speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)]
 
-        loss_att, acc_att, cer_att, wer_att = None, None, None, None
-        loss_ctc, cer_ctc = None, None
+        # pit loss
+        pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels)
+        pit_loss = standard_loss(speaker_logits, pit_speaker_labels)
+
+        # pse loss
+        with torch.no_grad():
+            power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).
+                            to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels]
+        pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors]
+        pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)]
+        pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths)
+        pse_loss = cal_power_loss(pse_speaker_logits, power_ts)
+
+        loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss
+
         stats = dict()
-
-        # 1. CTC branch
-        if self.ctc_weight != 0.0:
-            loss_ctc, cer_ctc = self._calc_ctc_loss(
-                encoder_out, encoder_out_lens, text, text_lengths
-            )
-
-            # Collect CTC branch stats
-            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
-            stats["cer_ctc"] = cer_ctc
-
-        # Intermediate CTC (optional)
-        loss_interctc = 0.0
-        if self.interctc_weight != 0.0 and intermediate_outs is not None:
-            for layer_idx, intermediate_out in intermediate_outs:
-                # we assume intermediate_out has the same length & padding
-                # as those of encoder_out
-                loss_ic, cer_ic = self._calc_ctc_loss(
-                    intermediate_out, encoder_out_lens, text, text_lengths
-                )
-                loss_interctc = loss_interctc + loss_ic
-
-                # Collect Intermedaite CTC stats
-                stats["loss_interctc_layer{}".format(layer_idx)] = (
-                    loss_ic.detach() if loss_ic is not None else None
-                )
-                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-
-            loss_interctc = loss_interctc / len(intermediate_outs)
-
-            # calculate whole encoder loss
-            loss_ctc = (
-                               1 - self.interctc_weight
-                       ) * loss_ctc + self.interctc_weight * loss_interctc
-
-        # 2b. Attention decoder branch
-        if self.ctc_weight != 1.0:
-            loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
-                encoder_out, encoder_out_lens, text, text_lengths
-            )
-
-        # 3. CTC-Att loss definition
-        if self.ctc_weight == 0.0:
-            loss = loss_att
-        elif self.ctc_weight == 1.0:
-            loss = loss_ctc
-        else:
-            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
-
-        # Collect Attn branch stats
-        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
-        stats["acc"] = acc_att
-        stats["cer"] = cer_att
-        stats["wer"] = wer_att
+        stats["pse_loss"] = pse_loss.detach()
+        stats["pit_loss"] = pit_loss.detach()
+        stats["attractor_loss"] = attractor_loss.detach()
+        stats["batch_size"] = batch_size
 
         # Collect total loss stats
         stats["loss"] = torch.clone(loss.detach())
@@ -193,10 +174,10 @@
             orders = [np.arange(e.shape[0]) for e in emb]
             for order in orders:
                 np.random.shuffle(order)
-            attractors, probs = self.eda.estimate(
+            attractors, probs = self.encoder_decoder_attractor.estimate(
                 [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
         else:
-            attractors, probs = self.eda.estimate(emb)
+            attractors, probs = self.encoder_decoder_attractor.estimate(emb)
         attractors_active = []
         for p, att, e in zip(probs, attractors, emb):
             if n_speakers and n_speakers >= 0:
diff --git a/funasr/modules/eend_ola/eend_ola_dataloader.py b/funasr/modules/eend_ola/eend_ola_dataloader.py
new file mode 100644
index 0000000..2ee9272
--- /dev/null
+++ b/funasr/modules/eend_ola/eend_ola_dataloader.py
@@ -0,0 +1,57 @@
+import logging
+
+import kaldiio
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset
+
+
+def custom_collate(batch):
+    keys, speech, speaker_labels, orders = zip(*batch)
+    speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech]
+    speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels]
+    orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders]
+    batch = dict(speech=speech,
+                 speaker_labels=speaker_labels,
+                 orders=orders)
+
+    return keys, batch
+
+
+class EENDOLADataset(Dataset):
+    def __init__(
+            self,
+            data_file,
+    ):
+        self.data_file = data_file
+        with open(data_file) as f:
+            lines = f.readlines()
+        self.samples = [line.strip().split() for line in lines]
+        logging.info("total samples: {}".format(len(self.samples)))
+
+    def __len__(self):
+        return len(self.samples)
+
+    def __getitem__(self, idx):
+        key, speech_path, speaker_label_path = self.samples[idx]
+        speech = kaldiio.load_mat(speech_path)
+        speaker_label = kaldiio.load_mat(speaker_label_path).reshape(speech.shape[0], -1)
+
+        order = np.arange(speech.shape[0])
+        np.random.shuffle(order)
+
+        return key, speech, speaker_label, order
+
+
+class EENDOLADataLoader():
+    def __init__(self, data_file, batch_size, shuffle=True, num_workers=8):
+        dataset = EENDOLADataset(data_file)
+        self.data_loader = DataLoader(dataset,
+                                      batch_size=batch_size,
+                                      collate_fn=custom_collate,
+                                      shuffle=shuffle,
+                                      num_workers=num_workers)
+
+    def build_iter(self, epoch):
+        return self.data_loader
\ No newline at end of file
diff --git a/funasr/modules/eend_ola/encoder.py b/funasr/modules/eend_ola/encoder.py
index 90a63f3..3065884 100644
--- a/funasr/modules/eend_ola/encoder.py
+++ b/funasr/modules/eend_ola/encoder.py
@@ -91,6 +91,7 @@
                  dropout_rate: float = 0.1,
                  use_pos_emb: bool = False):
         super(EENDOLATransformerEncoder, self).__init__()
+        self.linear_in = nn.Linear(idim, n_units)
         self.lnorm_in = nn.LayerNorm(n_units)
         self.n_layers = n_layers
         self.dropout = nn.Dropout(dropout_rate)
@@ -104,25 +105,10 @@
             setattr(self, '{}{:d}'.format("ff_", i),
                     PositionwiseFeedForward(n_units, e_units, dropout_rate))
         self.lnorm_out = nn.LayerNorm(n_units)
-        if use_pos_emb:
-            self.pos_enc = torch.nn.Sequential(
-                torch.nn.Linear(idim, n_units),
-                torch.nn.LayerNorm(n_units),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                PositionalEncoding(n_units, dropout_rate),
-            )
-        else:
-            self.linear_in = nn.Linear(idim, n_units)
-            self.pos_enc = None
 
     def __call__(self, x, x_mask=None):
         BT_size = x.shape[0] * x.shape[1]
-        if self.pos_enc is not None:
-            e = self.pos_enc(x)
-            e = e.view(BT_size, -1)
-        else:
-            e = self.linear_in(x.reshape(BT_size, -1))
+        e = self.linear_in(x.reshape(BT_size, -1))
         for i in range(self.n_layers):
             e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
             s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0], x_mask)
@@ -130,4 +116,4 @@
             e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
             s = getattr(self, '{}{:d}'.format("ff_", i))(e)
             e = e + self.dropout(s)
-        return self.lnorm_out(e)
+        return self.lnorm_out(e)
\ No newline at end of file
diff --git a/funasr/modules/eend_ola/utils/losses.py b/funasr/modules/eend_ola/utils/losses.py
index af0181d..756952d 100644
--- a/funasr/modules/eend_ola/utils/losses.py
+++ b/funasr/modules/eend_ola/utils/losses.py
@@ -1,11 +1,10 @@
 import numpy as np
 import torch
 import torch.nn.functional as F
-from itertools import permutations
-from torch import nn
+from scipy.optimize import linear_sum_assignment
 
 
-def standard_loss(ys, ts, label_delay=0):
+def standard_loss(ys, ts):
     losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
     loss = torch.sum(torch.stack(losses))
     n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
@@ -13,55 +12,29 @@
     return loss
 
 
-def batch_pit_n_speaker_loss(ys, ts, n_speakers_list):
-    max_n_speakers = ts[0].shape[1]
-    olens = [y.shape[0] for y in ys]
-    ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-1)
-    ys_mask = [torch.ones(olen).to(ys.device) for olen in olens]
-    ys_mask = torch.nn.utils.rnn.pad_sequence(ys_mask, batch_first=True, padding_value=0).unsqueeze(-1)
+def fast_batch_pit_n_speaker_loss(ys, ts):
+    with torch.no_grad():
+        bs = len(ys)
+        indices = []
+        for b in range(bs):
+            y = ys[b].transpose(0, 1)
+            t = ts[b].transpose(0, 1)
+            C, _ = t.shape
+            y = y[:, None, :].repeat(1, C, 1)
+            t = t[None, :, :].repeat(C, 1, 1)
+            bce_loss = F.binary_cross_entropy(torch.sigmoid(y), t, reduction="none").mean(-1)
+            C = bce_loss.cpu()
+            indices.append(linear_sum_assignment(C))
+    labels_perm = [t[:, idx[1]] for t, idx in zip(ts, indices)]
 
-    losses = []
-    for shift in range(max_n_speakers):
-        ts_roll = [torch.roll(t, -shift, dims=1) for t in ts]
-        ts_roll = nn.utils.rnn.pad_sequence(ts_roll, batch_first=True, padding_value=-1)
-        loss = F.binary_cross_entropy(torch.sigmoid(ys), ts_roll, reduction='none')
-        if ys_mask is not None:
-            loss = loss * ys_mask
-        loss = torch.sum(loss, dim=1)
-        losses.append(loss)
-    losses = torch.stack(losses, dim=2)
+    return labels_perm
 
-    perms = np.array(list(permutations(range(max_n_speakers)))).astype(np.float32)
-    perms = torch.from_numpy(perms).to(losses.device)
-    y_ind = torch.arange(max_n_speakers, dtype=torch.float32, device=losses.device)
-    t_inds = torch.fmod(perms - y_ind, max_n_speakers).to(torch.long)
 
-    losses_perm = []
-    for t_ind in t_inds:
-        losses_perm.append(
-            torch.mean(losses[:, y_ind.to(torch.long), t_ind], dim=1))
-    losses_perm = torch.stack(losses_perm, dim=1)
-
-    def select_perm_indices(num, max_num):
-        perms = list(permutations(range(max_num)))
-        sub_perms = list(permutations(range(num)))
-        return [
-            [x[:num] for x in perms].index(perm)
-            for perm in sub_perms]
-
-    masks = torch.full_like(losses_perm, device=losses.device, fill_value=float('inf'))
-    for i, t in enumerate(ts):
-        n_speakers = n_speakers_list[i]
-        indices = select_perm_indices(n_speakers, max_n_speakers)
-        masks[i, indices] = 0
-    losses_perm += masks
-
-    min_loss = torch.sum(torch.min(losses_perm, dim=1)[0])
-    n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(losses.device)
-    min_loss = min_loss / n_frames
-
-    min_indices = torch.argmin(losses_perm, dim=1)
-    labels_perm = [t[:, perms[idx].to(torch.long)] for t, idx in zip(ts, min_indices)]
-    labels_perm = [t[:, :n_speakers] for t, n_speakers in zip(labels_perm, n_speakers_list)]
-
-    return min_loss, labels_perm
+def cal_power_loss(logits, power_ts):
+    losses = [F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit) for logit, power_t in
+              zip(logits, power_ts)]
+    loss = torch.sum(torch.stack(losses))
+    n_frames = torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts]))).to(torch.float32).to(
+        power_ts[0].device)
+    loss = loss / n_frames
+    return loss

--
Gitblit v1.9.1