funasr/models/language_model/transformer_lm.py
@@ -66,9 +66,7 @@ y = self.decoder(h) return y, None def score( self, y: torch.Tensor, state: Any, x: torch.Tensor ) -> Tuple[torch.Tensor, Any]: def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: """Score new token. Args: @@ -115,8 +113,7 @@ else: # transpose state of [batch, layer] into [layer, batch] batch_state = [ torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers) torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers) ] # batch decoding