| | |
| | | import yaml |
| | | from funasr.models.base_model import FunASRModel |
| | | from torch.utils.data import DataLoader |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr import __version__ |
| | | from funasr.datasets.dataset import AbsDataset |
| | |
| | | |
| | | @classmethod |
| | | def get_parser(cls) -> config_argparse.ArgumentParser: |
| | | assert check_argument_types() |
| | | |
| | | class ArgumentDefaultsRawTextHelpFormatter( |
| | | argparse.RawTextHelpFormatter, |
| | |
| | | cls.trainer.add_arguments(parser) |
| | | cls.add_task_arguments(parser) |
| | | |
| | | assert check_return_type(parser) |
| | | return parser |
| | | |
| | | @classmethod |
| | |
| | | return _cls |
| | | |
| | | # This method is used only for --print_config |
| | | assert check_argument_types() |
| | | parser = cls.get_parser() |
| | | args, _ = parser.parse_known_args() |
| | | config = vars(args) |
| | |
| | | |
| | | @classmethod |
| | | def check_required_command_args(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | if hasattr(args, "required"): |
| | | for k in vars(args): |
| | | if "-" in k: |
| | |
| | | inference: bool = False, |
| | | ) -> None: |
| | | """Check if the dataset satisfy the requirement of current Task""" |
| | | assert check_argument_types() |
| | | mes = ( |
| | | f"If you intend to use an additional input, modify " |
| | | f'"{cls.__name__}.required_data_names()" or ' |
| | |
| | | |
| | | @classmethod |
| | | def print_config(cls, file=sys.stdout) -> None: |
| | | assert check_argument_types() |
| | | # Shows the config: e.g. python train.py asr --print_config |
| | | config = cls.get_default_config() |
| | | file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False)) |
| | | |
| | | @classmethod |
| | | def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None): |
| | | assert check_argument_types() |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | if args is None: |
| | | parser = cls.get_parser() |
| | |
| | | |
| | | @classmethod |
| | | def main_worker(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | |
| | | # 0. Init distributed process |
| | | distributed_option = build_dataclass(DistributedOption, args) |
| | |
| | | - 4 epoch with "--num_iters_per_epoch" == 4 |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | iter_options = cls.build_iter_options(args, distributed_option, mode) |
| | | |
| | | # Overwrite iter_options if any kwargs is given |
| | |
| | | def build_sequence_iter_factory( |
| | | cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str |
| | | ) -> AbsIterFactory: |
| | | assert check_argument_types() |
| | | |
| | | if hasattr(args, "frontend_conf"): |
| | | if args.frontend_conf is not None and "fs" in args.frontend_conf: |
| | |
| | | iter_options: IteratorOptions, |
| | | mode: str, |
| | | ) -> AbsIterFactory: |
| | | assert check_argument_types() |
| | | |
| | | dataset = ESPnetDataset( |
| | | iter_options.data_path_and_name_and_type, |
| | |
| | | def build_multiple_iter_factory( |
| | | cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str |
| | | ): |
| | | assert check_argument_types() |
| | | iter_options = cls.build_iter_options(args, distributed_option, mode) |
| | | assert len(iter_options.data_path_and_name_and_type) > 0, len( |
| | | iter_options.data_path_and_name_and_type |
| | |
| | | inference: bool = False, |
| | | ) -> DataLoader: |
| | | """Build DataLoader using iterable dataset""" |
| | | assert check_argument_types() |
| | | # For backward compatibility for pytorch DataLoader |
| | | if collate_fn is not None: |
| | | kwargs = dict(collate_fn=collate_fn) |
| | |
| | | device: Device type, "cpu", "cuda", or "cuda:N". |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | if config_file is None: |
| | | assert model_file is not None, ( |
| | | "The argument 'model_file' must be provided " |