funasr/models/sond/encoder/ci_scorers.py
@@ -7,9 +7,9 @@ super().__init__() def forward( self, xs_pad: torch.Tensor, spk_emb: torch.Tensor, self, xs_pad: torch.Tensor, spk_emb: torch.Tensor, ): # xs_pad: B, T, D # spk_emb: B, N, D @@ -22,9 +22,9 @@ super().__init__() def forward( self, xs_pad: torch.Tensor, spk_emb: torch.Tensor, self, xs_pad: torch.Tensor, spk_emb: torch.Tensor, ): # xs_pad: B, T, D # spk_emb: B, N, D