| | |
| | | import torch.nn |
| | | import torch.optim |
| | | import yaml |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.base_model import FunASRModel |
| | | from torch.utils.data import DataLoader |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr import __version__ |
| | | from funasr.datasets.dataset import AbsDataset |
| | |
| | | >>> cls.check_task_requirements() |
| | | If your model is defined as following, |
| | | |
| | | >>> from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | >>> class Model(AbsESPnetModel): |
| | | >>> from funasr.models.base_model import FunASRModel |
| | | >>> class Model(FunASRModel): |
| | | ... def forward(self, input, output, opt=None): pass |
| | | |
| | | then "required_data_names" should be as |
| | |
| | | >>> cls.check_task_requirements() |
| | | If your model is defined as follows, |
| | | |
| | | >>> from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | >>> class Model(AbsESPnetModel): |
| | | >>> from funasr.models.base_model import FunASRModel |
| | | >>> class Model(FunASRModel): |
| | | ... def forward(self, input, output, opt=None): pass |
| | | |
| | | then "optional_data_names" should be as |
| | |
| | | |
| | | @classmethod |
| | | @abstractmethod |
| | | def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel: |
| | | def build_model(cls, args: argparse.Namespace) -> FunASRModel: |
| | | raise NotImplementedError |
| | | |
| | | |
| | | @classmethod |
| | | def get_parser(cls) -> config_argparse.ArgumentParser: |
| | | assert check_argument_types() |
| | | |
| | | class ArgumentDefaultsRawTextHelpFormatter( |
| | | argparse.RawTextHelpFormatter, |
| | |
| | | help='Perform on "collect stats" mode', |
| | | ) |
| | | group.add_argument( |
| | | "--mc", |
| | | type=bool, |
| | | default=False, |
| | | help="MultiChannel input", |
| | | ) |
| | | group.add_argument( |
| | | "--write_collected_feats", |
| | | type=str2bool, |
| | | default=False, |
| | |
| | | parser.add_argument( |
| | | "--batch_interval", |
| | | type=int, |
| | | default=10000, |
| | | default=-1, |
| | | help="The batch interval for saving model.", |
| | | ) |
| | | group.add_argument( |
| | |
| | | type=int, |
| | | default=1, |
| | | help="The number of gradient accumulation", |
| | | ) |
| | | group.add_argument( |
| | | "--bias_grad_times", |
| | | type=float, |
| | | default=1.0, |
| | | help="To scale the gradient of contextual related params", |
| | | ) |
| | | group.add_argument( |
| | | "--no_forward_run", |
| | |
| | | group.add_argument( |
| | | "--init_param", |
| | | type=str, |
| | | action="append", |
| | | default=[], |
| | | nargs="*", |
| | | help="Specify the file path used for initialization of parameters. " |
| | | "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', " |
| | | "where file_path is the model file path, " |
| | |
| | | "--freeze_param", |
| | | type=str, |
| | | default=[], |
| | | nargs="*", |
| | | action="append", |
| | | help="Freeze parameters", |
| | | ) |
| | | |
| | |
| | | cls.trainer.add_arguments(parser) |
| | | cls.add_task_arguments(parser) |
| | | |
| | | assert check_return_type(parser) |
| | | return parser |
| | | |
| | | @classmethod |
| | |
| | | return _cls |
| | | |
| | | # This method is used only for --print_config |
| | | assert check_argument_types() |
| | | parser = cls.get_parser() |
| | | args, _ = parser.parse_known_args() |
| | | config = vars(args) |
| | |
| | | |
| | | @classmethod |
| | | def check_required_command_args(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | if hasattr(args, "required"): |
| | | for k in vars(args): |
| | | if "-" in k: |
| | |
| | | inference: bool = False, |
| | | ) -> None: |
| | | """Check if the dataset satisfy the requirement of current Task""" |
| | | assert check_argument_types() |
| | | mes = ( |
| | | f"If you intend to use an additional input, modify " |
| | | f'"{cls.__name__}.required_data_names()" or ' |
| | |
| | | |
| | | @classmethod |
| | | def print_config(cls, file=sys.stdout) -> None: |
| | | assert check_argument_types() |
| | | # Shows the config: e.g. python train.py asr --print_config |
| | | config = cls.get_default_config() |
| | | file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False)) |
| | | |
| | | @classmethod |
| | | def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None): |
| | | assert check_argument_types() |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | if args is None: |
| | | parser = cls.get_parser() |
| | |
| | | |
| | | @classmethod |
| | | def main_worker(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | |
| | | # 0. Init distributed process |
| | | distributed_option = build_dataclass(DistributedOption, args) |
| | |
| | | elif args.distributed and args.simple_ddp: |
| | | distributed_option.init_torch_distributed_pai(args) |
| | | args.ngpu = dist.get_world_size() |
| | | if args.dataset_type == "small": |
| | | if args.dataset_type == "small" and args.ngpu > 0: |
| | | if args.batch_size is not None: |
| | | args.batch_size = args.batch_size * args.ngpu |
| | | if args.batch_bins is not None: |
| | | if args.batch_bins is not None and args.ngpu > 0: |
| | | args.batch_bins = args.batch_bins * args.ngpu |
| | | |
| | | # filter samples if wav.scp and text are mismatch |
| | | if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large": |
| | | if ( |
| | | args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large": |
| | | if not args.simple_ddp or distributed_option.dist_rank == 0: |
| | | filter_wav_text(args.data_dir, args.train_set) |
| | | filter_wav_text(args.data_dir, args.dev_set) |
| | |
| | | |
| | | if args.train_shape_file is None and args.dataset_type == "small": |
| | | if not args.simple_ddp or distributed_option.dist_rank == 0: |
| | | calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max) |
| | | calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max) |
| | | calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, |
| | | args.speech_length_max) |
| | | calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, |
| | | args.speech_length_max) |
| | | if args.simple_ddp: |
| | | dist.barrier() |
| | | args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")] |
| | |
| | | |
| | | # 2. Build model |
| | | model = cls.build_model(args=args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | if not isinstance(model, FunASRModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | ) |
| | | model = model.to( |
| | | dtype=getattr(torch, args.train_dtype), |
| | |
| | | data_path_and_name_and_type=args.train_data_path_and_name_and_type, |
| | | key_file=train_key_file, |
| | | batch_size=args.batch_size, |
| | | mc=args.mc, |
| | | dtype=args.train_dtype, |
| | | num_workers=args.num_workers, |
| | | allow_variable_data_keys=args.allow_variable_data_keys, |
| | |
| | | data_path_and_name_and_type=args.valid_data_path_and_name_and_type, |
| | | key_file=valid_key_file, |
| | | batch_size=args.valid_batch_size, |
| | | mc=args.mc, |
| | | dtype=args.train_dtype, |
| | | num_workers=args.num_workers, |
| | | allow_variable_data_keys=args.allow_variable_data_keys, |
| | |
| | | |
| | | # 7. Build iterator factories |
| | | if args.dataset_type == "large": |
| | | from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader |
| | | train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, |
| | | bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, |
| | | mode="train") |
| | | valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, |
| | | bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, |
| | | mode="eval") |
| | | from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader |
| | | train_iter_factory = LargeDataLoader(args, mode="train") |
| | | valid_iter_factory = LargeDataLoader(args, mode="eval") |
| | | |
| | | elif args.dataset_type == "small": |
| | | train_iter_factory = cls.build_iter_factory( |
| | | args=args, |
| | |
| | | - 4 epoch with "--num_iters_per_epoch" == 4 |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | iter_options = cls.build_iter_options(args, distributed_option, mode) |
| | | |
| | | # Overwrite iter_options if any kwargs is given |
| | |
| | | def build_sequence_iter_factory( |
| | | cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str |
| | | ) -> AbsIterFactory: |
| | | assert check_argument_types() |
| | | |
| | | if args.frontend_conf is not None and "fs" in args.frontend_conf: |
| | | dest_sample_rate = args.frontend_conf["fs"] |
| | | if hasattr(args, "frontend_conf"): |
| | | if args.frontend_conf is not None and "fs" in args.frontend_conf: |
| | | dest_sample_rate = args.frontend_conf["fs"] |
| | | else: |
| | | dest_sample_rate = 16000 |
| | | else: |
| | | dest_sample_rate = 16000 |
| | | |
| | |
| | | iter_options: IteratorOptions, |
| | | mode: str, |
| | | ) -> AbsIterFactory: |
| | | assert check_argument_types() |
| | | |
| | | dataset = ESPnetDataset( |
| | | iter_options.data_path_and_name_and_type, |
| | |
| | | def build_multiple_iter_factory( |
| | | cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str |
| | | ): |
| | | assert check_argument_types() |
| | | iter_options = cls.build_iter_options(args, distributed_option, mode) |
| | | assert len(iter_options.data_path_and_name_and_type) > 0, len( |
| | | iter_options.data_path_and_name_and_type |
| | |
| | | inference: bool = False, |
| | | ) -> DataLoader: |
| | | """Build DataLoader using iterable dataset""" |
| | | assert check_argument_types() |
| | | # For backward compatibility for pytorch DataLoader |
| | | if collate_fn is not None: |
| | | kwargs = dict(collate_fn=collate_fn) |
| | |
| | | model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | ) -> Tuple[AbsESPnetModel, argparse.Namespace]: |
| | | ) -> Tuple[FunASRModel, argparse.Namespace]: |
| | | """Build model from the files. |
| | | |
| | | This method is used for inference or fine-tuning. |
| | |
| | | device: Device type, "cpu", "cuda", or "cuda:N". |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | if config_file is None: |
| | | assert model_file is not None, ( |
| | | "The argument 'model_file' must be provided " |
| | |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | if not isinstance(model, FunASRModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | ) |
| | | model.to(device) |
| | | if model_file is not None: |