| | |
| | | |
| | | |
| | | class AudioDataset(IterableDataset): |
| | | def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train", pre_hwlist=None, pre_prob=0.0): |
| | | def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"): |
| | | self.scp_lists = scp_lists |
| | | self.data_names = data_names |
| | | self.data_types = data_types |
| | |
| | | self.world_size = 1 |
| | | self.worker_id = 0 |
| | | self.num_workers = 1 |
| | | self.pre_hwlist = pre_hwlist |
| | | self.pre_prob = pre_prob |
| | | |
| | | def set_epoch(self, epoch): |
| | | self.epoch = epoch |
| | |
| | | data_types = conf.get("data_types", "kaldi_ark,text") |
| | | |
| | | pre_hwfile = conf.get("pre_hwlist", None) |
| | | pre_prob = conf.get("pre_prob", 0) |
| | | pre_prob = conf.get("pre_prob", 0) # unused yet |
| | | |
| | | hw_config = {"sample_rate": conf.get("sample_rate", 0.6), |
| | | "double_rate": conf.get("double_rate", 0.1), |
| | | "hotword_min_length": conf.get("hotword_min_length", 2), |
| | | "hotword_max_length": conf.get("hotword_max_length", 8)} |
| | | |
| | | |
| | | if pre_hwfile is not None: |
| | | pre_hwlist = [] |
| | |
| | | pre_hwlist.append(line.strip()) |
| | | else: |
| | | pre_hwlist = None |
| | | # logging.warning("Previous hwlist: {}".format(pre_hwlist)) |
| | | |
| | | dataset = AudioDataset(scp_lists, |
| | | data_names, |
| | | data_types, |
| | | frontend_conf=frontend_conf, |
| | | shuffle=shuffle, |
| | | mode=mode, |
| | | pre_hwlist=pre_hwlist, |
| | | pre_prob=pre_prob) |
| | | ) |
| | | |
| | | filter_conf = conf.get('filter_conf', {}) |
| | | filter_fn = partial(filter, **filter_conf) |
| | |
| | | normalize = None |
| | | |
| | | # 4. Encoder |
| | | |
| | | if getattr(args, "encoder", None) is not None: |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class(input_size, **args.encoder_conf) |