funasr/models/predictor/cif.py
@@ -147,7 +147,7 @@ b, t, d = hidden.size() tail_threshold = self.tail_threshold tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device) tail_threshold = torch.reshape(tail_threshold, (1, 1)) tail_threshold = tail_threshold.unsqueeze(0).repeat(b, 1) alphas = torch.cat([alphas, tail_threshold], dim=1) zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device) hidden = torch.cat([hidden, zeros], dim=1)