游雁
2024-02-20 d79287c37e4e7ae2694a992cbbfb03a5ca4f7670
funasr/models/contextual_paraformer/model.py
@@ -65,11 +65,9 @@
        if bias_encoder_type == 'lstm':
            logging.warning("enable bias encoder sampling and contextual training")
            self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
            self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
        elif bias_encoder_type == 'mean':
            logging.warning("enable bias encoder sampling and contextual training")
            self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
        else:
            logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type))