funasr/models/e2e_asr.py
@@ -122,6 +122,7 @@ self.ctc = ctc self.extract_feats_in_collect_stats = extract_feats_in_collect_stats self.length_normalized_loss = length_normalized_loss def forward( self, @@ -220,6 +221,8 @@ stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel if self.length_normalized_loss: batch_size = (text_lengths + 1).sum().type_as(batch_size) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight