语帆
2024-02-22 044199f80279825baba0831380c5fc0369abd298
test
1个文件已修改
11 ■■■■ 已修改文件
funasr/models/lcbnet/model.py 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/lcbnet/model.py
@@ -21,7 +21,7 @@
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
import pdb
@tables.register("model_classes", "LCBNet")
class LCBNet(nn.Module):
    """
@@ -90,7 +90,7 @@
        fusion_encoder_class = tables.encoder_classes.get(fusion_encoder)
        fusion_encoder = fusion_encoder_class(**fusion_encoder_conf)
        bias_predictor_class = tables.encoder_classes.get_class(bias_predictor)
        bias_predictor = bias_predictor_class(args.bias_predictor_conf)
        bias_predictor = bias_predictor_class(bias_predictor_conf)
        if decoder is not None:
            decoder_class = tables.decoder_classes.get(decoder)
@@ -117,9 +117,13 @@
        self.specaug = specaug
        self.normalize = normalize
        self.encoder = encoder
        # lcbnet
        self.text_encoder = text_encoder
        self.fusion_encoder = fusion_encoder
        self.bias_predictor = bias_predictor
        self.select_num = select_num
        self.select_length = select_length
        self.insert_blank = insert_blank
        if not hasattr(self.encoder, "interctc_use_conditioning"):
            self.encoder.interctc_use_conditioning = False
@@ -409,7 +413,8 @@
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        pdb.set_trace()
        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
            speech, speech_lengths = data_in, data_lengths