funasr/cli/models/paraformer.py
@@ -193,6 +193,7 @@ self.decoder.embed = None self.use_1st_decoder_loss = use_1st_decoder_loss self.length_normalized_loss = length_normalized_loss def forward( self, @@ -302,6 +303,8 @@ stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel if self.length_normalized_loss: batch_size = (text_lengths + self.predictor_bias).sum() loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight