aky15
2023-05-23 584d0bc0ebb0de360c8cc3c05c26e376b943fefa
Merge pull request #538 from alibaba-damo-academy/dev_aky2

rnnt support wav input
3个文件已修改
54 ■■■■■ 已修改文件
funasr/bin/asr_infer.py 22 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/abs_task.py 23 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_infer.py
@@ -1510,8 +1510,13 @@
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.frontend is not None:
            speech = torch.unsqueeze(speech, axis=0)
            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
@@ -1536,14 +1541,19 @@
        
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.frontend is not None:
            speech = torch.unsqueeze(speech, axis=0)
            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        
        enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
        enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
        
        nbest_hyps = self.beam_search(enc_out[0])
        
funasr/tasks/abs_task.py
@@ -1376,25 +1376,10 @@
            # 7. Build iterator factories
            if args.dataset_type == "large":
                from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
                                                   frontend_conf=args.frontend_conf if hasattr(args,
                                                                                               "frontend_conf") else None,
                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                               "seg_dict_file") else None,
                                                   punc_dict_file=args.punc_list if hasattr(args,
                                                                                            "punc_list") else None,
                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
                                                   mode="train")
                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
                                                   frontend_conf=args.frontend_conf if hasattr(args,
                                                                                               "frontend_conf") else None,
                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                               "seg_dict_file") else None,
                                                   punc_dict_file=args.punc_list if hasattr(args,
                                                                                            "punc_list") else None,
                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
                                                   mode="eval")
                from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
                train_iter_factory = LargeDataLoader(args, mode="train")
                valid_iter_factory = LargeDataLoader(args, mode="eval")
            elif args.dataset_type == "small":
                train_iter_factory = cls.build_iter_factory(
                    args=args,
funasr/tasks/asr.py
@@ -363,12 +363,6 @@
            default=get_default_kwargs(CTC),
            help="The keyword arguments for CTC class.",
        )
        group.add_argument(
            "--joint_network_conf",
            action=NestedDictAction,
            default=None,
            help="The keyword arguments for joint network class.",
        )
        group = parser.add_argument_group(description="Preprocess related")
        group.add_argument(
@@ -1379,6 +1373,7 @@
    num_optimizers: int = 1
    class_choices_list = [
        model_choices,
        frontend_choices,
        specaug_choices,
        normalize_choices,
@@ -1476,7 +1471,7 @@
        try:
            model_class = model_choices.get_class(args.model)
        except AttributeError:
            model_class = model_choices.get_class("asr")
            model_class = model_choices.get_class("rnnt_unified")
        model = model_class(
            vocab_size=vocab_size,