游雁
2024-01-15 a035d68e860ea6decdf422c0fc04eda4fc4de397
funasr/models/ct_transformer/model.py
@@ -46,7 +46,7 @@
        
        
        self.embed = nn.Embedding(vocab_size, embed_unit)
        encoder_class = tables.encoder_classes.get(encoder.lower())
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(**encoder_conf)
        self.decoder = nn.Linear(att_unit, punc_size)
@@ -60,7 +60,7 @@
        
        
    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
        """Compute loss value from buffer sequences.
        Args: