update large dataset for sampling rate
| | |
| | | return seg_dict |
| | | |
| | | class ArkDataLoader(AbsIterFactory): |
| | | def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, punc_dict_file=None, mode="train"): |
| | | def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None, mode="train"): |
| | | symbol_table = read_symbol_table(dict_file) if dict_file is not None else None |
| | | if seg_dict_file is not None: |
| | | seg_dict = load_seg_dict(seg_dict_file) |
| | |
| | | else: |
| | | punc_dict = None |
| | | self.dataset_conf = dataset_conf |
| | | self.frontend_conf = frontend_conf |
| | | logging.info("dataloader config: {}".format(self.dataset_conf)) |
| | | batch_mode = self.dataset_conf.get("batch_mode", "padding") |
| | | self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, |
| | | self.dataset_conf, mode=mode, batch_mode=batch_mode) |
| | | self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode) |
| | | |
| | | def build_iter(self, epoch, shuffle=True): |
| | | self.dataset.set_epoch(epoch) |
| | |
| | | |
| | | |
| | | class AudioDataset(IterableDataset): |
| | | def __init__(self, scp_lists, data_names, data_types, shuffle=True, mode="train"): |
| | | def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"): |
| | | self.scp_lists = scp_lists |
| | | self.data_names = data_names |
| | | self.data_types = data_types |
| | | self.frontend_conf = frontend_conf |
| | | self.shuffle = shuffle |
| | | self.mode = mode |
| | | self.epoch = -1 |
| | |
| | | elif data_type == "sound": |
| | | key, path = item.strip().split() |
| | | waveform, sampling_rate = torchaudio.load(path) |
| | | if self.frontend_conf is not None: |
| | | if sampling_rate != self.frontend_conf["fs"]: |
| | | waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, |
| | | new_freq=self.frontend_conf["fs"])(waveform) |
| | | sampling_rate = self.frontend_conf["fs"] |
| | | waveform = waveform.numpy() |
| | | mat = waveform[0] |
| | | sample_dict[data_name] = mat |
| | |
| | | seg_dict, |
| | | punc_dict, |
| | | conf, |
| | | frontend_conf, |
| | | mode="train", |
| | | batch_mode="padding"): |
| | | scp_lists = read_lists(data_list_file) |
| | | shuffle = conf.get('shuffle', True) |
| | | data_names = conf.get("data_names", "speech,text") |
| | | data_types = conf.get("data_types", "kaldi_ark,text") |
| | | dataset = AudioDataset(scp_lists, data_names, data_types, shuffle=shuffle, mode=mode) |
| | | dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, mode=mode) |
| | | |
| | | filter_conf = conf.get('filter_conf', {}) |
| | | filter_fn = partial(filter, **filter_conf) |
| | |
| | | if args.dataset_type == "large": |
| | | from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader |
| | | train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, |
| | | mode="train") |
| | | valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, |
| | | valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, |