游雁
2023-12-11 d77910eb6d171727f2350e45c31c91436c4c8891
funasr/cli/models/paraformer.py
@@ -193,6 +193,7 @@
         self.decoder.embed = None
      self.use_1st_decoder_loss = use_1st_decoder_loss
      self.length_normalized_loss = length_normalized_loss
   
   def forward(
      self,
@@ -302,6 +303,8 @@
      stats["loss"] = torch.clone(loss.detach())
      
      # force_gatherable: to-device and to-tensor if scalar for DataParallel
      if self.length_normalized_loss:
         batch_size = (text_lengths + self.predictor_bias).sum()
      loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
      return loss, stats, weight