Yabin Li
2023-08-08 b454a1054fadbff0ee963944ff42f66b98317582
funasr/export/models/predictor/cif.py
@@ -36,6 +36,17 @@
   def forward(self, hidden: torch.Tensor,
               mask: torch.Tensor,
               ):
      alphas, token_num = self.forward_cnn(hidden, mask)
      mask = mask.transpose(-1, -2).float()
      mask = mask.squeeze(-1)
      hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
      acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
      return acoustic_embeds, token_num, alphas, cif_peak
   def forward_cnn(self, hidden: torch.Tensor,
               mask: torch.Tensor,
               ):
      h = hidden
      context = h.transpose(1, 2)
      queries = self.pad(context)
@@ -49,12 +60,8 @@
      alphas = alphas * mask
      alphas = alphas.squeeze(-1)
      token_num = alphas.sum(-1)
      mask = mask.squeeze(-1)
      hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
      acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
      return acoustic_embeds, token_num, alphas, cif_peak
      return alphas, token_num
   
   def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
      b, t, d = hidden.size()
@@ -285,4 +292,4 @@
                                integrate)
    fires = torch.stack(list_fires, 1)
    return fires
    return fires