| | |
| | | import torch.nn |
| | | import torch.optim |
| | | import yaml |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.base_model import FunASRModel |
| | | from torch.utils.data import DataLoader |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | |
| | | >>> cls.check_task_requirements() |
| | | If your model is defined as following, |
| | | |
| | | >>> from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | >>> class Model(AbsESPnetModel): |
| | | >>> from funasr.models.base_model import FunASRModel |
| | | >>> class Model(FunASRModel): |
| | | ... def forward(self, input, output, opt=None): pass |
| | | |
| | | then "required_data_names" should be as |
| | |
| | | >>> cls.check_task_requirements() |
| | | If your model is defined as follows, |
| | | |
| | | >>> from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | >>> class Model(AbsESPnetModel): |
| | | >>> from funasr.models.base_model import FunASRModel |
| | | >>> class Model(FunASRModel): |
| | | ... def forward(self, input, output, opt=None): pass |
| | | |
| | | then "optional_data_names" should be as |
| | |
| | | |
| | | @classmethod |
| | | @abstractmethod |
| | | def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel: |
| | | def build_model(cls, args: argparse.Namespace) -> FunASRModel: |
| | | raise NotImplementedError |
| | | |
| | | |
| | | @classmethod |
| | | def get_parser(cls) -> config_argparse.ArgumentParser: |
| | |
| | | help='Perform on "collect stats" mode', |
| | | ) |
| | | group.add_argument( |
| | | "--mc", |
| | | type=bool, |
| | | default=False, |
| | | help="MultiChannel input", |
| | | ) |
| | | group.add_argument( |
| | | "--write_collected_feats", |
| | | type=str2bool, |
| | | default=False, |
| | |
| | | parser.add_argument( |
| | | "--batch_interval", |
| | | type=int, |
| | | default=10000, |
| | | default=-1, |
| | | help="The batch interval for saving model.", |
| | | ) |
| | | group.add_argument( |
| | |
| | | type=int, |
| | | default=1, |
| | | help="The number of gradient accumulation", |
| | | ) |
| | | group.add_argument( |
| | | "--bias_grad_times", |
| | | type=float, |
| | | default=1.0, |
| | | help="To scale the gradient of contextual related params", |
| | | ) |
| | | group.add_argument( |
| | | "--no_forward_run", |
| | |
| | | group.add_argument( |
| | | "--init_param", |
| | | type=str, |
| | | action="append", |
| | | default=[], |
| | | nargs="*", |
| | | help="Specify the file path used for initialization of parameters. " |
| | | "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', " |
| | | "where file_path is the model file path, " |
| | |
| | | "--freeze_param", |
| | | type=str, |
| | | default=[], |
| | | nargs="*", |
| | | action="append", |
| | | help="Freeze parameters", |
| | | ) |
| | | |
| | |
| | | elif args.distributed and args.simple_ddp: |
| | | distributed_option.init_torch_distributed_pai(args) |
| | | args.ngpu = dist.get_world_size() |
| | | if args.dataset_type == "small": |
| | | if args.dataset_type == "small" and args.ngpu > 0: |
| | | if args.batch_size is not None: |
| | | args.batch_size = args.batch_size * args.ngpu |
| | | if args.batch_bins is not None: |
| | | if args.batch_bins is not None and args.ngpu > 0: |
| | | args.batch_bins = args.batch_bins * args.ngpu |
| | | |
| | | # filter samples if wav.scp and text are mismatch |
| | |
| | | |
| | | # 2. Build model |
| | | model = cls.build_model(args=args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | if not isinstance(model, FunASRModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | ) |
| | | model = model.to( |
| | | dtype=getattr(torch, args.train_dtype), |
| | |
| | | data_path_and_name_and_type=args.train_data_path_and_name_and_type, |
| | | key_file=train_key_file, |
| | | batch_size=args.batch_size, |
| | | mc=args.mc, |
| | | dtype=args.train_dtype, |
| | | num_workers=args.num_workers, |
| | | allow_variable_data_keys=args.allow_variable_data_keys, |
| | |
| | | data_path_and_name_and_type=args.valid_data_path_and_name_and_type, |
| | | key_file=valid_key_file, |
| | | batch_size=args.valid_batch_size, |
| | | mc=args.mc, |
| | | dtype=args.train_dtype, |
| | | num_workers=args.num_workers, |
| | | allow_variable_data_keys=args.allow_variable_data_keys, |
| | |
| | | |
| | | # 7. Build iterator factories |
| | | if args.dataset_type == "large": |
| | | from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader |
| | | train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, |
| | | "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, |
| | | "punc_list") else None, |
| | | bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, |
| | | mode="train") |
| | | valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, |
| | | "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, |
| | | "punc_list") else None, |
| | | bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, |
| | | mode="eval") |
| | | from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader |
| | | train_iter_factory = LargeDataLoader(args, mode="train") |
| | | valid_iter_factory = LargeDataLoader(args, mode="eval") |
| | | |
| | | elif args.dataset_type == "small": |
| | | train_iter_factory = cls.build_iter_factory( |
| | | args=args, |
| | |
| | | ) -> AbsIterFactory: |
| | | assert check_argument_types() |
| | | |
| | | if args.frontend_conf is not None and "fs" in args.frontend_conf: |
| | | dest_sample_rate = args.frontend_conf["fs"] |
| | | if hasattr(args, "frontend_conf"): |
| | | if args.frontend_conf is not None and "fs" in args.frontend_conf: |
| | | dest_sample_rate = args.frontend_conf["fs"] |
| | | else: |
| | | dest_sample_rate = 16000 |
| | | else: |
| | | dest_sample_rate = 16000 |
| | | |
| | |
| | | model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | ) -> Tuple[AbsESPnetModel, argparse.Namespace]: |
| | | ) -> Tuple[FunASRModel, argparse.Namespace]: |
| | | """Build model from the files. |
| | | |
| | | This method is used for inference or fine-tuning. |
| | |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | if not isinstance(model, FunASRModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | ) |
| | | model.to(device) |
| | | if model_file is not None: |