From d5d7363da1f56c6932cba2901cc4b9d6f130a069 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 22:56:20 +0800
Subject: [PATCH] train

---
 funasr/tasks/abs_task.py |  761 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++--
 1 files changed, 733 insertions(+), 28 deletions(-)

diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 55a5d79..5940d0c 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -30,6 +30,7 @@
 import torch.nn
 import torch.optim
 import yaml
+from funasr.models.base_model import FunASRModel
 from torch.utils.data import DataLoader
 from typeguard import check_argument_types
 from typeguard import check_return_type
@@ -44,19 +45,18 @@
 from funasr.iterators.multiple_iter_factory import MultipleIterFactory
 from funasr.iterators.sequence_iter_factory import SequenceIterFactory
 from funasr.main_funcs.collect_stats import collect_stats
-from funasr.optimizers.sgd import SGD
 from funasr.optimizers.fairseq_adam import FairseqAdam
+from funasr.optimizers.sgd import SGD
 from funasr.samplers.build_batch_sampler import BATCH_TYPES
 from funasr.samplers.build_batch_sampler import build_batch_sampler
 from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
 from funasr.schedulers.noam_lr import NoamLR
-from funasr.schedulers.warmup_lr import WarmupLR
 from funasr.schedulers.tri_stage_scheduler import TriStageLR
+from funasr.schedulers.warmup_lr import WarmupLR
 from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.distributed_utils import DistributedOption
 from funasr.train.trainer import Trainer
@@ -230,8 +230,8 @@
         >>> cls.check_task_requirements()
         If your model is defined as following,
 
-        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
-        >>> class Model(AbsESPnetModel):
+        >>> from funasr.models.base_model import FunASRModel
+        >>> class Model(FunASRModel):
         ...     def forward(self, input, output, opt=None):  pass
 
         then "required_data_names" should be as
@@ -251,8 +251,8 @@
         >>> cls.check_task_requirements()
         If your model is defined as follows,
 
-        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
-        >>> class Model(AbsESPnetModel):
+        >>> from funasr.models.base_model import FunASRModel
+        >>> class Model(FunASRModel):
         ...     def forward(self, input, output, opt=None):  pass
 
         then "optional_data_names" should be as
@@ -263,11 +263,11 @@
 
     @classmethod
     @abstractmethod
-    def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel:
+    def build_model(cls, args: argparse.Namespace) -> FunASRModel:
         raise NotImplementedError
 
     @classmethod
-    def get_parser(cls) -> config_argparse.ArgumentParser:
+    def get_parser(cls, parser) -> config_argparse.ArgumentParser:
         assert check_argument_types()
 
         class ArgumentDefaultsRawTextHelpFormatter(
@@ -276,10 +276,10 @@
         ):
             pass
 
-        parser = config_argparse.ArgumentParser(
-            description="base parser",
-            formatter_class=ArgumentDefaultsRawTextHelpFormatter,
-        )
+        # parser = config_argparse.ArgumentParser(
+        #     description="base parser",
+        #     formatter_class=ArgumentDefaultsRawTextHelpFormatter,
+        # )
 
         # NOTE(kamo): Use '_' instead of '-' to avoid confusion.
         #  I think '-' looks really confusing if it's written in yaml.
@@ -961,6 +961,702 @@
         assert check_return_type(parser)
         return parser
 
