| | |
| | | import torch.nn |
| | | import torch.optim |
| | | import yaml |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from torch.utils.data import DataLoader |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | |
| | | from funasr.iterators.multiple_iter_factory import MultipleIterFactory |
| | | from funasr.iterators.sequence_iter_factory import SequenceIterFactory |
| | | from funasr.main_funcs.collect_stats import collect_stats |
| | | from funasr.optimizers.sgd import SGD |
| | | from funasr.optimizers.fairseq_adam import FairseqAdam |
| | | from funasr.optimizers.sgd import SGD |
| | | from funasr.samplers.build_batch_sampler import BATCH_TYPES |
| | | from funasr.samplers.build_batch_sampler import build_batch_sampler |
| | | from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler |
| | | from funasr.schedulers.noam_lr import NoamLR |
| | | from funasr.schedulers.warmup_lr import WarmupLR |
| | | from funasr.schedulers.tri_stage_scheduler import TriStageLR |
| | | from funasr.schedulers.warmup_lr import WarmupLR |
| | | from funasr.torch_utils.load_pretrained_model import load_pretrained_model |
| | | from funasr.torch_utils.model_summary import model_summary |
| | | from funasr.torch_utils.pytorch_version import pytorch_cudnn_version |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.train.class_choices import ClassChoices |
| | | from funasr.train.distributed_utils import DistributedOption |
| | | from funasr.train.trainer import Trainer |
| | |
| | | type=int, |
| | | default=sys.maxsize, |
| | | help="The maximum number update step to train", |
| | | ) |
| | | parser.add_argument( |
| | | "--batch_interval", |
| | | type=int, |
| | | default=10000, |
| | | help="The batch interval for saving model.", |
| | | ) |
| | | group.add_argument( |
| | | "--patience", |
| | |
| | | args.batch_bins = args.batch_bins * args.ngpu |
| | | |
| | | # filter samples if wav.scp and text are mismatch |
| | | if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large": |
| | | if ( |
| | | args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large": |
| | | if not args.simple_ddp or distributed_option.dist_rank == 0: |
| | | filter_wav_text(args.data_dir, args.train_set) |
| | | filter_wav_text(args.data_dir, args.dev_set) |
| | |
| | | |
| | | if args.train_shape_file is None and args.dataset_type == "small": |
| | | if not args.simple_ddp or distributed_option.dist_rank == 0: |
| | | calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max) |
| | | calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max) |
| | | calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, |
| | | args.speech_length_max) |
| | | calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, |
| | | args.speech_length_max) |
| | | if args.simple_ddp: |
| | | dist.barrier() |
| | | args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")] |
| | |
| | | # logging.basicConfig() is invoked in main_worker() instead of main() |
| | | # because it can be invoked only once in a process. |
| | | # FIXME(kamo): Should we use logging.getLogger()? |
| | | # BUGFIX: Remove previous handlers and reset log level |
| | | for handler in logging.root.handlers[:]: |
| | | logging.root.removeHandler(handler) |
| | | logging.basicConfig( |
| | | level=args.log_level, |
| | | format=f"[{os.uname()[1].split('.')[0]}]" |
| | | f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | else: |
| | | # BUGFIX: Remove previous handlers and reset log level |
| | | for handler in logging.root.handlers[:]: |
| | | logging.root.removeHandler(handler) |
| | | # Suppress logging if RANK != 0 |
| | | logging.basicConfig( |
| | | level="ERROR", |
| | |
| | | if args.dataset_type == "large": |
| | | from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader |
| | | train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None, |
| | | frontend_conf=args.frontend_conf if hasattr(args, |
| | | "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, |
| | | "punc_list") else None, |
| | | bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, |
| | | mode="train") |
| | | valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None, |
| | | valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, |
| | | "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, |
| | | "punc_list") else None, |
| | | bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, |
| | | mode="eval") |
| | | elif args.dataset_type == "small": |
| | | train_iter_factory = cls.build_iter_factory( |
| | |
| | | ) -> AbsIterFactory: |
| | | assert check_argument_types() |
| | | |
| | | if args.frontend_conf is not None and "fs" in args.frontend_conf: |
| | | dest_sample_rate = args.frontend_conf["fs"] |
| | | else: |
| | | dest_sample_rate = 16000 |
| | | |
| | | dataset = ESPnetDataset( |
| | | iter_options.data_path_and_name_and_type, |
| | | float_dtype=args.train_dtype, |
| | | preprocess=iter_options.preprocess_fn, |
| | | max_cache_size=iter_options.max_cache_size, |
| | | max_cache_fd=iter_options.max_cache_fd, |
| | | dest_sample_rate=args.frontend_conf["fs"], |
| | | dest_sample_rate=dest_sample_rate, |
| | | ) |
| | | cls.check_task_requirements( |
| | | dataset, args.allow_variable_data_keys, train=iter_options.train |