| | |
| | | import humanfriendly |
| | | import numpy as np |
| | | import torch |
| | | import torch.distributed as dist |
| | | import torch.multiprocessing |
| | | import torch.nn |
| | | import torch.optim |
| | | import yaml |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from torch.utils.data import DataLoader |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | |
| | | from funasr.datasets.dataset import DATA_TYPES |
| | | from funasr.datasets.dataset import ESPnetDataset |
| | | from funasr.datasets.iterable_dataset import IterableESPnetDataset |
| | | from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope |
| | | from funasr.iterators.abs_iter_factory import AbsIterFactory |
| | | from funasr.iterators.chunk_iter_factory import ChunkIterFactory |
| | | from funasr.iterators.multiple_iter_factory import MultipleIterFactory |
| | | from funasr.iterators.sequence_iter_factory import SequenceIterFactory |
| | | from funasr.main_funcs.collect_stats import collect_stats |
| | | from funasr.optimizers.fairseq_adam import FairseqAdam |
| | | from funasr.optimizers.sgd import SGD |
| | | from funasr.samplers.build_batch_sampler import BATCH_TYPES |
| | | from funasr.samplers.build_batch_sampler import build_batch_sampler |
| | | from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler |
| | | from funasr.schedulers.noam_lr import NoamLR |
| | | from funasr.schedulers.tri_stage_scheduler import TriStageLR |
| | | from funasr.schedulers.warmup_lr import WarmupLR |
| | | from funasr.torch_utils.load_pretrained_model import load_pretrained_model |
| | | from funasr.torch_utils.model_summary import model_summary |
| | | from funasr.torch_utils.pytorch_version import pytorch_cudnn_version |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.train.class_choices import ClassChoices |
| | | from funasr.train.distributed_utils import DistributedOption |
| | | from funasr.train.trainer import Trainer |
| | |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_int |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text |
| | | from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump |
| | | |
| | | try: |
| | |
| | | |
| | | optim_classes = dict( |
| | | adam=torch.optim.Adam, |
| | | fairseq_adam=FairseqAdam, |
| | | adamw=torch.optim.AdamW, |
| | | sgd=SGD, |
| | | adadelta=torch.optim.Adadelta, |
| | |
| | | CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR, |
| | | noamlr=NoamLR, |
| | | warmuplr=WarmupLR, |
| | | tri_stage=TriStageLR, |
| | | cycliclr=torch.optim.lr_scheduler.CyclicLR, |
| | | onecyclelr=torch.optim.lr_scheduler.OneCycleLR, |
| | | CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, |
| | |
| | | num_optimizers: int = 1 |
| | | trainer = Trainer |
| | | class_choices_list: List[ClassChoices] = [] |
| | | finetune_args: None |
| | | |
| | | def __init__(self): |
| | | raise RuntimeError("This class can't be instantiated.") |
| | |
| | | |
| | | # NOTE(kamo): add_arguments(..., required=True) can't be used |
| | | # to provide --print_config mode. Instead of it, do as |
| | | parser.set_defaults(required=["output_dir"]) |
| | | # parser.set_defaults(required=["output_dir"]) |
| | | |
| | | group = parser.add_argument_group("Common configuration") |
| | | |
| | |
| | | default=sys.maxsize, |
| | | help="The maximum number update step to train", |
| | | ) |
| | | parser.add_argument( |
| | | "--batch_interval", |
| | | type=int, |
| | | default=10000, |
| | | help="The batch interval for saving model.", |
| | | ) |
| | | group.add_argument( |
| | | "--patience", |
| | | type=int_or_none, |
| | |
| | | "and exclude_keys excludes keys of model states for the initialization." |
| | | "e.g.\n" |
| | | " # Load all parameters" |
| | | " --init_param some/where/model.pth\n" |
| | | " --init_param some/where/model.pb\n" |
| | | " # Load only decoder parameters" |
| | | " --init_param some/where/model.pth:decoder:decoder\n" |
| | | " --init_param some/where/model.pb:decoder:decoder\n" |
| | | " # Load only decoder parameters excluding decoder.embed" |
| | | " --init_param some/where/model.pth:decoder:decoder:decoder.embed\n" |
| | | " --init_param some/where/model.pth:decoder:decoder:decoder.embed\n", |
| | | " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n" |
| | | " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n", |
| | | ) |
| | | group.add_argument( |
| | | "--ignore_init_mismatch", |
| | |
| | | group.add_argument( |
| | | "--batch_type", |
| | | type=str, |
| | | default="folded", |
| | | default="length", |
| | | choices=list(BATCH_TYPES), |
| | | help=_batch_type_help, |
| | | ) |
| | |
| | | default=None, |
| | | choices=list(BATCH_TYPES) + [None], |
| | | help="If not given, the value of --batch_type is used", |
| | | ) |
| | | group.add_argument( |
| | | "--speech_length_min", |
| | | type=int, |
| | | default=-1, |
| | | help="speech length min", |
| | | ) |
| | | group.add_argument( |
| | | "--speech_length_max", |
| | | type=int, |
| | | default=-1, |
| | | help="speech length max", |
| | | ) |
| | | group.add_argument("--fold_length", type=int, action="append", default=[]) |
| | | group.add_argument( |
| | |
| | | help="flag to indicate whether training on PAI", |
| | | ) |
| | | group.add_argument( |
| | | "--simple_ddp", |
| | | type=str2bool, |
| | | default=False, |
| | | ) |
| | | group.add_argument( |
| | | "--num_worker_count", |
| | | type=int, |
| | | default=1, |
| | |
| | | @classmethod |
| | | def check_required_command_args(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | for k in vars(args): |
| | | if "-" in k: |
| | | raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")') |
| | | if hasattr(args, "required"): |
| | | for k in vars(args): |
| | | if "-" in k: |
| | | raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")') |
| | | |
| | | required = ", ".join( |
| | | f"--{a}" for a in args.required if getattr(args, a) is None |
| | | ) |
| | | |
| | | if len(required) != 0: |
| | | parser = cls.get_parser() |
| | | parser.print_help(file=sys.stderr) |
| | | p = Path(sys.argv[0]).name |
| | | print(file=sys.stderr) |
| | | print( |
| | | f"{p}: error: the following arguments are required: " f"{required}", |
| | | file=sys.stderr, |
| | | required = ", ".join( |
| | | f"--{a}" for a in args.required if getattr(args, a) is None |
| | | ) |
| | | sys.exit(2) |
| | | |
| | | if len(required) != 0: |
| | | parser = cls.get_parser() |
| | | parser.print_help(file=sys.stderr) |
| | | p = Path(sys.argv[0]).name |
| | | print(file=sys.stderr) |
| | | print( |
| | | f"{p}: error: the following arguments are required: " f"{required}", |
| | | file=sys.stderr, |
| | | ) |
| | | sys.exit(2) |
| | | |
| | | @classmethod |
| | | def check_task_requirements( |
| | | cls, |
| | | dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope], |
| | | dataset: Union[AbsDataset, IterableESPnetDataset], |
| | | allow_variable_data_keys: bool, |
| | | train: bool, |
| | | inference: bool = False, |
| | |
| | | cls.main_worker(args) |
| | | |
| | | @classmethod |
| | | def run(cls): |
| | | assert hasattr(cls, "finetune_args") |
| | | args = cls.finetune_args |
| | | args.train_shape_file = None |
| | | if args.distributed: |
| | | args.simple_ddp = True |
| | | else: |
| | | args.simple_ddp = False |
| | | args.ngpu = 1 |
| | | args.use_pai = False |
| | | args.batch_type = "length" |
| | | args.oss_bucket = None |
| | | args.input_size = None |
| | | cls.main_worker(args) |
| | | |
| | | @classmethod |
| | | def main_worker(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | |
| | |
| | | # Setting distributed_option.dist_rank, etc. |
| | | if args.use_pai: |
| | | distributed_option.init_options_pai() |
| | | else: |
| | | elif not args.simple_ddp: |
| | | distributed_option.init_options() |
| | | |
| | | # Invoking torch.distributed.init_process_group |
| | | if args.use_pai: |
| | | distributed_option.init_torch_distributed_pai(args) |
| | | elif not args.simple_ddp: |
| | | distributed_option.init_torch_distributed(args) |
| | | elif args.distributed and args.simple_ddp: |
| | | distributed_option.init_torch_distributed_pai(args) |
| | | args.ngpu = dist.get_world_size() |
| | | if args.dataset_type == "small": |
| | | if args.batch_size is not None: |
| | | args.batch_size = args.batch_size * args.ngpu |
| | | if args.batch_bins is not None: |
| | | args.batch_bins = args.batch_bins * args.ngpu |
| | | |
| | | # filter samples if wav.scp and text are mismatch |
| | | if ( |
| | | args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large": |
| | | if not args.simple_ddp or distributed_option.dist_rank == 0: |
| | | filter_wav_text(args.data_dir, args.train_set) |
| | | filter_wav_text(args.data_dir, args.dev_set) |
| | | if args.simple_ddp: |
| | | dist.barrier() |
| | | |
| | | if args.train_shape_file is None and args.dataset_type == "small": |
| | | if not args.simple_ddp or distributed_option.dist_rank == 0: |
| | | calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, |
| | | args.speech_length_max) |
| | | calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, |
| | | args.speech_length_max) |
| | | if args.simple_ddp: |
| | | dist.barrier() |
| | | args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")] |
| | | args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")] |
| | | |
| | | if args.train_data_file is None and args.dataset_type == "large": |
| | | if not args.simple_ddp or distributed_option.dist_rank == 0: |
| | | generate_data_list(args.data_dir, args.train_set) |
| | | generate_data_list(args.data_dir, args.dev_set) |
| | | if args.simple_ddp: |
| | | dist.barrier() |
| | | args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list") |
| | | args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list") |
| | | |
| | | # NOTE(kamo): Don't use logging before invoking logging.basicConfig() |
| | | if not distributed_option.distributed or distributed_option.dist_rank == 0: |
| | |
| | | # logging.basicConfig() is invoked in main_worker() instead of main() |
| | | # because it can be invoked only once in a process. |
| | | # FIXME(kamo): Should we use logging.getLogger()? |
| | | # BUGFIX: Remove previous handlers and reset log level |
| | | for handler in logging.root.handlers[:]: |
| | | logging.root.removeHandler(handler) |
| | | logging.basicConfig( |
| | | level=args.log_level, |
| | | format=f"[{os.uname()[1].split('.')[0]}]" |
| | | f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | else: |
| | | # BUGFIX: Remove previous handlers and reset log level |
| | | for handler in logging.root.handlers[:]: |
| | | logging.root.removeHandler(handler) |
| | | # Suppress logging if RANK != 0 |
| | | logging.basicConfig( |
| | | level="ERROR", |
| | | format=f"[{os.uname()[1].split('.')[0]}]" |
| | | f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | # Invoking torch.distributed.init_process_group |
| | | if args.use_pai: |
| | | distributed_option.init_torch_distributed_pai(args) |
| | | else: |
| | | distributed_option.init_torch_distributed(args) |
| | | logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size, |
| | | distributed_option.dist_rank, |
| | | distributed_option.local_rank)) |
| | | |
| | | # 1. Set random-seed |
| | | set_all_random_seed(args.seed) |
| | |
| | | |
| | | if args.dry_run: |
| | | pass |
| | | elif args.collect_stats: |
| | | # Perform on collect_stats mode. This mode has two roles |
| | | # - Derive the length and dimension of all input data |
| | | # - Accumulate feats, square values, and the length for whitening |
| | | |
| | | if args.valid_batch_size is None: |
| | | args.valid_batch_size = args.batch_size |
| | | |
| | | if len(args.train_shape_file) != 0: |
| | | train_key_file = args.train_shape_file[0] |
| | | else: |
| | | train_key_file = None |
| | | if len(args.valid_shape_file) != 0: |
| | | valid_key_file = args.valid_shape_file[0] |
| | | else: |
| | | valid_key_file = None |
| | | |
| | | collect_stats( |
| | | model=model, |
| | | train_iter=cls.build_streaming_iterator( |
| | | data_path_and_name_and_type=args.train_data_path_and_name_and_type, |
| | | key_file=train_key_file, |
| | | batch_size=args.batch_size, |
| | | dtype=args.train_dtype, |
| | | num_workers=args.num_workers, |
| | | allow_variable_data_keys=args.allow_variable_data_keys, |
| | | ngpu=args.ngpu, |
| | | preprocess_fn=cls.build_preprocess_fn(args, train=False), |
| | | collate_fn=cls.build_collate_fn(args, train=False), |
| | | ), |
| | | valid_iter=cls.build_streaming_iterator( |
| | | data_path_and_name_and_type=args.valid_data_path_and_name_and_type, |
| | | key_file=valid_key_file, |
| | | batch_size=args.valid_batch_size, |
| | | dtype=args.train_dtype, |
| | | num_workers=args.num_workers, |
| | | allow_variable_data_keys=args.allow_variable_data_keys, |
| | | ngpu=args.ngpu, |
| | | preprocess_fn=cls.build_preprocess_fn(args, train=False), |
| | | collate_fn=cls.build_collate_fn(args, train=False), |
| | | ), |
| | | output_dir=output_dir, |
| | | ngpu=args.ngpu, |
| | | log_interval=args.log_interval, |
| | | write_collected_feats=args.write_collected_feats, |
| | | ) |
| | | else: |
| | | logging.info("Training args: {}".format(args)) |
| | | # 6. Loads pre-trained model |
| | |
| | | # 7. Build iterator factories |
| | | if args.dataset_type == "large": |
| | | from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader |
| | | train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, |
| | | args.config, mode="train") |
| | | valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, |
| | | args.config, mode="eval") |
| | | train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, |
| | | "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, |
| | | "punc_list") else None, |
| | | bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, |
| | | mode="train") |
| | | valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, |
| | | "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, |
| | | "punc_list") else None, |
| | | bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, |
| | | mode="eval") |
| | | elif args.dataset_type == "small": |
| | | train_iter_factory = cls.build_iter_factory( |
| | | args=args, |
| | |
| | | ) -> AbsIterFactory: |
| | | assert check_argument_types() |
| | | |
| | | if args.frontend_conf is not None and "fs" in args.frontend_conf: |
| | | dest_sample_rate = args.frontend_conf["fs"] |
| | | else: |
| | | dest_sample_rate = 16000 |
| | | |
| | | dataset = ESPnetDataset( |
| | | iter_options.data_path_and_name_and_type, |
| | | float_dtype=args.train_dtype, |
| | | preprocess=iter_options.preprocess_fn, |
| | | max_cache_size=iter_options.max_cache_size, |
| | | max_cache_fd=iter_options.max_cache_fd, |
| | | dest_sample_rate=dest_sample_rate, |
| | | ) |
| | | cls.check_task_requirements( |
| | | dataset, args.allow_variable_data_keys, train=iter_options.train |
| | |
| | | collate_fn, |
| | | key_file: str = None, |
| | | batch_size: int = 1, |
| | | fs: dict = None, |
| | | mc: bool = False, |
| | | dtype: str = np.float32, |
| | | num_workers: int = 1, |
| | | allow_variable_data_keys: bool = False, |
| | |
| | | dataset = IterableESPnetDataset( |
| | | data_path_and_name_and_type, |
| | | float_dtype=dtype, |
| | | fs=fs, |
| | | mc=mc, |
| | | preprocess=preprocess_fn, |
| | | key_file=key_file, |
| | | ) |
| | |
| | | **kwargs, |
| | | ) |
| | | |
| | | @classmethod |
| | | def build_streaming_iterator_modelscope( |
| | | cls, |
| | | data_path_and_name_and_type, |
| | | preprocess_fn, |
| | | collate_fn, |
| | | key_file: str = None, |
| | | batch_size: int = 1, |
| | | dtype: str = np.float32, |
| | | num_workers: int = 1, |
| | | allow_variable_data_keys: bool = False, |
| | | ngpu: int = 0, |
| | | inference: bool = False, |
| | | sample_rate: Union[dict, int] = 16000 |
| | | ) -> DataLoader: |
| | | """Build DataLoader using iterable dataset""" |
| | | assert check_argument_types() |
| | | # For backward compatibility for pytorch DataLoader |
| | | if collate_fn is not None: |
| | | kwargs = dict(collate_fn=collate_fn) |
| | | else: |
| | | kwargs = {} |
| | | |
| | | audio_data = data_path_and_name_and_type[0] |
| | | if isinstance(audio_data, bytes): |
| | | dataset = IterableESPnetBytesModelScope( |
| | | data_path_and_name_and_type, |
| | | float_dtype=dtype, |
| | | preprocess=preprocess_fn, |
| | | key_file=key_file, |
| | | sample_rate=sample_rate |
| | | ) |
| | | else: |
| | | dataset = IterableESPnetDatasetModelScope( |
| | | data_path_and_name_and_type, |
| | | float_dtype=dtype, |
| | | preprocess=preprocess_fn, |
| | | key_file=key_file, |
| | | sample_rate=sample_rate |
| | | ) |
| | | |
| | | if dataset.apply_utt2category: |
| | | kwargs.update(batch_size=1) |
| | | else: |
| | | kwargs.update(batch_size=batch_size) |
| | | |
| | | cls.check_task_requirements(dataset, |
| | | allow_variable_data_keys, |
| | | train=False, |
| | | inference=inference) |
| | | |
| | | return DataLoader( |
| | | dataset=dataset, |
| | | pin_memory=ngpu > 0, |
| | | num_workers=num_workers, |
| | | **kwargs, |
| | | ) |
| | | |
| | | # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ |
| | | @classmethod |
| | | def build_model_from_file( |
| | | cls, |
| | | config_file: Union[Path, str] = None, |
| | | model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | ) -> Tuple[AbsESPnetModel, argparse.Namespace]: |
| | | """Build model from the files. |
| | |
| | | |
| | | with config_file.open("r", encoding="utf-8") as f: |
| | | args = yaml.safe_load(f) |
| | | if cmvn_file is not None: |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | |
| | | # in PyTorch<=1.4 |
| | | device = f"cuda:{torch.cuda.current_device()}" |
| | | model.load_state_dict(torch.load(model_file, map_location=device)) |
| | | |
| | | model.to(device) |
| | | return model, args |