游雁
2023-12-13 806a03609df033d61f824f1ab8527eb88fe837ad
funasr/cli/models/paraformer.py
@@ -193,6 +193,7 @@
         self.decoder.embed = None
      self.use_1st_decoder_loss = use_1st_decoder_loss
      self.length_normalized_loss = length_normalized_loss
   
   def forward(
      self,
@@ -302,6 +303,8 @@
      stats["loss"] = torch.clone(loss.detach())
      
      # force_gatherable: to-device and to-tensor if scalar for DataParallel
      if self.length_normalized_loss:
         batch_size = (text_lengths + self.predictor_bias).sum()
      loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
      return loss, stats, weight
   
@@ -594,7 +597,7 @@
         for li in range(bsz):
            target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
            if target_num > 0:
               input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
               input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
         input_mask = input_mask.eq(1)
         input_mask = input_mask.masked_fill(~nonpad_positions, False)
         input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
@@ -624,7 +627,7 @@
      for li in range(bsz):
         target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
         if target_num > 0:
            input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num], value=0)
            input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
      input_mask = input_mask.eq(1)
      input_mask = input_mask.masked_fill(~nonpad_positions, False)
      input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)