From d1c9782515df60bf7578075c649c45b224106cfe Mon Sep 17 00:00:00 2001
From: aky15 <ankeyuthu@gmail.com>
Date: 星期三, 12 四月 2023 17:51:21 +0800
Subject: [PATCH] Delete abs_task.py

---
 /dev/null | 1958 -----------------------------------------------------------
 1 files changed, 0 insertions(+), 1,958 deletions(-)

diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
deleted file mode 100644
index d4b8a72..0000000
--- a/funasr/tasks/abs_task.py
+++ /dev/null
@@ -1,1958 +0,0 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
-
-"""Abstract task module."""
-import argparse
-import functools
-import logging
-import os
-import sys
-from abc import ABC
-from abc import abstractmethod
-from dataclasses import dataclass
-from distutils.version import LooseVersion
-from io import BytesIO
-from pathlib import Path
-from typing import Any
-from typing import Callable
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import humanfriendly
-import numpy as np
-import torch
-import torch.distributed as dist
-import torch.multiprocessing
-import torch.nn
-import torch.optim
-import yaml
-from torch.utils.data import DataLoader
-from typeguard import check_argument_types
-from typeguard import check_return_type
-
-from funasr import __version__
-from funasr.datasets.dataset import AbsDataset
-from funasr.datasets.dataset import DATA_TYPES
-from funasr.datasets.dataset import ESPnetDataset
-from funasr.datasets.iterable_dataset import IterableESPnetDataset
-from funasr.iterators.abs_iter_factory import AbsIterFactory
-from funasr.iterators.chunk_iter_factory import ChunkIterFactory
-from funasr.iterators.multiple_iter_factory import MultipleIterFactory
-from funasr.iterators.sequence_iter_factory import SequenceIterFactory
-from funasr.main_funcs.collect_stats import collect_stats
-from funasr.optimizers.sgd import SGD
-from funasr.optimizers.fairseq_adam import FairseqAdam
-from funasr.samplers.build_batch_sampler import BATCH_TYPES
-from funasr.samplers.build_batch_sampler import build_batch_sampler
-from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
-from funasr.schedulers.noam_lr import NoamLR
-from funasr.schedulers.warmup_lr import WarmupLR
-from funasr.schedulers.tri_stage_scheduler import TriStageLR
-from funasr.torch_utils.load_pretrained_model import load_pretrained_model
-from funasr.torch_utils.model_summary import model_summary
-from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
-from funasr.train.class_choices import ClassChoices
-from funasr.train.distributed_utils import DistributedOption
-from funasr.train.trainer import Trainer
-from funasr.utils import config_argparse
-from funasr.utils.build_dataclass import build_dataclass
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.get_default_kwargs import get_default_kwargs
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.types import humanfriendly_parse_size_or_none
-from funasr.utils.types import int_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_int
-from funasr.utils.types import str_or_none
-from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
-from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
-
-try:
-    import wandb
-except Exception:
-    wandb = None
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.5.0"):
-    pass
-else:
-    pass
-
-optim_classes = dict(
-    adam=torch.optim.Adam,
-    fairseq_adam=FairseqAdam,
-    adamw=torch.optim.AdamW,
-    sgd=SGD,
-    adadelta=torch.optim.Adadelta,
-    adagrad=torch.optim.Adagrad,
-    adamax=torch.optim.Adamax,
-    asgd=torch.optim.ASGD,
-    lbfgs=torch.optim.LBFGS,
-    rmsprop=torch.optim.RMSprop,
-    rprop=torch.optim.Rprop,
-)
-if LooseVersion(torch.__version__) >= LooseVersion("1.10.0"):
-    # From 1.10.0, RAdam is officially supported
-    optim_classes.update(
-        radam=torch.optim.RAdam,
-    )
-try:
-    import torch_optimizer
-
-    optim_classes.update(
-        accagd=torch_optimizer.AccSGD,
-        adabound=torch_optimizer.AdaBound,
-        adamod=torch_optimizer.AdaMod,
-        diffgrad=torch_optimizer.DiffGrad,
-        lamb=torch_optimizer.Lamb,
-        novograd=torch_optimizer.NovoGrad,
-        pid=torch_optimizer.PID,
-        # torch_optimizer<=0.0.1a10 doesn't support
-        # qhadam=torch_optimizer.QHAdam,
-        qhm=torch_optimizer.QHM,
-        sgdw=torch_optimizer.SGDW,
-        yogi=torch_optimizer.Yogi,
-    )
-    if LooseVersion(torch_optimizer.__version__) < LooseVersion("0.2.0"):
-        # From 0.2.0, RAdam is dropped
-        optim_classes.update(
-            radam=torch_optimizer.RAdam,
-        )
-    del torch_optimizer
-except ImportError:
-    pass
-try:
-    import apex
-
-    optim_classes.update(
-        fusedadam=apex.optimizers.FusedAdam,
-        fusedlamb=apex.optimizers.FusedLAMB,
-        fusednovograd=apex.optimizers.FusedNovoGrad,
-        fusedsgd=apex.optimizers.FusedSGD,
-    )
-    del apex
-except ImportError:
-    pass
-try:
-    import fairscale
-except ImportError:
-    fairscale = None
-
-scheduler_classes = dict(
-    ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
-    lambdalr=torch.optim.lr_scheduler.LambdaLR,
-    steplr=torch.optim.lr_scheduler.StepLR,
-    multisteplr=torch.optim.lr_scheduler.MultiStepLR,
-    exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
-    CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
-    noamlr=NoamLR,
-    warmuplr=WarmupLR,
-    tri_stage=TriStageLR,
-    cycliclr=torch.optim.lr_scheduler.CyclicLR,
-    onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
-    CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
-)
-# To lower keys
-optim_classes = {k.lower(): v for k, v in optim_classes.items()}
-scheduler_classes = {k.lower(): v for k, v in scheduler_classes.items()}
-
-
-@dataclass
-class IteratorOptions:
-    preprocess_fn: callable
-    collate_fn: callable
-    data_path_and_name_and_type: list
-    shape_files: list
-    batch_size: int
-    batch_bins: int
-    batch_type: str
-    max_cache_size: float
-    max_cache_fd: int
-    distributed: bool
-    num_batches: Optional[int]
-    num_iters_per_epoch: Optional[int]
-    train: bool
-
-
-class AbsTask(ABC):
-    # Use @staticmethod, or @classmethod,
-    # instead of instance method to avoid God classes
-
-    # If you need more than one optimizers, change this value in inheritance
-    num_optimizers: int = 1
-    trainer = Trainer
-    class_choices_list: List[ClassChoices] = []
-    finetune_args: None
-
-    def __init__(self):
-        raise RuntimeError("This class can't be instantiated.")
-
-    @classmethod
-    @abstractmethod
-    def add_task_arguments(cls, parser: argparse.ArgumentParser):
-        pass
-
-    @classmethod
-    @abstractmethod
-    def build_collate_fn(
-            cls, args: argparse.Namespace, train: bool
-    ) -> Callable[[Sequence[Dict[str, np.ndarray]]], Dict[str, torch.Tensor]]:
-        """Return "collate_fn", which is a callable object and given to DataLoader.
-
-        >>> from torch.utils.data import DataLoader
-        >>> loader = DataLoader(collate_fn=cls.build_collate_fn(args, train=True), ...)
-
-        In many cases, you can use our common collate_fn.
-        """
-        raise NotImplementedError
-
-    @classmethod
-    @abstractmethod
-    def build_preprocess_fn(
-            cls, args: argparse.Namespace, train: bool
-    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
-        raise NotImplementedError
-
-    @classmethod
-    @abstractmethod
-    def required_data_names(
-            cls, train: bool = True, inference: bool = False
-    ) -> Tuple[str, ...]:
-        """Define the required names by Task
-
-        This function is used by
-        >>> cls.check_task_requirements()
-        If your model is defined as following,
-
-        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
-        >>> class Model(AbsESPnetModel):
-        ...     def forward(self, input, output, opt=None):  pass
-
-        then "required_data_names" should be as
-
-        >>> required_data_names = ('input', 'output')
-        """
-        raise NotImplementedError
-
-    @classmethod
-    @abstractmethod
-    def optional_data_names(
-            cls, train: bool = True, inference: bool = False
-    ) -> Tuple[str, ...]:
-        """Define the optional names by Task
-
-        This function is used by
-        >>> cls.check_task_requirements()
-        If your model is defined as follows,
-
-        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
-        >>> class Model(AbsESPnetModel):
-        ...     def forward(self, input, output, opt=None):  pass
-
-        then "optional_data_names" should be as
-
-        >>> optional_data_names = ('opt',)
-        """
-        raise NotImplementedError
-
-    @classmethod
-    @abstractmethod
-    def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel:
-        raise NotImplementedError
-
-    @classmethod
-    def get_parser(cls) -> config_argparse.ArgumentParser:
-        assert check_argument_types()
-
-        class ArgumentDefaultsRawTextHelpFormatter(
-            argparse.RawTextHelpFormatter,
-            argparse.ArgumentDefaultsHelpFormatter,
-        ):
-            pass
-
-        parser = config_argparse.ArgumentParser(
-            description="base parser",
-            formatter_class=ArgumentDefaultsRawTextHelpFormatter,
-        )
-
-        # NOTE(kamo): Use '_' instead of '-' to avoid confusion.
-        #  I think '-' looks really confusing if it's written in yaml.
-
-        # NOTE(kamo): add_arguments(..., required=True) can't be used
-        #  to provide --print_config mode. Instead of it, do as
-        # parser.set_defaults(required=["output_dir"])
-
-        group = parser.add_argument_group("Common configuration")
-
-        group.add_argument(
-            "--print_config",
-            action="store_true",
-            help="Print the config file and exit",
-        )
-        group.add_argument(
-            "--log_level",
-            type=lambda x: x.upper(),
-            default="INFO",
-            choices=("ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-            help="The verbose level of logging",
-        )
-        group.add_argument(
-            "--dry_run",
-            type=str2bool,
-            default=False,
-            help="Perform process without training",
-        )
-        group.add_argument(
-            "--iterator_type",
-            type=str,
-            choices=["sequence", "chunk", "task", "none"],
-            default="sequence",
-            help="Specify iterator type",
-        )
-
-        group.add_argument("--output_dir", type=str_or_none, default=None)
-        group.add_argument(
-            "--ngpu",
-            type=int,
-            default=0,
-            help="The number of gpus. 0 indicates CPU mode",
-        )
-        group.add_argument("--seed", type=int, default=0, help="Random seed")
-        group.add_argument(
-            "--num_workers",
-            type=int,
-            default=1,
-            help="The number of workers used for DataLoader",
-        )
-        group.add_argument(
-            "--num_att_plot",
-            type=int,
-            default=3,
-            help="The number images to plot the outputs from attention. "
-                 "This option makes sense only when attention-based model. "
-                 "We can also disable the attention plot by setting it 0",
-        )
-
-        group = parser.add_argument_group("distributed training related")
-        group.add_argument(
-            "--dist_backend",
-            default="nccl",
-            type=str,
-            help="distributed backend",
-        )
-        group.add_argument(
-            "--dist_init_method",
-            type=str,
-            default="env://",
-            help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
-                 '"WORLD_SIZE", and "RANK" are referred.',
-        )
-        group.add_argument(
-            "--dist_world_size",
-            default=None,
-            type=int_or_none,
-            help="number of nodes for distributed training",
-        )
-        group.add_argument(
-            "--dist_rank",
-            type=int_or_none,
-            default=None,
-            help="node rank for distributed training",
-        )
-        group.add_argument(
-            # Not starting with "dist_" for compatibility to launch.py
-            "--local_rank",
-            type=int_or_none,
-            default=None,
-            help="local rank for distributed training. This option is used if "
-                 "--multiprocessing_distributed=false",
-        )
-        group.add_argument(
-            "--dist_master_addr",
-            default=None,
-            type=str_or_none,
-            help="The master address for distributed training. "
-                 "This value is used when dist_init_method == 'env://'",
-        )
-        group.add_argument(
-            "--dist_master_port",
-            default=None,
-            type=int_or_none,
-            help="The master port for distributed training"
-                 "This value is used when dist_init_method == 'env://'",
-        )
-        group.add_argument(
-            "--dist_launcher",
-            default=None,
-            type=str_or_none,
-            choices=["slurm", "mpi", None],
-            help="The launcher type for distributed training",
-        )
-        group.add_argument(
-            "--multiprocessing_distributed",
-            default=False,
-            type=str2bool,
-            help="Use multi-processing distributed training to launch "
-                 "N processes per node, which has N GPUs. This is the "
-                 "fastest way to use PyTorch for either single node or "
-                 "multi node data parallel training",
-        )
-        group.add_argument(
-            "--unused_parameters",
-            type=str2bool,
-            default=False,
-            help="Whether to use the find_unused_parameters in "
-                 "torch.nn.parallel.DistributedDataParallel ",
-        )
-        group.add_argument(
-            "--sharded_ddp",
-            default=False,
-            type=str2bool,
-            help="Enable sharded training provided by fairscale",
-        )
-
-        group = parser.add_argument_group("cudnn mode related")
-        group.add_argument(
-            "--cudnn_enabled",
-            type=str2bool,
-            default=torch.backends.cudnn.enabled,
-            help="Enable CUDNN",
-        )
-        group.add_argument(
-            "--cudnn_benchmark",
-            type=str2bool,
-            default=torch.backends.cudnn.benchmark,
-            help="Enable cudnn-benchmark mode",
-        )
-        group.add_argument(
-            "--cudnn_deterministic",
-            type=str2bool,
-            default=True,
-            help="Enable cudnn-deterministic mode",
-        )
-
-        group = parser.add_argument_group("collect stats mode related")
-        group.add_argument(
-            "--collect_stats",
-            type=str2bool,
-            default=False,
-            help='Perform on "collect stats" mode',
-        )
-        group.add_argument(
-            "--write_collected_feats",
-            type=str2bool,
-            default=False,
-            help='Write the output features from the model when "collect stats" mode',
-        )
-
-        group = parser.add_argument_group("Trainer related")
-        group.add_argument(
-            "--max_epoch",
-            type=int,
-            default=40,
-            help="The maximum number epoch to train",
-        )
-        group.add_argument(
-            "--max_update",
-            type=int,
-            default=sys.maxsize,
-            help="The maximum number update step to train",
-        )
-        parser.add_argument(
-            "--batch_interval",
-            type=int,
-            default=10000,
-            help="The batch interval for saving model.",
-        )
-        group.add_argument(
-            "--patience",
-            type=int_or_none,
-            default=None,
-            help="Number of epochs to wait without improvement "
-                 "before stopping the training",
-        )
-        group.add_argument(
-            "--val_scheduler_criterion",
-            type=str,
-            nargs=2,
-            default=("valid", "loss"),
-            help="The criterion used for the value given to the lr scheduler. "
-                 'Give a pair referring the phase, "train" or "valid",'
-                 'and the criterion name. The mode specifying "min" or "max" can '
-                 "be changed by --scheduler_conf",
-        )
-        group.add_argument(
-            "--early_stopping_criterion",
-            type=str,
-            nargs=3,
-            default=("valid", "loss", "min"),
-            help="The criterion used for judging of early stopping. "
-                 'Give a pair referring the phase, "train" or "valid",'
-                 'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
-        )
-        group.add_argument(
-            "--best_model_criterion",
-            type=str2triple_str,
-            nargs="+",
-            default=[
-                ("train", "loss", "min"),
-                ("valid", "loss", "min"),
-                ("train", "acc", "max"),
-                ("valid", "acc", "max"),
-            ],
-            help="The criterion used for judging of the best model. "
-                 'Give a pair referring the phase, "train" or "valid",'
-                 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
-        )
-        group.add_argument(
-            "--keep_nbest_models",
-            type=int,
-            nargs="+",
-            default=[10],
-            help="Remove previous snapshots excluding the n-best scored epochs",
-        )
-        group.add_argument(
-            "--nbest_averaging_interval",
-            type=int,
-            default=0,
-            help="The epoch interval to apply model averaging and save nbest models",
-        )
-        group.add_argument(
-            "--grad_clip",
-            type=float,
-            default=5.0,
-            help="Gradient norm threshold to clip",
-        )
-        group.add_argument(
-            "--grad_clip_type",
-            type=float,
-            default=2.0,
-            help="The type of the used p-norm for gradient clip. Can be inf",
-        )
-        group.add_argument(
-            "--grad_noise",
-            type=str2bool,
-            default=False,
-            help="The flag to switch to use noise injection to "
-                 "gradients during training",
-        )
-        group.add_argument(
-            "--accum_grad",
-            type=int,
-            default=1,
-            help="The number of gradient accumulation",
-        )
-        group.add_argument(
-            "--no_forward_run",
-            type=str2bool,
-            default=False,
-            help="Just only iterating data loading without "
-                 "model forwarding and training",
-        )
-        group.add_argument(
-            "--resume",
-            type=str2bool,
-            default=False,
-            help="Enable resuming if checkpoint is existing",
-        )
-        group.add_argument(
-            "--train_dtype",
-            default="float32",
-            choices=["float16", "float32", "float64"],
-            help="Data type for training.",
-        )
-        group.add_argument(
-            "--use_amp",
-            type=str2bool,
-            default=False,
-            help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
-        )
-        group.add_argument(
-            "--log_interval",
-            type=int_or_none,
-            default=None,
-            help="Show the logs every the number iterations in each epochs at the "
-                 "training phase. If None is given, it is decided according the number "
-                 "of training samples automatically .",
-        )
-        group.add_argument(
-            "--use_tensorboard",
-            type=str2bool,
-            default=True,
-            help="Enable tensorboard logging",
-        )
-        group.add_argument(
-            "--use_wandb",
-            type=str2bool,
-            default=False,
-            help="Enable wandb logging",
-        )
-        group.add_argument(
-            "--wandb_project",
-            type=str,
-            default=None,
-            help="Specify wandb project",
-        )
-        group.add_argument(
-            "--wandb_id",
-            type=str,
-            default=None,
-            help="Specify wandb id",
-        )
-        group.add_argument(
-            "--wandb_entity",
-            type=str,
-            default=None,
-            help="Specify wandb entity",
-        )
-        group.add_argument(
-            "--wandb_name",
-            type=str,
-            default=None,
-            help="Specify wandb run name",
-        )
-        group.add_argument(
-            "--wandb_model_log_interval",
-            type=int,
-            default=-1,
-            help="Set the model log period",
-        )
-        group.add_argument(
-            "--detect_anomaly",
-            type=str2bool,
-            default=False,
-            help="Set torch.autograd.set_detect_anomaly",
-        )
-
-        group = parser.add_argument_group("Pretraining model related")
-        group.add_argument("--pretrain_path", help="This option is obsoleted")
-        group.add_argument(
-            "--init_param",
-            type=str,
-            default=[],
-            nargs="*",
-            help="Specify the file path used for initialization of parameters. "
-                 "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
-                 "where file_path is the model file path, "
-                 "src_key specifies the key of model states to be used in the model file, "
-                 "dst_key specifies the attribute of the model to be initialized, "
-                 "and exclude_keys excludes keys of model states for the initialization."
-                 "e.g.\n"
-                 "  # Load all parameters"
-                 "  --init_param some/where/model.pb\n"
-                 "  # Load only decoder parameters"
-                 "  --init_param some/where/model.pb:decoder:decoder\n"
-                 "  # Load only decoder parameters excluding decoder.embed"
-                 "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
-                 "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
-        )
-        group.add_argument(
-            "--ignore_init_mismatch",
-            type=str2bool,
-            default=False,
-            help="Ignore size mismatch when loading pre-trained model",
-        )
-        group.add_argument(
-            "--freeze_param",
-            type=str,
-            default=[],
-            nargs="*",
-            help="Freeze parameters",
-        )
-
-        group = parser.add_argument_group("BatchSampler related")
-        group.add_argument(
-            "--num_iters_per_epoch",
-            type=int_or_none,
-            default=None,
-            help="Restrict the number of iterations for training per epoch",
-        )
-        group.add_argument(
-            "--batch_size",
-            type=int,
-            default=20,
-            help="The mini-batch size used for training. Used if batch_type='unsorted',"
-                 " 'sorted', or 'folded'.",
-        )
-        group.add_argument(
-            "--valid_batch_size",
-            type=int_or_none,
-            default=None,
-            help="If not given, the value of --batch_size is used",
-        )
-        group.add_argument(
-            "--batch_bins",
-            type=int,
-            default=1000000,
-            help="The number of batch bins. Used if batch_type='length' or 'numel'",
-        )
-        group.add_argument(
-            "--valid_batch_bins",
-            type=int_or_none,
-            default=None,
-            help="If not given, the value of --batch_bins is used",
-        )
-
-        group.add_argument("--train_shape_file", type=str, action="append", default=[])
-        group.add_argument("--valid_shape_file", type=str, action="append", default=[])
-
-        group = parser.add_argument_group("Sequence iterator related")
-        _batch_type_help = ""
-        for key, value in BATCH_TYPES.items():
-            _batch_type_help += f'"{key}":\n{value}\n'
-        group.add_argument(
-            "--batch_type",
-            type=str,
-            default="length",
-            choices=list(BATCH_TYPES),
-            help=_batch_type_help,
-        )
-        group.add_argument(
-            "--valid_batch_type",
-            type=str_or_none,
-            default=None,
-            choices=list(BATCH_TYPES) + [None],
-            help="If not given, the value of --batch_type is used",
-        )
-        group.add_argument(
-            "--speech_length_min",
-            type=int,
-            default=-1,
-            help="speech length min",
-        )
-        group.add_argument(
-            "--speech_length_max",
-            type=int,
-            default=-1,
-            help="speech length max",
-        )
-        group.add_argument("--fold_length", type=int, action="append", default=[])
-        group.add_argument(
-            "--sort_in_batch",
-            type=str,
-            default="descending",
-            choices=["descending", "ascending"],
-            help="Sort the samples in each mini-batches by the sample "
-                 'lengths. To enable this, "shape_file" must have the length information.',
-        )
-        group.add_argument(
-            "--sort_batch",
-            type=str,
-            default="descending",
-            choices=["descending", "ascending"],
-            help="Sort mini-batches by the sample lengths",
-        )
-        group.add_argument(
-            "--multiple_iterator",
-            type=str2bool,
-            default=False,
-            help="Use multiple iterator mode",
-        )
-
-        group = parser.add_argument_group("Chunk iterator related")
-        group.add_argument(
-            "--chunk_length",
-            type=str_or_int,
-            default=500,
-            help="Specify chunk length. e.g. '300', '300,400,500', or '300-400'."
-                 "If multiple numbers separated by command are given, "
-                 "one of them is selected randomly for each samples. "
-                 "If two numbers are given with '-', it indicates the range of the choices. "
-                 "Note that if the sequence length is shorter than the all chunk_lengths, "
-                 "the sample is discarded. ",
-        )
-        group.add_argument(
-            "--chunk_shift_ratio",
-            type=float,
-            default=0.5,
-            help="Specify the shift width of chunks. If it's less than 1, "
-                 "allows the overlapping and if bigger than 1, there are some gaps "
-                 "between each chunk.",
-        )
-        group.add_argument(
-            "--num_cache_chunks",
-            type=int,
-            default=1024,
-            help="Shuffle in the specified number of chunks and generate mini-batches "
-                 "More larger this value, more randomness can be obtained.",
-        )
-
-        group = parser.add_argument_group("Dataset related")
-        _data_path_and_name_and_type_help = (
-            "Give three words splitted by comma. It's used for the training data. "
-            "e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. "
-            "The first value, some/path/a.scp, indicates the file path, "
-            "and the second, foo, is the key name used for the mini-batch data, "
-            "and the last, sound, decides the file type. "
-            "This option is repeatable, so you can input any number of features "
-            "for your task. Supported file types are as follows:\n\n"
-        )
-        for key, dic in DATA_TYPES.items():
-            _data_path_and_name_and_type_help += f'"{key}":\n{dic["help"]}\n\n'
-
-        # for large dataset
-        group.add_argument(
-            "--dataset_type",
-            type=str,
-            default="small",
-            help="whether to use dataloader for large dataset",
-        )
-        parser.add_argument(
-            "--dataset_conf",
-            action=NestedDictAction,
-            default=dict(),
-            help=f"The keyword arguments for dataset",
-        )
-        group.add_argument(
-            "--train_data_file",
-            type=str,
-            default=None,
-            help="train_list for large dataset",
-        )
-        group.add_argument(
-            "--valid_data_file",
-            type=str,
-            default=None,
-            help="valid_list for large dataset",
-        )
-
-        group.add_argument(
-            "--train_data_path_and_name_and_type",
-            type=str2triple_str,
-            action="append",
-            default=[],
-            help=_data_path_and_name_and_type_help,
-        )
-        group.add_argument(
-            "--valid_data_path_and_name_and_type",
-            type=str2triple_str,
-            action="append",
-            default=[],
-        )
-        group.add_argument(
-            "--allow_variable_data_keys",
-            type=str2bool,
-            default=False,
-            help="Allow the arbitrary keys for mini-batch with ignoring "
-                 "the task requirements",
-        )
-        group.add_argument(
-            "--max_cache_size",
-            type=humanfriendly.parse_size,
-            default=0.0,
-            help="The maximum cache size for data loader. e.g. 10MB, 20GB.",
-        )
-        group.add_argument(
-            "--max_cache_fd",
-            type=int,
-            default=32,
-            help="The maximum number of file descriptors to be kept "
-                 "as opened for ark files. "
-                 "This feature is only valid when data type is 'kaldi_ark'.",
-        )
-        group.add_argument(
-            "--valid_max_cache_size",
-            type=humanfriendly_parse_size_or_none,
-            default=None,
-            help="The maximum cache size for validation data loader. e.g. 10MB, 20GB. "
-                 "If None, the 5 percent size of --max_cache_size",
-        )
-
-        group = parser.add_argument_group("Optimizer related")
-        for i in range(1, cls.num_optimizers + 1):
-            suf = "" if i == 1 else str(i)
-            group.add_argument(
-                f"--optim{suf}",
-                type=lambda x: x.lower(),
-                default="adadelta",
-                choices=list(optim_classes),
-                help="The optimizer type",
-            )
-            group.add_argument(
-                f"--optim{suf}_conf",
-                action=NestedDictAction,
-                default=dict(),
-                help="The keyword arguments for optimizer",
-            )
-            group.add_argument(
-                f"--scheduler{suf}",
-                type=lambda x: str_or_none(x.lower()),
-                default=None,
-                choices=list(scheduler_classes) + [None],
-                help="The lr scheduler type",
-            )
-            group.add_argument(
-                f"--scheduler{suf}_conf",
-                action=NestedDictAction,
-                default=dict(),
-                help="The keyword arguments for lr scheduler",
-            )
-
-        # for training on PAI
-        group = parser.add_argument_group("PAI training related")
-        group.add_argument(
-            "--use_pai",
-            type=str2bool,
-            default=False,
-            help="flag to indicate whether training on PAI",
-        )
-        group.add_argument(
-            "--simple_ddp",
-            type=str2bool,
-            default=False,
-        )
-        group.add_argument(
-            "--num_worker_count",
-            type=int,
-            default=1,
-            help="The number of machines on PAI.",
-        )
-        group.add_argument(
-            "--access_key_id",
-            type=str,
-            default=None,
-            help="The username for oss.",
-        )
-        group.add_argument(
-            "--access_key_secret",
-            type=str,
-            default=None,
-            help="The password for oss.",
-        )
-        group.add_argument(
-            "--endpoint",
-            type=str,
-            default=None,
-            help="The endpoint for oss.",
-        )
-        group.add_argument(
-            "--bucket_name",
-            type=str,
-            default=None,
-            help="The bucket name for oss.",
-        )
-        group.add_argument(
-            "--oss_bucket",
-            default=None,
-            help="oss bucket.",
-        )
-
-        cls.trainer.add_arguments(parser)
-        cls.add_task_arguments(parser)
-
-        assert check_return_type(parser)
-        return parser
-
-    @classmethod
-    def build_optimizers(
-            cls,
-            args: argparse.Namespace,
-            model: torch.nn.Module,
-    ) -> List[torch.optim.Optimizer]:
-        if cls.num_optimizers != 1:
-            raise RuntimeError(
-                "build_optimizers() must be overridden if num_optimizers != 1"
-            )
-
-        optim_class = optim_classes.get(args.optim)
-        if optim_class is None:
-            raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
-        if args.sharded_ddp:
-            if fairscale is None:
-                raise RuntimeError("Requiring fairscale. Do 'pip install fairscale'")
-            optim = fairscale.optim.oss.OSS(
-                params=model.parameters(), optim=optim_class, **args.optim_conf
-            )
-        else:
-            optim = optim_class(model.parameters(), **args.optim_conf)
-
-        optimizers = [optim]
-        return optimizers
-
-    @classmethod
-    def exclude_opts(cls) -> Tuple[str, ...]:
-        """The options not to be shown by --print_config"""
-        return "required", "print_config", "config", "ngpu"
-
-    @classmethod
-    def get_default_config(cls) -> Dict[str, Any]:
-        """Return the configuration as dict.
-
-        This method is used by print_config()
-        """
-
-        def get_class_type(name: str, classes: dict):
-            _cls = classes.get(name)
-            if _cls is None:
-                raise ValueError(f"must be one of {list(classes)}: {name}")
-            return _cls
-
-        # This method is used only for --print_config
-        assert check_argument_types()
-        parser = cls.get_parser()
-        args, _ = parser.parse_known_args()
-        config = vars(args)
-        # Excludes the options not to be shown
-        for k in AbsTask.exclude_opts():
-            config.pop(k)
-
-        for i in range(1, cls.num_optimizers + 1):
-            suf = "" if i == 1 else str(i)
-            name = config[f"optim{suf}"]
-            optim_class = get_class_type(name, optim_classes)
-            conf = get_default_kwargs(optim_class)
-            # Overwrite the default by the arguments,
-            conf.update(config[f"optim{suf}_conf"])
-            # and set it again
-            config[f"optim{suf}_conf"] = conf
-
-            name = config[f"scheduler{suf}"]
-            if name is not None:
-                scheduler_class = get_class_type(name, scheduler_classes)
-                conf = get_default_kwargs(scheduler_class)
-                # Overwrite the default by the arguments,
-                conf.update(config[f"scheduler{suf}_conf"])
-                # and set it again
-                config[f"scheduler{suf}_conf"] = conf
-
-        for class_choices in cls.class_choices_list:
-            if getattr(args, class_choices.name) is not None:
-                class_obj = class_choices.get_class(getattr(args, class_choices.name))
-                conf = get_default_kwargs(class_obj)
-                name = class_choices.name
-                # Overwrite the default by the arguments,
-                conf.update(config[f"{name}_conf"])
-                # and set it again
-                config[f"{name}_conf"] = conf
-        return config
-
-    @classmethod
-    def check_required_command_args(cls, args: argparse.Namespace):
-        assert check_argument_types()
-        if hasattr(args, "required"):
-            for k in vars(args):
-                if "-" in k:
-                    raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
-
-            required = ", ".join(
-                f"--{a}" for a in args.required if getattr(args, a) is None
-            )
-
-            if len(required) != 0:
-                parser = cls.get_parser()
-                parser.print_help(file=sys.stderr)
-                p = Path(sys.argv[0]).name
-                print(file=sys.stderr)
-                print(
-                    f"{p}: error: the following arguments are required: " f"{required}",
-                    file=sys.stderr,
-                )
-                sys.exit(2)
-
-    @classmethod
-    def check_task_requirements(
-            cls,
-            dataset: Union[AbsDataset, IterableESPnetDataset],
-            allow_variable_data_keys: bool,
-            train: bool,
-            inference: bool = False,
-    ) -> None:
-        """Check if the dataset satisfy the requirement of current Task"""
-        assert check_argument_types()
-        mes = (
-            f"If you intend to use an additional input, modify "
-            f'"{cls.__name__}.required_data_names()" or '
-            f'"{cls.__name__}.optional_data_names()". '
-            f"Otherwise you need to set --allow_variable_data_keys true "
-        )
-
-        for k in cls.required_data_names(train, inference):
-            if not dataset.has_name(k):
-                raise RuntimeError(
-                    f'"{cls.required_data_names(train, inference)}" are required for'
-                    f' {cls.__name__}. but "{dataset.names()}" are input.\n{mes}'
-                )
-        if not allow_variable_data_keys:
-            task_keys = cls.required_data_names(
-                train, inference
-            ) + cls.optional_data_names(train, inference)
-            for k in dataset.names():
-                if k not in task_keys:
-                    raise RuntimeError(
-                        f"The data-name must be one of {task_keys} "
-                        f'for {cls.__name__}: "{k}" is not allowed.\n{mes}'
-                    )
-
-    @classmethod
-    def print_config(cls, file=sys.stdout) -> None:
-        assert check_argument_types()
-        # Shows the config: e.g. python train.py asr --print_config
-        config = cls.get_default_config()
-        file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False))
-
-    @classmethod
-    def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None):
-        assert check_argument_types()
-        print(get_commandline_args(), file=sys.stderr)
-        if args is None:
-            parser = cls.get_parser()
-            args = parser.parse_args(cmd)
-        args.version = __version__
-        if args.pretrain_path is not None:
-            raise RuntimeError("--pretrain_path is deprecated. Use --init_param")
-        if args.print_config:
-            cls.print_config()
-            sys.exit(0)
-        cls.check_required_command_args(args)
-
-        if not args.distributed or not args.multiprocessing_distributed:
-            cls.main_worker(args)
-        else:
-            assert args.ngpu > 1
-            cls.main_worker(args)
-
-    @classmethod
-    def run(cls):
-        assert hasattr(cls, "finetune_args")
-        args = cls.finetune_args
-        args.train_shape_file = None
-        if args.distributed:
-            args.simple_ddp = True
-        else:
-            args.simple_ddp = False
-            args.ngpu = 1
-        args.use_pai = False
-        args.batch_type = "length"
-        args.oss_bucket = None
-        args.input_size = None
-        cls.main_worker(args)
-
-    @classmethod
-    def main_worker(cls, args: argparse.Namespace):
-        assert check_argument_types()
-
-        # 0. Init distributed process
-        distributed_option = build_dataclass(DistributedOption, args)
-        # Setting distributed_option.dist_rank, etc.
-        if args.use_pai:
-            distributed_option.init_options_pai()
-        elif not args.simple_ddp:
-            distributed_option.init_options()
-
-        # Invoking torch.distributed.init_process_group
-        if args.use_pai:
-            distributed_option.init_torch_distributed_pai(args)
-        elif not args.simple_ddp:
-            distributed_option.init_torch_distributed(args)
-        elif args.distributed and args.simple_ddp:
-            distributed_option.init_torch_distributed_pai(args)
-            args.ngpu = dist.get_world_size()
-            if args.dataset_type == "small":
-                if args.batch_size is not None:
-                    args.batch_size = args.batch_size * args.ngpu
-                if args.batch_bins is not None:
-                    args.batch_bins = args.batch_bins * args.ngpu
-
-        # filter samples if wav.scp and text are mismatch
-        if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
-            if not args.simple_ddp or distributed_option.dist_rank == 0:
-                filter_wav_text(args.data_dir, args.train_set)
-                filter_wav_text(args.data_dir, args.dev_set)
-            if args.simple_ddp:
-                dist.barrier()
-
-        if args.train_shape_file is None and args.dataset_type == "small":
-            if not args.simple_ddp or distributed_option.dist_rank == 0:
-                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
-                calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
-            if args.simple_ddp:
-                dist.barrier()
-            args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
-            args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")]
-
-        if args.train_data_file is None and args.dataset_type == "large":
-            if not args.simple_ddp or distributed_option.dist_rank == 0:
-                generate_data_list(args.data_dir, args.train_set)
-                generate_data_list(args.data_dir, args.dev_set)
-            if args.simple_ddp:
-                dist.barrier()
-            args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list")
-            args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list")
-
-        # NOTE(kamo): Don't use logging before invoking logging.basicConfig()
-        if not distributed_option.distributed or distributed_option.dist_rank == 0:
-            if not distributed_option.distributed:
-                _rank = ""
-            else:
-                _rank = (
-                    f":{distributed_option.dist_rank}/"
-                    f"{distributed_option.dist_world_size}"
-                )
-
-            # NOTE(kamo):
-            # logging.basicConfig() is invoked in main_worker() instead of main()
-            # because it can be invoked only once in a process.
-            # FIXME(kamo): Should we use logging.getLogger()?
-            # BUGFIX: Remove previous handlers and reset log level
-            for handler in logging.root.handlers[:]:
-                logging.root.removeHandler(handler)
-            logging.basicConfig(
-                level=args.log_level,
-                format=f"[{os.uname()[1].split('.')[0]}]"
-                       f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-            )
-        else:
-            # BUGFIX: Remove previous handlers and reset log level
-            for handler in logging.root.handlers[:]:
-                logging.root.removeHandler(handler)
-            # Suppress logging if RANK != 0
-            logging.basicConfig(
-                level="ERROR",
-                format=f"[{os.uname()[1].split('.')[0]}]"
-                       f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-            )
-        logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
-                                                                       distributed_option.dist_rank,
-                                                                       distributed_option.local_rank))
-
-        # 1. Set random-seed
-        set_all_random_seed(args.seed)
-        torch.backends.cudnn.enabled = args.cudnn_enabled
-        torch.backends.cudnn.benchmark = args.cudnn_benchmark
-        torch.backends.cudnn.deterministic = args.cudnn_deterministic
-        if args.detect_anomaly:
-            logging.info("Invoking torch.autograd.set_detect_anomaly(True)")
-            torch.autograd.set_detect_anomaly(args.detect_anomaly)
-
-        # 2. Build model
-        model = cls.build_model(args=args)
-        if not isinstance(model, AbsESPnetModel):
-            raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
-            )
-        model = model.to(
-            dtype=getattr(torch, args.train_dtype),
-            device="cuda" if args.ngpu > 0 else "cpu",
-        )
-        for t in args.freeze_param:
-            for k, p in model.named_parameters():
-                if k.startswith(t + ".") or k == t:
-                    logging.info(f"Setting {k}.requires_grad = False")
-                    p.requires_grad = False
-
-        # 3. Build optimizer
-        optimizers = cls.build_optimizers(args, model=model)
-
-        # 4. Build schedulers
-        schedulers = []
-        for i, optim in enumerate(optimizers, 1):
-            suf = "" if i == 1 else str(i)
-            name = getattr(args, f"scheduler{suf}")
-            conf = getattr(args, f"scheduler{suf}_conf")
-            if name is not None:
-                cls_ = scheduler_classes.get(name)
-                if cls_ is None:
-                    raise ValueError(
-                        f"must be one of {list(scheduler_classes)}: {name}"
-                    )
-                scheduler = cls_(optim, **conf)
-            else:
-                scheduler = None
-
-            schedulers.append(scheduler)
-
-        logging.info(pytorch_cudnn_version())
-        logging.info(model_summary(model))
-        for i, (o, s) in enumerate(zip(optimizers, schedulers), 1):
-            suf = "" if i == 1 else str(i)
-            logging.info(f"Optimizer{suf}:\n{o}")
-            logging.info(f"Scheduler{suf}: {s}")
-
-        # 5. Dump "args" to config.yaml
-        # NOTE(kamo): "args" should be saved after object-buildings are done
-        #  because they are allowed to modify "args".
-        output_dir = Path(args.output_dir)
-        if not distributed_option.distributed or distributed_option.dist_rank == 0:
-            output_dir.mkdir(parents=True, exist_ok=True)
-            with (output_dir / "config.yaml").open("w", encoding="utf-8") as f:
-                logging.info(
-                    f'Saving the configuration in {output_dir / "config.yaml"}'
-                )
-                if args.use_pai:
-                    buffer = BytesIO()
-                    torch.save({"config": vars(args)}, buffer)
-                    args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
-                else:
-                    yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
-
-        if args.dry_run:
-            pass
-        elif args.collect_stats:
-            # Perform on collect_stats mode. This mode has two roles
-            # - Derive the length and dimension of all input data
-            # - Accumulate feats, square values, and the length for whitening
-
-            if args.valid_batch_size is None:
-                args.valid_batch_size = args.batch_size
-
-            if len(args.train_shape_file) != 0:
-                train_key_file = args.train_shape_file[0]
-            else:
-                train_key_file = None
-            if len(args.valid_shape_file) != 0:
-                valid_key_file = args.valid_shape_file[0]
-            else:
-                valid_key_file = None
-
-            collect_stats(
-                model=model,
-                train_iter=cls.build_streaming_iterator(
-                    data_path_and_name_and_type=args.train_data_path_and_name_and_type,
-                    key_file=train_key_file,
-                    batch_size=args.batch_size,
-                    dtype=args.train_dtype,
-                    num_workers=args.num_workers,
-                    allow_variable_data_keys=args.allow_variable_data_keys,
-                    ngpu=args.ngpu,
-                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
-                    collate_fn=cls.build_collate_fn(args, train=False),
-                ),
-                valid_iter=cls.build_streaming_iterator(
-                    data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
-                    key_file=valid_key_file,
-                    batch_size=args.valid_batch_size,
-                    dtype=args.train_dtype,
-                    num_workers=args.num_workers,
-                    allow_variable_data_keys=args.allow_variable_data_keys,
-                    ngpu=args.ngpu,
-                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
-                    collate_fn=cls.build_collate_fn(args, train=False),
-                ),
-                output_dir=output_dir,
-                ngpu=args.ngpu,
-                log_interval=args.log_interval,
-                write_collected_feats=args.write_collected_feats,
-            )
-        else:
-            logging.info("Training args: {}".format(args))
-            # 6. Loads pre-trained model
-            for p in args.init_param:
-                logging.info(f"Loading pretrained params from {p}")
-                load_pretrained_model(
-                    model=model,
-                    init_param=p,
-                    ignore_init_mismatch=args.ignore_init_mismatch,
-                    # NOTE(kamo): "cuda" for torch.load always indicates cuda:0
-                    #   in PyTorch<=1.4
-                    map_location=f"cuda:{torch.cuda.current_device()}"
-                    if args.ngpu > 0
-                    else "cpu",
-                    oss_bucket=args.oss_bucket,
-                )
-
-            # 7. Build iterator factories
-            if args.dataset_type == "large":
-                from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
-                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
-                                                   frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
-                                                   seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
-                                                   punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
-                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
-                                                   mode="train")
-                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, 
-                                                   frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
-                                                   seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
-                                                   punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
-                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
-                                                   mode="eval")
-            elif args.dataset_type == "small":
-                train_iter_factory = cls.build_iter_factory(
-                    args=args,
-                    distributed_option=distributed_option,
-                    mode="train",
-                )
-                valid_iter_factory = cls.build_iter_factory(
-                    args=args,
-                    distributed_option=distributed_option,
-                    mode="valid",
-                )
-            else:
-                raise ValueError(f"Not supported dataset_type={args.dataset_type}")
-
-            if args.scheduler == "tri_stage":
-                for scheduler in schedulers:
-                    scheduler.init_tri_stage_scheudler(max_update=args.max_update)
-
-            # 8. Start training
-            if args.use_wandb:
-                if wandb is None:
-                    raise RuntimeError("Please install wandb")
-
-                try:
-                    wandb.login()
-                except wandb.errors.UsageError:
-                    logging.info("wandb not configured! run `wandb login` to enable")
-                    args.use_wandb = False
-
-            if args.use_wandb:
-                if (
-                        not distributed_option.distributed
-                        or distributed_option.dist_rank == 0
-                ):
-                    if args.wandb_project is None:
-                        project = "FunASR_" + cls.__name__
-                    else:
-                        project = args.wandb_project
-
-                    if args.wandb_name is None:
-                        name = str(Path(".").resolve()).replace("/", "_")
-                    else:
-                        name = args.wandb_name
-
-                    wandb.init(
-                        entity=args.wandb_entity,
-                        project=project,
-                        name=name,
-                        dir=output_dir,
-                        id=args.wandb_id,
-                        resume="allow",
-                    )
-                    wandb.config.update(args)
-                else:
-                    # wandb also supports grouping for distributed training,
-                    # but we only logs aggregated data,
-                    # so it's enough to perform on rank0 node.
-                    args.use_wandb = False
-
-            # Don't give args to trainer.run() directly!!!
-            # Instead of it, define "Options" object and build here.
-            trainer_options = cls.trainer.build_options(args)
-            cls.trainer.run(
-                model=model,
-                optimizers=optimizers,
-                schedulers=schedulers,
-                train_iter_factory=train_iter_factory,
-                valid_iter_factory=valid_iter_factory,
-                trainer_options=trainer_options,
-                distributed_option=distributed_option,
-            )
-
-            if args.use_wandb and wandb.run:
-                wandb.finish()
-
-    @classmethod
-    def build_iter_options(
-            cls,
-            args: argparse.Namespace,
-            distributed_option: DistributedOption,
-            mode: str,
-    ):
-        if mode == "train":
-            preprocess_fn = cls.build_preprocess_fn(args, train=True)
-            collate_fn = cls.build_collate_fn(args, train=True)
-            data_path_and_name_and_type = args.train_data_path_and_name_and_type
-            shape_files = args.train_shape_file
-            batch_size = args.batch_size
-            batch_bins = args.batch_bins
-            batch_type = args.batch_type
-            max_cache_size = args.max_cache_size
-            max_cache_fd = args.max_cache_fd
-            distributed = distributed_option.distributed
-            num_batches = None
-            num_iters_per_epoch = args.num_iters_per_epoch
-            train = True
-
-        elif mode == "valid":
-            preprocess_fn = cls.build_preprocess_fn(args, train=False)
-            collate_fn = cls.build_collate_fn(args, train=False)
-            data_path_and_name_and_type = args.valid_data_path_and_name_and_type
-            shape_files = args.valid_shape_file
-
-            if args.valid_batch_type is None:
-                batch_type = args.batch_type
-            else:
-                batch_type = args.valid_batch_type
-            if args.valid_batch_size is None:
-                batch_size = args.batch_size
-            else:
-                batch_size = args.valid_batch_size
-            if args.valid_batch_bins is None:
-                batch_bins = args.batch_bins
-            else:
-                batch_bins = args.valid_batch_bins
-            if args.valid_max_cache_size is None:
-                # Cache 5% of maximum size for validation loader
-                max_cache_size = 0.05 * args.max_cache_size
-            else:
-                max_cache_size = args.valid_max_cache_size
-            max_cache_fd = args.max_cache_fd
-            distributed = distributed_option.distributed
-            num_batches = None
-            num_iters_per_epoch = None
-            train = False
-        else:
-            raise NotImplementedError(f"mode={mode}")
-
-        return IteratorOptions(
-            preprocess_fn=preprocess_fn,
-            collate_fn=collate_fn,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            shape_files=shape_files,
-            batch_type=batch_type,
-            batch_size=batch_size,
-            batch_bins=batch_bins,
-            num_batches=num_batches,
-            max_cache_size=max_cache_size,
-            max_cache_fd=max_cache_fd,
-            distributed=distributed,
-            num_iters_per_epoch=num_iters_per_epoch,
-            train=train,
-        )
-
-    @classmethod
-    def build_iter_factory(
-            cls,
-            args: argparse.Namespace,
-            distributed_option: DistributedOption,
-            mode: str,
-            kwargs: dict = None,
-    ) -> AbsIterFactory:
-        """Build a factory object of mini-batch iterator.
-
-        This object is invoked at every epochs to build the iterator for each epoch
-        as following:
-
-        >>> iter_factory = cls.build_iter_factory(...)
-        >>> for epoch in range(1, max_epoch):
-        ...     for keys, batch in iter_fatory.build_iter(epoch):
-        ...         model(**batch)
-
-        The mini-batches for each epochs are fully controlled by this class.
-        Note that the random seed used for shuffling is decided as "seed + epoch" and
-        the generated mini-batches can be reproduces when resuming.
-
-        Note that the definition of "epoch" doesn't always indicate
-        to run out of the whole training corpus.
-        "--num_iters_per_epoch" option restricts the number of iterations for each epoch
-        and the rest of samples for the originally epoch are left for the next epoch.
-        e.g. If The number of mini-batches equals to 4, the following two are same:
-
-        - 1 epoch without "--num_iters_per_epoch"
-        - 4 epoch with "--num_iters_per_epoch" == 4
-
-        """
-        assert check_argument_types()
-        iter_options = cls.build_iter_options(args, distributed_option, mode)
-
-        # Overwrite iter_options if any kwargs is given
-        if kwargs is not None:
-            for k, v in kwargs.items():
-                setattr(iter_options, k, v)
-
-        if args.iterator_type == "sequence":
-            return cls.build_sequence_iter_factory(
-                args=args,
-                iter_options=iter_options,
-                mode=mode,
-            )
-        elif args.iterator_type == "chunk":
-            return cls.build_chunk_iter_factory(
-                args=args,
-                iter_options=iter_options,
-                mode=mode,
-            )
-        elif args.iterator_type == "task":
-            return cls.build_task_iter_factory(
-                args=args,
-                iter_options=iter_options,
-                mode=mode,
-            )
-        else:
-            raise RuntimeError(f"Not supported: iterator_type={args.iterator_type}")
-
-    @classmethod
-    def build_sequence_iter_factory(
-            cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str
-    ) -> AbsIterFactory:
-        assert check_argument_types()
-
-        if args.frontend_conf is not None and "fs" in args.frontend_conf:
-            dest_sample_rate = args.frontend_conf["fs"]
-        else:
-            dest_sample_rate = 16000
-
-        dataset = ESPnetDataset(
-            iter_options.data_path_and_name_and_type,
-            float_dtype=args.train_dtype,
-            preprocess=iter_options.preprocess_fn,
-            max_cache_size=iter_options.max_cache_size,
-            max_cache_fd=iter_options.max_cache_fd,
-<<<<<<< HEAD
-            dest_sample_rate=args.frontend_conf["fs"] if args.frontend_conf else 16000,
-=======
-            dest_sample_rate=dest_sample_rate,
->>>>>>> main
-        )
-        cls.check_task_requirements(
-            dataset, args.allow_variable_data_keys, train=iter_options.train
-        )
-
-        if Path(
-                Path(iter_options.data_path_and_name_and_type[0][0]).parent, "utt2category"
-        ).exists():
-            utt2category_file = str(
-                Path(
-                    Path(iter_options.data_path_and_name_and_type[0][0]).parent,
-                    "utt2category",
-                )
-            )
-        else:
-            utt2category_file = None
-        batch_sampler = build_batch_sampler(
-            type=iter_options.batch_type,
-            shape_files=iter_options.shape_files,
-            fold_lengths=args.fold_length,
-            batch_size=iter_options.batch_size,
-            batch_bins=iter_options.batch_bins,
-            sort_in_batch=args.sort_in_batch,
-            sort_batch=args.sort_batch,
-            drop_last=False,
-            min_batch_size=torch.distributed.get_world_size()
-            if iter_options.distributed
-            else 1,
-            utt2category_file=utt2category_file,
-        )
-
-        batches = list(batch_sampler)
-        if iter_options.num_batches is not None:
-            batches = batches[: iter_options.num_batches]
-
-        bs_list = [len(batch) for batch in batches]
-
-        logging.info(f"[{mode}] dataset:\n{dataset}")
-        logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
-        logging.info(
-            f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
-            f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
-        )
-
-        if args.scheduler == "tri_stage" and mode == "train":
-            args.max_update = len(bs_list) * args.max_epoch
-            logging.info("Max update: {}".format(args.max_update))
-
-        if iter_options.distributed:
-            world_size = torch.distributed.get_world_size()
-            rank = torch.distributed.get_rank()
-            for batch in batches:
-                if len(batch) < world_size:
-                    raise RuntimeError(
-                        f"The batch-size must be equal or more than world_size: "
-                        f"{len(batch)} < {world_size}"
-                    )
-            batches = [batch[rank::world_size] for batch in batches]
-
-        return SequenceIterFactory(
-            dataset=dataset,
-            batches=batches,
-            seed=args.seed,
-            num_iters_per_epoch=iter_options.num_iters_per_epoch,
-            shuffle=iter_options.train,
-            num_workers=args.num_workers,
-            collate_fn=iter_options.collate_fn,
-            pin_memory=args.ngpu > 0,
-        )
-
-    @classmethod
-    def build_chunk_iter_factory(
-            cls,
-            args: argparse.Namespace,
-            iter_options: IteratorOptions,
-            mode: str,
-    ) -> AbsIterFactory:
-        assert check_argument_types()
-
-        dataset = ESPnetDataset(
-            iter_options.data_path_and_name_and_type,
-            float_dtype=args.train_dtype,
-            preprocess=iter_options.preprocess_fn,
-            max_cache_size=iter_options.max_cache_size,
-            max_cache_fd=iter_options.max_cache_fd,
-        )
-        cls.check_task_requirements(
-            dataset, args.allow_variable_data_keys, train=iter_options.train
-        )
-
-        if len(iter_options.shape_files) == 0:
-            key_file = iter_options.data_path_and_name_and_type[0][0]
-        else:
-            key_file = iter_options.shape_files[0]
-
-        batch_sampler = UnsortedBatchSampler(batch_size=1, key_file=key_file)
-        batches = list(batch_sampler)
-        if iter_options.num_batches is not None:
-            batches = batches[: iter_options.num_batches]
-        logging.info(f"[{mode}] dataset:\n{dataset}")
-
-        if iter_options.distributed:
-            world_size = torch.distributed.get_world_size()
-            rank = torch.distributed.get_rank()
-            if len(batches) < world_size:
-                raise RuntimeError("Number of samples is smaller than world_size")
-            if iter_options.batch_size < world_size:
-                raise RuntimeError("batch_size must be equal or more than world_size")
-
-            if rank < iter_options.batch_size % world_size:
-                batch_size = iter_options.batch_size // world_size + 1
-            else:
-                batch_size = iter_options.batch_size // world_size
-            num_cache_chunks = args.num_cache_chunks // world_size
-            # NOTE(kamo): Split whole corpus by sample numbers without considering
-            #   each of the lengths, therefore the number of iteration counts are not
-            #   always equal to each other and the iterations are limitted
-            #   by the fewest iterations.
-            #   i.e. the samples over the counts are discarded.
-            batches = batches[rank::world_size]
-        else:
-            batch_size = iter_options.batch_size
-            num_cache_chunks = args.num_cache_chunks
-
-        return ChunkIterFactory(
-            dataset=dataset,
-            batches=batches,
-            seed=args.seed,
-            batch_size=batch_size,
-            # For chunk iterator,
-            # --num_iters_per_epoch doesn't indicate the number of iterations,
-            # but indicates the number of samples.
-            num_samples_per_epoch=iter_options.num_iters_per_epoch,
-            shuffle=iter_options.train,
-            num_workers=args.num_workers,
-            collate_fn=iter_options.collate_fn,
-            pin_memory=args.ngpu > 0,
-            chunk_length=args.chunk_length,
-            chunk_shift_ratio=args.chunk_shift_ratio,
-            num_cache_chunks=num_cache_chunks,
-        )
-
-    # NOTE(kamo): Not abstract class
-    @classmethod
-    def build_task_iter_factory(
-            cls,
-            args: argparse.Namespace,
-            iter_options: IteratorOptions,
-            mode: str,
-    ) -> AbsIterFactory:
-        """Build task specific iterator factory
-
-        Example:
-
-            >>> class YourTask(AbsTask):
-            ... @classmethod
-            ... def add_task_arguments(cls, parser: argparse.ArgumentParser):
-            ...     parser.set_defaults(iterator_type="task")
-            ...
-            ... @classmethod
-            ... def build_task_iter_factory(
-            ...     cls,
-            ...     args: argparse.Namespace,
-            ...     iter_options: IteratorOptions,
-            ...     mode: str,
-            ... ):
-            ...     return FooIterFactory(...)
-            ...
-            ... @classmethod
-            ... def build_iter_options(
-            ....    args: argparse.Namespace,
-            ...     distributed_option: DistributedOption,
-            ...     mode: str
-            ... ):
-            ...     # if you need to customize options object
-        """
-        raise NotImplementedError
-
-    @classmethod
-    def build_multiple_iter_factory(
-            cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str
-    ):
-        assert check_argument_types()
-        iter_options = cls.build_iter_options(args, distributed_option, mode)
-        assert len(iter_options.data_path_and_name_and_type) > 0, len(
-            iter_options.data_path_and_name_and_type
-        )
-
-        # 1. Sanity check
-        num_splits = None
-        for path in [
-                        path for path, _, _ in iter_options.data_path_and_name_and_type
-                    ] + list(iter_options.shape_files):
-            if not Path(path).is_dir():
-                raise RuntimeError(f"{path} is not a directory")
-            p = Path(path) / "num_splits"
-            if not p.exists():
-                raise FileNotFoundError(f"{p} is not found")
-            with p.open() as f:
-                _num_splits = int(f.read())
-                if num_splits is not None and num_splits != _num_splits:
-                    raise RuntimeError(
-                        f"Number of splits are mismathed: "
-                        f"{iter_options.data_path_and_name_and_type[0][0]} and {path}"
-                    )
-                num_splits = _num_splits
-
-            for i in range(num_splits):
-                p = Path(path) / f"split.{i}"
-                if not p.exists():
-                    raise FileNotFoundError(f"{p} is not found")
-
-        # 2. Create functions to build an iter factory for each splits
-        data_path_and_name_and_type_list = [
-            [
-                (str(Path(p) / f"split.{i}"), n, t)
-                for p, n, t in iter_options.data_path_and_name_and_type
-            ]
-            for i in range(num_splits)
-        ]
-        shape_files_list = [
-            [str(Path(s) / f"split.{i}") for s in iter_options.shape_files]
-            for i in range(num_splits)
-        ]
-        num_iters_per_epoch_list = [
-            (iter_options.num_iters_per_epoch + i) // num_splits
-            if iter_options.num_iters_per_epoch is not None
-            else None
-            for i in range(num_splits)
-        ]
-        max_cache_size = iter_options.max_cache_size / num_splits
-
-        # Note that iter-factories are built for each epoch at runtime lazily.
-        build_funcs = [
-            functools.partial(
-                cls.build_iter_factory,
-                args,
-                distributed_option,
-                mode,
-                kwargs=dict(
-                    data_path_and_name_and_type=_data_path_and_name_and_type,
-                    shape_files=_shape_files,
-                    num_iters_per_epoch=_num_iters_per_epoch,
-                    max_cache_size=max_cache_size,
-                ),
-            )
-            for (
-                _data_path_and_name_and_type,
-                _shape_files,
-                _num_iters_per_epoch,
-            ) in zip(
-                data_path_and_name_and_type_list,
-                shape_files_list,
-                num_iters_per_epoch_list,
-            )
-        ]
-
-        # 3. Build MultipleIterFactory
-        return MultipleIterFactory(
-            build_funcs=build_funcs, shuffle=iter_options.train, seed=args.seed
-        )
-
-    @classmethod
-    def build_streaming_iterator(
-            cls,
-            data_path_and_name_and_type,
-            preprocess_fn,
-            collate_fn,
-            key_file: str = None,
-            batch_size: int = 1,
-            fs: dict = None,
-            mc: bool = False,
-            dtype: str = np.float32,
-            num_workers: int = 1,
-            allow_variable_data_keys: bool = False,
-            ngpu: int = 0,
-            inference: bool = False,
-    ) -> DataLoader:
-        """Build DataLoader using iterable dataset"""
-        assert check_argument_types()
-        # For backward compatibility for pytorch DataLoader
-        if collate_fn is not None:
-            kwargs = dict(collate_fn=collate_fn)
-        else:
-            kwargs = {}
-
-        dataset = IterableESPnetDataset(
-            data_path_and_name_and_type,
-            float_dtype=dtype,
-            fs=fs,
-            mc=mc,
-            preprocess=preprocess_fn,
-            key_file=key_file,
-        )
-        if dataset.apply_utt2category:
-            kwargs.update(batch_size=1)
-        else:
-            kwargs.update(batch_size=batch_size)
-
-        cls.check_task_requirements(
-            dataset, allow_variable_data_keys, train=False, inference=inference
-        )
-
-        return DataLoader(
-            dataset=dataset,
-            pin_memory=ngpu > 0,
-            num_workers=num_workers,
-            **kwargs,
-        )
-
-    # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
-    @classmethod
-    def build_model_from_file(
-            cls,
-            config_file: Union[Path, str] = None,
-            model_file: Union[Path, str] = None,
-            cmvn_file: Union[Path, str] = None,
-            device: str = "cpu",
-    ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
-        """Build model from the files.
-
-        This method is used for inference or fine-tuning.
-
-        Args:
-            config_file: The yaml file saved when training.
-            model_file: The model file saved when training.
-            device: Device type, "cpu", "cuda", or "cuda:N".
-
-        """
-        assert check_argument_types()
-        if config_file is None:
-            assert model_file is not None, (
-                "The argument 'model_file' must be provided "
-                "if the argument 'config_file' is not specified."
-            )
-            config_file = Path(model_file).parent / "config.yaml"
-        else:
-            config_file = Path(config_file)
-
-        with config_file.open("r", encoding="utf-8") as f:
-            args = yaml.safe_load(f)
-        if cmvn_file is not None:
-            args["cmvn_file"] = cmvn_file
-        args = argparse.Namespace(**args)
-        model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
-            raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
-            )
-        model.to(device)
-        if model_file is not None:
-            if device == "cuda":
-                # NOTE(kamo): "cuda" for torch.load always indicates cuda:0
-                #   in PyTorch<=1.4
-                device = f"cuda:{torch.cuda.current_device()}"
-            model.load_state_dict(torch.load(model_file, map_location=device))
-        model.to(device)
-        return model, args

--
Gitblit v1.9.1