funasr/models/ct_transformer/model.py
@@ -46,7 +46,7 @@ self.embed = nn.Embedding(vocab_size, embed_unit) encoder_class = tables.encoder_classes.get(encoder.lower()) encoder_class = tables.encoder_classes.get(encoder) encoder = encoder_class(**encoder_conf) self.decoder = nn.Linear(att_unit, punc_size) @@ -60,7 +60,7 @@ def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs): """Compute loss value from buffer sequences. Args: