| | |
| | | from funasr.iterators.chunk_iter_factory import ChunkIterFactory |
| | | from funasr.iterators.multiple_iter_factory import MultipleIterFactory |
| | | from funasr.iterators.sequence_iter_factory import SequenceIterFactory |
| | | from funasr.main_funcs.collect_stats import collect_stats |
| | | from funasr.optimizers.sgd import SGD |
| | | from funasr.optimizers.fairseq_adam import FairseqAdam |
| | | from funasr.samplers.build_batch_sampler import BATCH_TYPES |
| | | from funasr.samplers.build_batch_sampler import build_batch_sampler |
| | | from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler |
| | | from funasr.schedulers.noam_lr import NoamLR |
| | | from funasr.schedulers.warmup_lr import WarmupLR |
| | | from funasr.schedulers.tri_stage_scheduler import TriStageLR |
| | | from funasr.torch_utils.load_pretrained_model import load_pretrained_model |
| | | from funasr.torch_utils.model_summary import model_summary |
| | | from funasr.torch_utils.pytorch_version import pytorch_cudnn_version |
| | |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_int |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.wav_utils import calc_shape, generate_data_list |
| | | from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text |
| | | from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump |
| | | |
| | | try: |
| | |
| | | |
| | | optim_classes = dict( |
| | | adam=torch.optim.Adam, |
| | | fairseq_adam=FairseqAdam, |
| | | adamw=torch.optim.AdamW, |
| | | sgd=SGD, |
| | | adadelta=torch.optim.Adadelta, |
| | |
| | | CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR, |
| | | noamlr=NoamLR, |
| | | warmuplr=WarmupLR, |
| | | tri_stage=TriStageLR, |
| | | cycliclr=torch.optim.lr_scheduler.CyclicLR, |
| | | onecyclelr=torch.optim.lr_scheduler.OneCycleLR, |
| | | CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, |
| | |
| | | if args.batch_bins is not None: |
| | | args.batch_bins = args.batch_bins * args.ngpu |
| | | |
| | | # filter samples if wav.scp and text are mismatch |
| | | if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large": |
| | | if not args.simple_ddp or distributed_option.dist_rank == 0: |
| | | filter_wav_text(args.data_dir, args.train_set) |
| | | filter_wav_text(args.data_dir, args.dev_set) |
| | | if args.simple_ddp: |
| | | dist.barrier() |
| | | |
| | | if args.train_shape_file is None and args.dataset_type == "small": |
| | | if not args.simple_ddp or distributed_option.dist_rank == 0: |
| | | calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max) |
| | |
| | | |
| | | if args.dry_run: |
| | | pass |
| | | elif args.collect_stats: |
| | | # Perform on collect_stats mode. This mode has two roles |
| | | # - Derive the length and dimension of all input data |
| | | # - Accumulate feats, square values, and the length for whitening |
| | | |
| | | if args.valid_batch_size is None: |
| | | args.valid_batch_size = args.batch_size |
| | | |
| | | if len(args.train_shape_file) != 0: |
| | | train_key_file = args.train_shape_file[0] |
| | | else: |
| | | train_key_file = None |
| | | if len(args.valid_shape_file) != 0: |
| | | valid_key_file = args.valid_shape_file[0] |
| | | else: |
| | | valid_key_file = None |
| | | |
| | | collect_stats( |
| | | model=model, |
| | | train_iter=cls.build_streaming_iterator( |
| | | data_path_and_name_and_type=args.train_data_path_and_name_and_type, |
| | | key_file=train_key_file, |
| | | batch_size=args.batch_size, |
| | | dtype=args.train_dtype, |
| | | num_workers=args.num_workers, |
| | | allow_variable_data_keys=args.allow_variable_data_keys, |
| | | ngpu=args.ngpu, |
| | | preprocess_fn=cls.build_preprocess_fn(args, train=False), |
| | | collate_fn=cls.build_collate_fn(args, train=False), |
| | | ), |
| | | valid_iter=cls.build_streaming_iterator( |
| | | data_path_and_name_and_type=args.valid_data_path_and_name_and_type, |
| | | key_file=valid_key_file, |
| | | batch_size=args.valid_batch_size, |
| | | dtype=args.train_dtype, |
| | | num_workers=args.num_workers, |
| | | allow_variable_data_keys=args.allow_variable_data_keys, |
| | | ngpu=args.ngpu, |
| | | preprocess_fn=cls.build_preprocess_fn(args, train=False), |
| | | collate_fn=cls.build_collate_fn(args, train=False), |
| | | ), |
| | | output_dir=output_dir, |
| | | ngpu=args.ngpu, |
| | | log_interval=args.log_interval, |
| | | write_collected_feats=args.write_collected_feats, |
| | | ) |
| | | else: |
| | | logging.info("Training args: {}".format(args)) |
| | | # 6. Loads pre-trained model |
| | |
| | | collate_fn, |
| | | key_file: str = None, |
| | | batch_size: int = 1, |
| | | fs: dict = None, |
| | | dtype: str = np.float32, |
| | | num_workers: int = 1, |
| | | allow_variable_data_keys: bool = False, |
| | |
| | | dataset = IterableESPnetDataset( |
| | | data_path_and_name_and_type, |
| | | float_dtype=dtype, |
| | | fs=fs, |
| | | preprocess=preprocess_fn, |
| | | key_file=key_file, |
| | | ) |