+
+    # @classmethod
+    # def get_parser(cls) -> config_argparse.ArgumentParser:
+    #     assert check_argument_types()
+    #
+    #     class ArgumentDefaultsRawTextHelpFormatter(
+    #         argparse.RawTextHelpFormatter,
+    #         argparse.ArgumentDefaultsHelpFormatter,
+    #     ):
+    #         pass
+    #
+    #     parser = config_argparse.ArgumentParser(
+    #         description="base parser",
+    #         formatter_class=ArgumentDefaultsRawTextHelpFormatter,
+    #     )
+    #
+    #     # NOTE(kamo): Use '_' instead of '-' to avoid confusion.
+    #     #  I think '-' looks really confusing if it's written in yaml.
+    #
+    #     # NOTE(kamo): add_arguments(..., required=True) can't be used
+    #     #  to provide --print_config mode. Instead of it, do as
+    #     # parser.set_defaults(required=["output_dir"])
+    #
+    #     group = parser.add_argument_group("Common configuration")
+    #
+    #     group.add_argument(
+    #         "--print_config",
+    #         action="store_true",
+    #         help="Print the config file and exit",
+    #     )
+    #     group.add_argument(
+    #         "--log_level",
+    #         type=lambda x: x.upper(),
+    #         default="INFO",
+    #         choices=("ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+    #         help="The verbose level of logging",
+    #     )
+    #     group.add_argument(
+    #         "--dry_run",
+    #         type=str2bool,
+    #         default=False,
+    #         help="Perform process without training",
+    #     )
+    #     group.add_argument(
+    #         "--iterator_type",
+    #         type=str,
+    #         choices=["sequence", "chunk", "task", "none"],
+    #         default="sequence",
+    #         help="Specify iterator type",
+    #     )
+    #
+    #     group.add_argument("--output_dir", type=str_or_none, default=None)
+    #     group.add_argument(
+    #         "--ngpu",
+    #         type=int,
+    #         default=0,
+    #         help="The number of gpus. 0 indicates CPU mode",
+    #     )
+    #     group.add_argument("--seed", type=int, default=0, help="Random seed")
+    #     group.add_argument(
+    #         "--num_workers",
+    #         type=int,
+    #         default=1,
+    #         help="The number of workers used for DataLoader",
+    #     )
+    #     group.add_argument(
+    #         "--num_att_plot",
+    #         type=int,
+    #         default=3,
+    #         help="The number images to plot the outputs from attention. "
+    #              "This option makes sense only when attention-based model. "
+    #              "We can also disable the attention plot by setting it 0",
+    #     )
+    #
+    #     group = parser.add_argument_group("distributed training related")
+    #     group.add_argument(
+    #         "--dist_backend",
+    #         default="nccl",
+    #         type=str,
+    #         help="distributed backend",
+    #     )
+    #     group.add_argument(
+    #         "--dist_init_method",
+    #         type=str,
+    #         default="env://",
+    #         help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
+    #              '"WORLD_SIZE", and "RANK" are referred.',
+    #     )
+    #     group.add_argument(
+    #         "--dist_world_size",
+    #         default=None,
+    #         type=int_or_none,
+    #         help="number of nodes for distributed training",
+    #     )
+    #     group.add_argument(
+    #         "--dist_rank",
+    #         type=int_or_none,
+    #         default=None,
+    #         help="node rank for distributed training",
+    #     )
+    #     group.add_argument(
+    #         # Not starting with "dist_" for compatibility to launch.py
+    #         "--local_rank",
+    #         type=int_or_none,
+    #         default=None,
+    #         help="local rank for distributed training. This option is used if "
+    #              "--multiprocessing_distributed=false",
+    #     )
+    #     group.add_argument(
+    #         "--dist_master_addr",
+    #         default=None,
+    #         type=str_or_none,
+    #         help="The master address for distributed training. "
+    #              "This value is used when dist_init_method == 'env://'",
+    #     )
+    #     group.add_argument(
+    #         "--dist_master_port",
+    #         default=None,
+    #         type=int_or_none,
+    #         help="The master port for distributed training"
+    #              "This value is used when dist_init_method == 'env://'",
+    #     )
+    #     group.add_argument(
+    #         "--dist_launcher",
+    #         default=None,
+    #         type=str_or_none,
+    #         choices=["slurm", "mpi", None],
+    #         help="The launcher type for distributed training",
+    #     )
+    #     group.add_argument(
+    #         "--multiprocessing_distributed",
+    #         default=False,
+    #         type=str2bool,
+    #         help="Use multi-processing distributed training to launch "
+    #              "N processes per node, which has N GPUs. This is the "
+    #              "fastest way to use PyTorch for either single node or "
+    #              "multi node data parallel training",
+    #     )
+    #     group.add_argument(
+    #         "--unused_parameters",
+    #         type=str2bool,
+    #         default=False,
+    #         help="Whether to use the find_unused_parameters in "
+    #              "torch.nn.parallel.DistributedDataParallel ",
+    #     )
+    #     group.add_argument(
+    #         "--sharded_ddp",
+    #         default=False,
+    #         type=str2bool,
+    #         help="Enable sharded training provided by fairscale",
+    #     )
+    #
+    #     group = parser.add_argument_group("cudnn mode related")
+    #     group.add_argument(
+    #         "--cudnn_enabled",
+    #         type=str2bool,
+    #         default=torch.backends.cudnn.enabled,
+    #         help="Enable CUDNN",
+    #     )
+    #     group.add_argument(
+    #         "--cudnn_benchmark",
+    #         type=str2bool,
+    #         default=torch.backends.cudnn.benchmark,
+    #         help="Enable cudnn-benchmark mode",
+    #     )
+    #     group.add_argument(
+    #         "--cudnn_deterministic",
+    #         type=str2bool,
+    #         default=True,
+    #         help="Enable cudnn-deterministic mode",
+    #     )
+    #
+    #     group = parser.add_argument_group("collect stats mode related")
+    #     group.add_argument(
+    #         "--collect_stats",
+    #         type=str2bool,
+    #         default=False,
+    #         help='Perform on "collect stats" mode',
+    #     )
+    #     group.add_argument(
+    #         "--mc",
+    #         type=bool,
+    #         default=False,
+    #         help="MultiChannel input",
+    #     )
+    #     group.add_argument(
+    #         "--write_collected_feats",
+    #         type=str2bool,
+    #         default=False,
+    #         help='Write the output features from the model when "collect stats" mode',
+    #     )
+    #
+    #     group = parser.add_argument_group("Trainer related")
+    #     group.add_argument(
+    #         "--max_epoch",
+    #         type=int,
+    #         default=40,
+    #         help="The maximum number epoch to train",
+    #     )
+    #     group.add_argument(
+    #         "--max_update",
+    #         type=int,
+    #         default=sys.maxsize,
+    #         help="The maximum number update step to train",
+    #     )
+    #     parser.add_argument(
+    #         "--batch_interval",
+    #         type=int,
+    #         default=-1,
+    #         help="The batch interval for saving model.",
+    #     )
+    #     group.add_argument(
+    #         "--patience",
+    #         type=int_or_none,
+    #         default=None,
+    #         help="Number of epochs to wait without improvement "
+    #              "before stopping the training",
+    #     )
+    #     group.add_argument(
+    #         "--val_scheduler_criterion",
+    #         type=str,
+    #         nargs=2,
+    #         default=("valid", "loss"),
+    #         help="The criterion used for the value given to the lr scheduler. "
+    #              'Give a pair referring the phase, "train" or "valid",'
+    #              'and the criterion name. The mode specifying "min" or "max" can '
+    #              "be changed by --scheduler_conf",
+    #     )
+    #     group.add_argument(
+    #         "--early_stopping_criterion",
+    #         type=str,
+    #         nargs=3,
+    #         default=("valid", "loss", "min"),
+    #         help="The criterion used for judging of early stopping. "
+    #              'Give a pair referring the phase, "train" or "valid",'
+    #              'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
+    #     )
+    #     group.add_argument(
+    #         "--best_model_criterion",
+    #         type=str2triple_str,
+    #         nargs="+",
+    #         default=[
+    #             ("train", "loss", "min"),
+    #             ("valid", "loss", "min"),
+    #             ("train", "acc", "max"),
+    #             ("valid", "acc", "max"),
+    #         ],
+    #         help="The criterion used for judging of the best model. "
+    #              'Give a pair referring the phase, "train" or "valid",'
+    #              'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
+    #     )
+    #     group.add_argument(
+    #         "--keep_nbest_models",
+    #         type=int,
+    #         nargs="+",
+    #         default=[10],
+    #         help="Remove previous snapshots excluding the n-best scored epochs",
+    #     )
+    #     group.add_argument(
+    #         "--nbest_averaging_interval",
+    #         type=int,
+    #         default=0,
+    #         help="The epoch interval to apply model averaging and save nbest models",
+    #     )
+    #     group.add_argument(
+    #         "--grad_clip",
+    #         type=float,
+    #         default=5.0,
+    #         help="Gradient norm threshold to clip",
+    #     )
+    #     group.add_argument(
+    #         "--grad_clip_type",
+    #         type=float,
+    #         default=2.0,
+    #         help="The type of the used p-norm for gradient clip. Can be inf",
+    #     )
+    #     group.add_argument(
+    #         "--grad_noise",
+    #         type=str2bool,
+    #         default=False,
+    #         help="The flag to switch to use noise injection to "
+    #              "gradients during training",
+    #     )
+    #     group.add_argument(
+    #         "--accum_grad",
+    #         type=int,
+    #         default=1,
+    #         help="The number of gradient accumulation",
+    #     )
+    #     group.add_argument(
+    #         "--bias_grad_times",
+    #         type=float,
+    #         default=1.0,
+    #         help="To scale the gradient of contextual related params",
+    #     )
+    #     group.add_argument(
+    #         "--no_forward_run",
+    #         type=str2bool,
+    #         default=False,
+    #         help="Just only iterating data loading without "
+    #              "model forwarding and training",
+    #     )
+    #     group.add_argument(
+    #         "--resume",
+    #         type=str2bool,
+    #         default=False,
+    #         help="Enable resuming if checkpoint is existing",
+    #     )
+    #     group.add_argument(
+    #         "--train_dtype",
+    #         default="float32",
+    #         choices=["float16", "float32", "float64"],
+    #         help="Data type for training.",
+    #     )
+    #     group.add_argument(
+    #         "--use_amp",
+    #         type=str2bool,
+    #         default=False,
+    #         help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
+    #     )
+    #     group.add_argument(
+    #         "--log_interval",
+    #         type=int_or_none,
+    #         default=None,
+    #         help="Show the logs every the number iterations in each epochs at the "
+    #              "training phase. If None is given, it is decided according the number "
+    #              "of training samples automatically .",
+    #     )
+    #     group.add_argument(
+    #         "--use_tensorboard",
+    #         type=str2bool,
+    #         default=True,
+    #         help="Enable tensorboard logging",
+    #     )
+    #     group.add_argument(
+    #         "--use_wandb",
+    #         type=str2bool,
+    #         default=False,
+    #         help="Enable wandb logging",
+    #     )
+    #     group.add_argument(
+    #         "--wandb_project",
+    #         type=str,
+    #         default=None,
+    #         help="Specify wandb project",
+    #     )
+    #     group.add_argument(
+    #         "--wandb_id",
+    #         type=str,
+    #         default=None,
+    #         help="Specify wandb id",
+    #     )
+    #     group.add_argument(
+    #         "--wandb_entity",
+    #         type=str,
+    #         default=None,
+    #         help="Specify wandb entity",
+    #     )
+    #     group.add_argument(
+    #         "--wandb_name",
+    #         type=str,
+    #         default=None,
+    #         help="Specify wandb run name",
+    #     )
+    #     group.add_argument(
+    #         "--wandb_model_log_interval",
+    #         type=int,
+    #         default=-1,
+    #         help="Set the model log period",
+    #     )
+    #     group.add_argument(
+    #         "--detect_anomaly",
+    #         type=str2bool,
+    #         default=False,
+    #         help="Set torch.autograd.set_detect_anomaly",
+    #     )
+    #
+    #     group = parser.add_argument_group("Pretraining model related")
+    #     group.add_argument("--pretrain_path", help="This option is obsoleted")
+    #     group.add_argument(
+    #         "--init_param",
+    #         type=str,
+    #         action="append",
+    #         default=[],
+    #         help="Specify the file path used for initialization of parameters. "
+    #              "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
+    #              "where file_path is the model file path, "
+    #              "src_key specifies the key of model states to be used in the model file, "
+    #              "dst_key specifies the attribute of the model to be initialized, "
+    #              "and exclude_keys excludes keys of model states for the initialization."
+    #              "e.g.\n"
+    #              "  # Load all parameters"
+    #              "  --init_param some/where/model.pb\n"
+    #              "  # Load only decoder parameters"
+    #              "  --init_param some/where/model.pb:decoder:decoder\n"
+    #              "  # Load only decoder parameters excluding decoder.embed"
+    #              "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
+    #              "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
+    #     )
+    #     group.add_argument(
+    #         "--ignore_init_mismatch",
+    #         type=str2bool,
+    #         default=False,
+    #         help="Ignore size mismatch when loading pre-trained model",
+    #     )
+    #     group.add_argument(
+    #         "--freeze_param",
+    #         type=str,
+    #         default=[],
+    #         action="append",
+    #         help="Freeze parameters",
+    #     )
+    #
+    #     group = parser.add_argument_group("BatchSampler related")
+    #     group.add_argument(
+    #         "--num_iters_per_epoch",
+    #         type=int_or_none,
+    #         default=None,
+    #         help="Restrict the number of iterations for training per epoch",
+    #     )
+    #     group.add_argument(
+    #         "--batch_size",
+    #         type=int,
+    #         default=20,
+    #         help="The mini-batch size used for training. Used if batch_type='unsorted',"
+    #              " 'sorted', or 'folded'.",
+    #     )
+    #     group.add_argument(
+    #         "--valid_batch_size",
+    #         type=int_or_none,
+    #         default=None,
+    #         help="If not given, the value of --batch_size is used",
+    #     )
+    #     group.add_argument(
+    #         "--batch_bins",
+    #         type=int,
+    #         default=1000000,
+    #         help="The number of batch bins. Used if batch_type='length' or 'numel'",
+    #     )
+    #     group.add_argument(
+    #         "--valid_batch_bins",
+    #         type=int_or_none,
+    #         default=None,
+    #         help="If not given, the value of --batch_bins is used",
+    #     )
+    #
+    #     group.add_argument("--train_shape_file", type=str, action="append", default=[])
+    #     group.add_argument("--valid_shape_file", type=str, action="append", default=[])
+    #
+    #     group = parser.add_argument_group("Sequence iterator related")
+    #     _batch_type_help = ""
+    #     for key, value in BATCH_TYPES.items():
+    #         _batch_type_help += f'"{key}":\n{value}\n'
+    #     group.add_argument(
+    #         "--batch_type",
+    #         type=str,
+    #         default="length",
+    #         choices=list(BATCH_TYPES),
+    #         help=_batch_type_help,
+    #     )
+    #     group.add_argument(
+    #         "--valid_batch_type",
+    #         type=str_or_none,
+    #         default=None,
+    #         choices=list(BATCH_TYPES) + [None],
+    #         help="If not given, the value of --batch_type is used",
+    #     )
+    #     group.add_argument(
+    #         "--speech_length_min",
+    #         type=int,
+    #         default=-1,
+    #         help="speech length min",
+    #     )
+    #     group.add_argument(
+    #         "--speech_length_max",
+    #         type=int,
+    #         default=-1,
+    #         help="speech length max",
+    #     )
+    #     group.add_argument("--fold_length", type=int, action="append", default=[])
+    #     group.add_argument(
+    #         "--sort_in_batch",
+    #         type=str,
+    #         default="descending",
+    #         choices=["descending", "ascending"],
+    #         help="Sort the samples in each mini-batches by the sample "
+    #              'lengths. To enable this, "shape_file" must have the length information.',
+    #     )
+    #     group.add_argument(
+    #         "--sort_batch",
+    #         type=str,
+    #         default="descending",
+    #         choices=["descending", "ascending"],
+    #         help="Sort mini-batches by the sample lengths",
+    #     )
+    #     group.add_argument(
+    #         "--multiple_iterator",
+    #         type=str2bool,
+    #         default=False,
+    #         help="Use multiple iterator mode",
+    #     )
+    #
+    #     group = parser.add_argument_group("Chunk iterator related")
+    #     group.add_argument(
+    #         "--chunk_length",
+    #         type=str_or_int,
+    #         default=500,
+    #         help="Specify chunk length. e.g. '300', '300,400,500', or '300-400'."
+    #              "If multiple numbers separated by command are given, "
+    #              "one of them is selected randomly for each samples. "
+    #              "If two numbers are given with '-', it indicates the range of the choices. "
+    #              "Note that if the sequence length is shorter than the all chunk_lengths, "
+    #              "the sample is discarded. ",
+    #     )
+    #     group.add_argument(
+    #         "--chunk_shift_ratio",
+    #         type=float,
+    #         default=0.5,
+    #         help="Specify the shift width of chunks. If it's less than 1, "
+    #              "allows the overlapping and if bigger than 1, there are some gaps "
+    #              "between each chunk.",
+    #     )
+    #     group.add_argument(
+    #         "--num_cache_chunks",
+    #         type=int,
+    #         default=1024,
+    #         help="Shuffle in the specified number of chunks and generate mini-batches "
+    #              "More larger this value, more randomness can be obtained.",
+    #     )
+    #
+    #     group = parser.add_argument_group("Dataset related")
+    #     _data_path_and_name_and_type_help = (
+    #         "Give three words splitted by comma. It's used for the training data. "
+    #         "e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. "
+    #         "The first value, some/path/a.scp, indicates the file path, "
+    #         "and the second, foo, is the key name used for the mini-batch data, "
+    #         "and the last, sound, decides the file type. "
+    #         "This option is repeatable, so you can input any number of features "
+    #         "for your task. Supported file types are as follows:\n\n"
+    #     )
+    #     for key, dic in DATA_TYPES.items():
+    #         _data_path_and_name_and_type_help += f'"{key}":\n{dic["help"]}\n\n'
+    #
+    #     # for large dataset
+    #     group.add_argument(
+    #         "--dataset_type",
+    #         type=str,
+    #         default="small",
+    #         help="whether to use dataloader for large dataset",
+    #     )
+    #     parser.add_argument(
+    #         "--dataset_conf",
+    #         action=NestedDictAction,
+    #         default=dict(),
+    #         help=f"The keyword arguments for dataset",
+    #     )
+    #     group.add_argument(
+    #         "--train_data_file",
+    #         type=str,
+    #         default=None,
+    #         help="train_list for large dataset",
+    #     )
+    #     group.add_argument(
+    #         "--valid_data_file",
+    #         type=str,
+    #         default=None,
+    #         help="valid_list for large dataset",
+    #     )
+    #
+    #     group.add_argument(
+    #         "--train_data_path_and_name_and_type",
+    #         type=str2triple_str,
+    #         action="append",
+    #         default=[],
+    #         help=_data_path_and_name_and_type_help,
+    #     )
+    #     group.add_argument(
+    #         "--valid_data_path_and_name_and_type",
+    #         type=str2triple_str,
+    #         action="append",
+    #         default=[],
+    #     )
+    #     group.add_argument(
+    #         "--allow_variable_data_keys",
+    #         type=str2bool,
+    #         default=False,
+    #         help="Allow the arbitrary keys for mini-batch with ignoring "
+    #              "the task requirements",
+    #     )
+    #     group.add_argument(
+    #         "--max_cache_size",
+    #         type=humanfriendly.parse_size,
+    #         default=0.0,
+    #         help="The maximum cache size for data loader. e.g. 10MB, 20GB.",
+    #     )
+    #     group.add_argument(
+    #         "--max_cache_fd",
+    #         type=int,
+    #         default=32,
+    #         help="The maximum number of file descriptors to be kept "
+    #              "as opened for ark files. "
+    #              "This feature is only valid when data type is 'kaldi_ark'.",
+    #     )
+    #     group.add_argument(
+    #         "--valid_max_cache_size",
+    #         type=humanfriendly_parse_size_or_none,
+    #         default=None,
+    #         help="The maximum cache size for validation data loader. e.g. 10MB, 20GB. "
+    #              "If None, the 5 percent size of --max_cache_size",
+    #     )
+    #
+    #     group = parser.add_argument_group("Optimizer related")
+    #     for i in range(1, cls.num_optimizers + 1):
+    #         suf = "" if i == 1 else str(i)
+    #         group.add_argument(
+    #             f"--optim{suf}",
+    #             type=lambda x: x.lower(),
+    #             default="adadelta",
+    #             choices=list(optim_classes),
+    #             help="The optimizer type",
+    #         )
+    #         group.add_argument(
+    #             f"--optim{suf}_conf",
+    #             action=NestedDictAction,
+    #             default=dict(),
+    #             help="The keyword arguments for optimizer",
+    #         )
+    #         group.add_argument(
+    #             f"--scheduler{suf}",
+    #             type=lambda x: str_or_none(x.lower()),
+    #             default=None,
+    #             choices=list(scheduler_classes) + [None],
+    #             help="The lr scheduler type",
+    #         )
+    #         group.add_argument(
+    #             f"--scheduler{suf}_conf",
+    #             action=NestedDictAction,
+    #             default=dict(),
+    #             help="The keyword arguments for lr scheduler",
+    #         )
+    #
+    #     # for training on PAI
+    #     group = parser.add_argument_group("PAI training related")
+    #     group.add_argument(
+    #         "--use_pai",
+    #         type=str2bool,
+    #         default=False,
+    #         help="flag to indicate whether training on PAI",
+    #     )
+    #     group.add_argument(
+    #         "--simple_ddp",
+    #         type=str2bool,
+    #         default=False,
+    #     )
+    #     group.add_argument(
+    #         "--num_worker_count",
+    #         type=int,
+    #         default=1,
+    #         help="The number of machines on PAI.",
+    #     )
+    #     group.add_argument(
+    #         "--access_key_id",
+    #         type=str,
+    #         default=None,
+    #         help="The username for oss.",
+    #     )
+    #     group.add_argument(
+    #         "--access_key_secret",
+    #         type=str,
+    #         default=None,
+    #         help="The password for oss.",
+    #     )
+    #     group.add_argument(
+    #         "--endpoint",
+    #         type=str,
+    #         default=None,
+    #         help="The endpoint for oss.",
+    #     )
+    #     group.add_argument(
+    #         "--bucket_name",
+    #         type=str,
+    #         default=None,
+    #         help="The bucket name for oss.",
+    #     )
+    #     group.add_argument(
+    #         "--oss_bucket",
+    #         default=None,
+    #         help="oss bucket.",
+    #     )
+    #
+    #     cls.trainer.add_arguments(parser)
+    #     cls.add_task_arguments(parser)
+    #
+    #     assert check_return_type(parser)
+    #     return parser
+
     @classmethod
     def build_optimizers(
             cls,
@@ -1172,7 +1868,8 @@
                     args.batch_bins = args.batch_bins * args.ngpu
 
         # filter samples if wav.scp and text are mismatch
-        if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
+        if (
+                args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
             if not args.simple_ddp or distributed_option.dist_rank == 0:
                 filter_wav_text(args.data_dir, args.train_set)
                 filter_wav_text(args.data_dir, args.dev_set)
@@ -1181,8 +1878,10 @@
 
         if args.train_shape_file is None and args.dataset_type == "small":
             if not args.simple_ddp or distributed_option.dist_rank == 0:
-                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
-                calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
+                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min,
+                           args.speech_length_max)
+                calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min,
+                           args.speech_length_max)
             if args.simple_ddp:
                 dist.barrier()
             args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
@@ -1244,9 +1943,9 @@
 
         # 2. Build model
         model = cls.build_model(args=args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model = model.to(
             dtype=getattr(torch, args.train_dtype),
@@ -1374,15 +2073,21 @@
             if args.dataset_type == "large":
                 from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
                 train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
-                                                   frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
-                                                   seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
-                                                   punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
+                                                   frontend_conf=args.frontend_conf if hasattr(args,
+                                                                                               "frontend_conf") else None,
+                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
+                                                                                               "seg_dict_file") else None,
+                                                   punc_dict_file=args.punc_list if hasattr(args,
+                                                                                            "punc_list") else None,
                                                    bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
                                                    mode="train")
-                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, 
-                                                   frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
-                                                   seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
-                                                   punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
+                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
+                                                   frontend_conf=args.frontend_conf if hasattr(args,
+                                                                                               "frontend_conf") else None,
+                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
+                                                                                               "seg_dict_file") else None,
+                                                   punc_dict_file=args.punc_list if hasattr(args,
+                                                                                            "punc_list") else None,
                                                    bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
                                                    mode="eval")
             elif args.dataset_type == "small":
@@ -1929,7 +2634,7 @@
             model_file: Union[Path, str] = None,
             cmvn_file: Union[Path, str] = None,
             device: str = "cpu",
-    ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
+    ) -> Tuple[FunASRModel, argparse.Namespace]:
         """Build model from the files.
 
         This method is used for inference or fine-tuning.
@@ -1956,9 +2661,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model.to(device)
         if model_file is not None:

--
Gitblit v1.9.1