zhifu gao
2023-12-07 5cf512419c282f833ee35a2f31890bff00d94343
funasr/models/e2e_asr.py
@@ -122,6 +122,7 @@
            self.ctc = ctc
        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
        self.length_normalized_loss = length_normalized_loss
    def forward(
            self,
@@ -220,6 +221,8 @@
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = (text_lengths + 1).sum().type_as(batch_size)
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight