aky15
2023-05-24 8c24f52fd25fce2df46dc4d9dffb45619dc38f9f
Dev aky2 (#549)

* support resume model from pai

* add padding for streaming rnnt conv input

* fix large dataset training bug

* bug fix

---------

Co-authored-by: aky15 <ankeyu.aky@11.17.44.249>
3个文件已修改
14 ■■■■■ 已修改文件
funasr/datasets/large_datasets/build_dataloader.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/utils/tokenize.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/subsampling.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/build_dataloader.py
@@ -73,8 +73,8 @@
            seg_dict = load_seg_dict(args.seg_dict_file)
        if hasattr(args, "punc_dict_file") and args.punc_dict_file is not None:
            punc_dict = read_symbol_table(args.punc_dict_file)
        if hasattr(args, "bpemodel_file") and args.bpemodel_file is not None:
            bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel_file)
        if hasattr(args, "bpemodel") and args.bpemodel is not None:
            bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel)
        self.dataset_conf = args.dataset_conf
        self.frontend_conf = args.frontend_conf
        logging.info("dataloader config: {}".format(self.dataset_conf))
funasr/datasets/large_datasets/utils/tokenize.py
@@ -46,10 +46,8 @@
    text = data["text"]
    token = []
    vad = -2
    if bpe_tokenizer is not None:
        text = bpe_tokenizer.text2tokens(text)
        text = bpe_tokenizer.text2tokens(" ".join(text))
    if seg_dict is not None:
        assert isinstance(seg_dict, dict)
        text = seg_tokenize(text, seg_dict)
funasr/modules/subsampling.py
@@ -506,9 +506,9 @@
                )
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size, 3, 2),
                    torch.nn.Conv2d(1, conv_size, 3, 2, [1,0]),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2, [(kernel_2-1)//2, 0]),
                    torch.nn.ReLU(),
                )
@@ -597,7 +597,7 @@
            mask: Mask of output sequences. (B, sub(T))
        """
        if self.subsampling_factor > 1:
            return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
            return mask[:, ::2][:, ::self.stride_2]
        else:
            return mask