| | |
| | | import yaml |
| | | from funasr.models.base_model import FunASRModel |
| | | from torch.utils.data import DataLoader |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr import __version__ |
| | | from funasr.datasets.dataset import AbsDataset |
| | |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text |
| | | from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump |
| | | from funasr.modules.lora.utils import mark_only_lora_as_trainable |
| | | |
| | | try: |
| | | import wandb |
| | |
| | | |
| | | @classmethod |
| | | def get_parser(cls) -> config_argparse.ArgumentParser: |
| | | assert check_argument_types() |
| | | |
| | | class ArgumentDefaultsRawTextHelpFormatter( |
| | | argparse.RawTextHelpFormatter, |
| | |
| | | default=None, |
| | | help="oss bucket.", |
| | | ) |
| | | group.add_argument( |
| | | "--enable_lora", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Apply lora for finetuning.", |
| | | ) |
| | | group.add_argument( |
| | | "--lora_bias", |
| | | type=str, |
| | | default="none", |
| | | help="lora bias.", |
| | | ) |
| | | |
| | | cls.trainer.add_arguments(parser) |
| | | cls.add_task_arguments(parser) |
| | | |
| | | assert check_return_type(parser) |
| | | return parser |
| | | |
| | | @classmethod |
| | |
| | | return _cls |
| | | |
| | | # This method is used only for --print_config |
| | | assert check_argument_types() |
| | | parser = cls.get_parser() |
| | | args, _ = parser.parse_known_args() |
| | | config = vars(args) |
| | |
| | | |
| | | @classmethod |
| | | def check_required_command_args(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | if hasattr(args, "required"): |
| | | for k in vars(args): |
| | | if "-" in k: |
| | |
| | | inference: bool = False, |
| | | ) -> None: |
| | | """Check if the dataset satisfy the requirement of current Task""" |
| | | assert check_argument_types() |
| | | mes = ( |
| | | f"If you intend to use an additional input, modify " |
| | | f'"{cls.__name__}.required_data_names()" or ' |
| | |
| | | |
| | | @classmethod |
| | | def print_config(cls, file=sys.stdout) -> None: |
| | | assert check_argument_types() |
| | | # Shows the config: e.g. python train.py asr --print_config |
| | | config = cls.get_default_config() |
| | | file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False)) |
| | | |
| | | @classmethod |
| | | def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None): |
| | | assert check_argument_types() |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | if args is None: |
| | | parser = cls.get_parser() |
| | |
| | | |
| | | @classmethod |
| | | def main_worker(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | |
| | | # 0. Init distributed process |
| | | distributed_option = build_dataclass(DistributedOption, args) |
| | |
| | | dtype=getattr(torch, args.train_dtype), |
| | | device="cuda" if args.ngpu > 0 else "cpu", |
| | | ) |
| | | if args.enable_lora: |
| | | mark_only_lora_as_trainable(model, args.lora_bias) |
| | | for t in args.freeze_param: |
| | | for k, p in model.named_parameters(): |
| | | if k.startswith(t + ".") or k == t: |
| | |
| | | - 4 epoch with "--num_iters_per_epoch" == 4 |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | iter_options = cls.build_iter_options(args, distributed_option, mode) |
| | | |
| | | # Overwrite iter_options if any kwargs is given |
| | |
| | | def build_sequence_iter_factory( |
| | | cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str |
| | | ) -> AbsIterFactory: |
| | | assert check_argument_types() |
| | | |
| | | if hasattr(args, "frontend_conf"): |
| | | if args.frontend_conf is not None and "fs" in args.frontend_conf: |
| | |
| | | iter_options: IteratorOptions, |
| | | mode: str, |
| | | ) -> AbsIterFactory: |
| | | assert check_argument_types() |
| | | |
| | | dataset = ESPnetDataset( |
| | | iter_options.data_path_and_name_and_type, |
| | |
| | | def build_multiple_iter_factory( |
| | | cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str |
| | | ): |
| | | assert check_argument_types() |
| | | iter_options = cls.build_iter_options(args, distributed_option, mode) |
| | | assert len(iter_options.data_path_and_name_and_type) > 0, len( |
| | | iter_options.data_path_and_name_and_type |
| | |
| | | inference: bool = False, |
| | | ) -> DataLoader: |
| | | """Build DataLoader using iterable dataset""" |
| | | assert check_argument_types() |
| | | # For backward compatibility for pytorch DataLoader |
| | | if collate_fn is not None: |
| | | kwargs = dict(collate_fn=collate_fn) |
| | |
| | | device: Device type, "cpu", "cuda", or "cuda:N". |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | if config_file is None: |
| | | assert model_file is not None, ( |
| | | "The argument 'model_file' must be provided " |