funasr/models/predictor/cif.py
@@ -31,10 +31,12 @@ alphas = torch.sigmoid(output) alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) if mask is not None: alphas = alphas * mask.transpose(-1, -2).float() mask = mask.transpose(-1, -2).float() alphas = alphas * mask if mask_chunk_predictor is not None: alphas = alphas * mask_chunk_predictor alphas = alphas.squeeze(-1) mask = mask.squeeze(-1) if target_label_length is not None: target_length = target_label_length elif target_label is not None: