From 98abc0e5ac1a1da0fe1802d9ffb623802fbf0b2f Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 29 六月 2023 16:30:39 +0800
Subject: [PATCH] update setup (#686)

---
 funasr/tasks/abs_task.py |  173 ++++++++++++++++++++++++++++++++++++++++++---------------
 1 files changed, 126 insertions(+), 47 deletions(-)

diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 4e79c63..91d33c5 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -30,9 +30,8 @@
 import torch.nn
 import torch.optim
 import yaml
+from funasr.models.base_model import FunASRModel
 from torch.utils.data import DataLoader
-from typeguard import check_argument_types
-from typeguard import check_return_type
 
 from funasr import __version__
 from funasr.datasets.dataset import AbsDataset
@@ -43,17 +42,19 @@
 from funasr.iterators.chunk_iter_factory import ChunkIterFactory
 from funasr.iterators.multiple_iter_factory import MultipleIterFactory
 from funasr.iterators.sequence_iter_factory import SequenceIterFactory
+from funasr.main_funcs.collect_stats import collect_stats
+from funasr.optimizers.fairseq_adam import FairseqAdam
 from funasr.optimizers.sgd import SGD
 from funasr.samplers.build_batch_sampler import BATCH_TYPES
 from funasr.samplers.build_batch_sampler import build_batch_sampler
 from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
 from funasr.schedulers.noam_lr import NoamLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
 from funasr.schedulers.warmup_lr import WarmupLR
 from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.distributed_utils import DistributedOption
 from funasr.train.trainer import Trainer
@@ -68,7 +69,7 @@
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_int
 from funasr.utils.types import str_or_none
-from funasr.utils.wav_utils import calc_shape, generate_data_list
+from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
 from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
 
 try:
@@ -83,6 +84,7 @@
 
 optim_classes = dict(
     adam=torch.optim.Adam,
+    fairseq_adam=FairseqAdam,
     adamw=torch.optim.AdamW,
     sgd=SGD,
     adadelta=torch.optim.Adadelta,
@@ -149,6 +151,7 @@
     CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
     noamlr=NoamLR,
     warmuplr=WarmupLR,
+    tri_stage=TriStageLR,
     cycliclr=torch.optim.lr_scheduler.CyclicLR,
     onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
     CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
@@ -225,8 +228,8 @@
         >>> cls.check_task_requirements()
         If your model is defined as following,
 
-        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
-        >>> class Model(AbsESPnetModel):
+        >>> from funasr.models.base_model import FunASRModel
+        >>> class Model(FunASRModel):
         ...     def forward(self, input, output, opt=None):  pass
 
         then "required_data_names" should be as
@@ -246,8 +249,8 @@
         >>> cls.check_task_requirements()
         If your model is defined as follows,
 
-        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
-        >>> class Model(AbsESPnetModel):
+        >>> from funasr.models.base_model import FunASRModel
+        >>> class Model(FunASRModel):
         ...     def forward(self, input, output, opt=None):  pass
 
         then "optional_data_names" should be as
@@ -258,12 +261,12 @@
 
     @classmethod
     @abstractmethod
-    def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel:
+    def build_model(cls, args: argparse.Namespace) -> FunASRModel:
         raise NotImplementedError
+
 
     @classmethod
     def get_parser(cls) -> config_argparse.ArgumentParser:
-        assert check_argument_types()
 
         class ArgumentDefaultsRawTextHelpFormatter(
             argparse.RawTextHelpFormatter,
@@ -440,6 +443,12 @@
             help='Perform on "collect stats" mode',
         )
         group.add_argument(
+            "--mc",
+            type=bool,
+            default=False,
+            help="MultiChannel input",
+        )
+        group.add_argument(
             "--write_collected_feats",
             type=str2bool,
             default=False,
@@ -458,6 +467,12 @@
             type=int,
             default=sys.maxsize,
             help="The maximum number update step to train",
+        )
+        parser.add_argument(
+            "--batch_interval",
+            type=int,
+            default=-1,
+            help="The batch interval for saving model.",
         )
         group.add_argument(
             "--patience",
@@ -536,6 +551,12 @@
             type=int,
             default=1,
             help="The number of gradient accumulation",
+        )
+        group.add_argument(
+            "--bias_grad_times",
+            type=float,
+            default=1.0,
+            help="To scale the gradient of contextual related params",
         )
         group.add_argument(
             "--no_forward_run",
@@ -624,8 +645,8 @@
         group.add_argument(
             "--init_param",
             type=str,
+            action="append",
             default=[],
-            nargs="*",
             help="Specify the file path used for initialization of parameters. "
                  "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
                  "where file_path is the model file path, "
@@ -634,12 +655,12 @@
                  "and exclude_keys excludes keys of model states for the initialization."
                  "e.g.\n"
                  "  # Load all parameters"
-                 "  --init_param some/where/model.pth\n"
+                 "  --init_param some/where/model.pb\n"
                  "  # Load only decoder parameters"
-                 "  --init_param some/where/model.pth:decoder:decoder\n"
+                 "  --init_param some/where/model.pb:decoder:decoder\n"
                  "  # Load only decoder parameters excluding decoder.embed"
-                 "  --init_param some/where/model.pth:decoder:decoder:decoder.embed\n"
-                 "  --init_param some/where/model.pth:decoder:decoder:decoder.embed\n",
+                 "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
+                 "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
         )
         group.add_argument(
             "--ignore_init_mismatch",
@@ -651,7 +672,7 @@
             "--freeze_param",
             type=str,
             default=[],
-            nargs="*",
+            action="append",
             help="Freeze parameters",
         )
 
@@ -935,7 +956,6 @@
         cls.trainer.add_arguments(parser)
         cls.add_task_arguments(parser)
 
-        assert check_return_type(parser)
         return parser
 
     @classmethod
@@ -983,7 +1003,6 @@
             return _cls
 
         # This method is used only for --print_config
-        assert check_argument_types()
         parser = cls.get_parser()
         args, _ = parser.parse_known_args()
         config = vars(args)
@@ -1023,7 +1042,6 @@
 
     @classmethod
     def check_required_command_args(cls, args: argparse.Namespace):
-        assert check_argument_types()
         if hasattr(args, "required"):
             for k in vars(args):
                 if "-" in k:
@@ -1053,7 +1071,6 @@
             inference: bool = False,
     ) -> None:
         """Check if the dataset satisfy the requirement of current Task"""
-        assert check_argument_types()
         mes = (
             f"If you intend to use an additional input, modify "
             f'"{cls.__name__}.required_data_names()" or '
@@ -1080,14 +1097,12 @@
 
     @classmethod
     def print_config(cls, file=sys.stdout) -> None:
-        assert check_argument_types()
         # Shows the config: e.g. python train.py asr --print_config
         config = cls.get_default_config()
         file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False))
 
     @classmethod
     def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None):
-        assert check_argument_types()
         print(get_commandline_args(), file=sys.stderr)
         if args is None:
             parser = cls.get_parser()
@@ -1124,7 +1139,6 @@
 
     @classmethod
     def main_worker(cls, args: argparse.Namespace):
-        assert check_argument_types()
 
         # 0. Init distributed process
         distributed_option = build_dataclass(DistributedOption, args)
@@ -1142,16 +1156,27 @@
         elif args.distributed and args.simple_ddp:
             distributed_option.init_torch_distributed_pai(args)
             args.ngpu = dist.get_world_size()
-            if args.dataset_type == "small":
+            if args.dataset_type == "small" and args.ngpu > 0:
                 if args.batch_size is not None:
                     args.batch_size = args.batch_size * args.ngpu
-                if args.batch_bins is not None:
+                if args.batch_bins is not None and args.ngpu > 0:
                     args.batch_bins = args.batch_bins * args.ngpu
+
+        # filter samples if wav.scp and text are mismatch
+        if (
+                args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
+            if not args.simple_ddp or distributed_option.dist_rank == 0:
+                filter_wav_text(args.data_dir, args.train_set)
+                filter_wav_text(args.data_dir, args.dev_set)
+            if args.simple_ddp:
+                dist.barrier()
 
         if args.train_shape_file is None and args.dataset_type == "small":
             if not args.simple_ddp or distributed_option.dist_rank == 0:
-                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
-                calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
+                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min,
+                           args.speech_length_max)
+                calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min,
+                           args.speech_length_max)
             if args.simple_ddp:
                 dist.barrier()
             args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
@@ -1180,12 +1205,18 @@
             # logging.basicConfig() is invoked in main_worker() instead of main()
             # because it can be invoked only once in a process.
             # FIXME(kamo): Should we use logging.getLogger()?
+            # BUGFIX: Remove previous handlers and reset log level
+            for handler in logging.root.handlers[:]:
+                logging.root.removeHandler(handler)
             logging.basicConfig(
                 level=args.log_level,
                 format=f"[{os.uname()[1].split('.')[0]}]"
                        f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
             )
         else:
+            # BUGFIX: Remove previous handlers and reset log level
+            for handler in logging.root.handlers[:]:
+                logging.root.removeHandler(handler)
             # Suppress logging if RANK != 0
             logging.basicConfig(
                 level="ERROR",
@@ -1207,9 +1238,9 @@
 
         # 2. Build model
         model = cls.build_model(args=args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model = model.to(
             dtype=getattr(torch, args.train_dtype),
@@ -1268,6 +1299,54 @@
 
         if args.dry_run:
             pass
+        elif args.collect_stats:
+            # Perform on collect_stats mode. This mode has two roles
+            # - Derive the length and dimension of all input data
+            # - Accumulate feats, square values, and the length for whitening
+
+            if args.valid_batch_size is None:
+                args.valid_batch_size = args.batch_size
+
+            if len(args.train_shape_file) != 0:
+                train_key_file = args.train_shape_file[0]
+            else:
+                train_key_file = None
+            if len(args.valid_shape_file) != 0:
+                valid_key_file = args.valid_shape_file[0]
+            else:
+                valid_key_file = None
+
+            collect_stats(
+                model=model,
+                train_iter=cls.build_streaming_iterator(
+                    data_path_and_name_and_type=args.train_data_path_and_name_and_type,
+                    key_file=train_key_file,
+                    batch_size=args.batch_size,
+                    mc=args.mc,
+                    dtype=args.train_dtype,
+                    num_workers=args.num_workers,
+                    allow_variable_data_keys=args.allow_variable_data_keys,
+                    ngpu=args.ngpu,
+                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
+                    collate_fn=cls.build_collate_fn(args, train=False),
+                ),
+                valid_iter=cls.build_streaming_iterator(
+                    data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
+                    key_file=valid_key_file,
+                    batch_size=args.valid_batch_size,
+                    mc=args.mc,
+                    dtype=args.train_dtype,
+                    num_workers=args.num_workers,
+                    allow_variable_data_keys=args.allow_variable_data_keys,
+                    ngpu=args.ngpu,
+                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
+                    collate_fn=cls.build_collate_fn(args, train=False),
+                ),
+                output_dir=output_dir,
+                ngpu=args.ngpu,
+                log_interval=args.log_interval,
+                write_collected_feats=args.write_collected_feats,
+            )
         else:
             logging.info("Training args: {}".format(args))
             # 6. Loads pre-trained model
@@ -1287,15 +1366,10 @@
 
             # 7. Build iterator factories
             if args.dataset_type == "large":
-                from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
-                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
-                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
-                                                                                               "seg_dict_file") else None,
-                                                   mode="train")
-                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
-                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
-                                                                                               "seg_dict_file") else None,
-                                                   mode="eval")
+                from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
+                train_iter_factory = LargeDataLoader(args, mode="train")
+                valid_iter_factory = LargeDataLoader(args, mode="eval")
+
             elif args.dataset_type == "small":
                 train_iter_factory = cls.build_iter_factory(
                     args=args,
@@ -1472,7 +1546,6 @@
         - 4 epoch with "--num_iters_per_epoch" == 4
 
         """
-        assert check_argument_types()
         iter_options = cls.build_iter_options(args, distributed_option, mode)
 
         # Overwrite iter_options if any kwargs is given
@@ -1505,7 +1578,14 @@
     def build_sequence_iter_factory(
             cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str
     ) -> AbsIterFactory:
-        assert check_argument_types()
+
+        if hasattr(args, "frontend_conf"):
+            if args.frontend_conf is not None and "fs" in args.frontend_conf:
+                dest_sample_rate = args.frontend_conf["fs"]
+            else:
+                dest_sample_rate = 16000
+        else:
+            dest_sample_rate = 16000
 
         dataset = ESPnetDataset(
             iter_options.data_path_and_name_and_type,
@@ -1513,6 +1593,7 @@
             preprocess=iter_options.preprocess_fn,
             max_cache_size=iter_options.max_cache_size,
             max_cache_fd=iter_options.max_cache_fd,
+            dest_sample_rate=dest_sample_rate,
         )
         cls.check_task_requirements(
             dataset, args.allow_variable_data_keys, train=iter_options.train
@@ -1590,7 +1671,6 @@
             iter_options: IteratorOptions,
             mode: str,
     ) -> AbsIterFactory:
-        assert check_argument_types()
 
         dataset = ESPnetDataset(
             iter_options.data_path_and_name_and_type,
@@ -1695,7 +1775,6 @@
     def build_multiple_iter_factory(
             cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str
     ):
-        assert check_argument_types()
         iter_options = cls.build_iter_options(args, distributed_option, mode)
         assert len(iter_options.data_path_and_name_and_type) > 0, len(
             iter_options.data_path_and_name_and_type
@@ -1784,6 +1863,7 @@
             key_file: str = None,
             batch_size: int = 1,
             fs: dict = None,
+            mc: bool = False,
             dtype: str = np.float32,
             num_workers: int = 1,
             allow_variable_data_keys: bool = False,
@@ -1791,7 +1871,6 @@
             inference: bool = False,
     ) -> DataLoader:
         """Build DataLoader using iterable dataset"""
-        assert check_argument_types()
         # For backward compatibility for pytorch DataLoader
         if collate_fn is not None:
             kwargs = dict(collate_fn=collate_fn)
@@ -1802,6 +1881,7 @@
             data_path_and_name_and_type,
             float_dtype=dtype,
             fs=fs,
+            mc=mc,
             preprocess=preprocess_fn,
             key_file=key_file,
         )
@@ -1829,7 +1909,7 @@
             model_file: Union[Path, str] = None,
             cmvn_file: Union[Path, str] = None,
             device: str = "cpu",
-    ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
+    ) -> Tuple[FunASRModel, argparse.Namespace]:
         """Build model from the files.
 
         This method is used for inference or fine-tuning.
@@ -1840,7 +1920,6 @@
             device: Device type, "cpu", "cuda", or "cuda:N".
 
         """
-        assert check_argument_types()
         if config_file is None:
             assert model_file is not None, (
                 "The argument 'model_file' must be provided "
@@ -1856,9 +1935,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model.to(device)
         if model_file is not None:

--
Gitblit v1.9.1