From 6427c834dfd97b1f05c6659cdc7ccf010bf82fe1 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 24 四月 2023 19:50:07 +0800
Subject: [PATCH] update

---
 funasr/tasks/abs_task.py |  297 ++++++++++++++++++++++++++++++++++++++++------------------
 1 files changed, 203 insertions(+), 94 deletions(-)

diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index d716423..6922ae0 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -25,10 +25,12 @@
 import humanfriendly
 import numpy as np
 import torch
+import torch.distributed as dist
 import torch.multiprocessing
 import torch.nn
 import torch.optim
 import yaml
+from funasr.train.abs_espnet_model import AbsESPnetModel
 from torch.utils.data import DataLoader
 from typeguard import check_argument_types
 from typeguard import check_return_type
@@ -38,22 +40,23 @@
 from funasr.datasets.dataset import DATA_TYPES
 from funasr.datasets.dataset import ESPnetDataset
 from funasr.datasets.iterable_dataset import IterableESPnetDataset
-from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope
 from funasr.iterators.abs_iter_factory import AbsIterFactory
 from funasr.iterators.chunk_iter_factory import ChunkIterFactory
 from funasr.iterators.multiple_iter_factory import MultipleIterFactory
 from funasr.iterators.sequence_iter_factory import SequenceIterFactory
+from funasr.main_funcs.collect_stats import collect_stats
+from funasr.optimizers.fairseq_adam import FairseqAdam
 from funasr.optimizers.sgd import SGD
 from funasr.samplers.build_batch_sampler import BATCH_TYPES
 from funasr.samplers.build_batch_sampler import build_batch_sampler
 from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
 from funasr.schedulers.noam_lr import NoamLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
 from funasr.schedulers.warmup_lr import WarmupLR
 from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.distributed_utils import DistributedOption
 from funasr.train.trainer import Trainer
@@ -68,6 +71,7 @@
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_int
 from funasr.utils.types import str_or_none
+from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
 from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
 
 try:
@@ -82,6 +86,7 @@
 
 optim_classes = dict(
     adam=torch.optim.Adam,
+    fairseq_adam=FairseqAdam,
     adamw=torch.optim.AdamW,
     sgd=SGD,
     adadelta=torch.optim.Adadelta,
@@ -148,6 +153,7 @@
     CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
     noamlr=NoamLR,
     warmuplr=WarmupLR,
+    tri_stage=TriStageLR,
     cycliclr=torch.optim.lr_scheduler.CyclicLR,
     onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
     CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
@@ -182,6 +188,7 @@
     num_optimizers: int = 1
     trainer = Trainer
     class_choices_list: List[ClassChoices] = []
+    finetune_args: None
 
     def __init__(self):
         raise RuntimeError("This class can't be instantiated.")
@@ -279,7 +286,7 @@
 
         # NOTE(kamo): add_arguments(..., required=True) can't be used
         #  to provide --print_config mode. Instead of it, do as
-        parser.set_defaults(required=["output_dir"])
+        # parser.set_defaults(required=["output_dir"])
 
         group = parser.add_argument_group("Common configuration")
 
@@ -457,6 +464,12 @@
             default=sys.maxsize,
             help="The maximum number update step to train",
         )
+        parser.add_argument(
+            "--batch_interval",
+            type=int,
+            default=10000,
+            help="The batch interval for saving model.",
+        )
         group.add_argument(
             "--patience",
             type=int_or_none,
@@ -632,12 +645,12 @@
                  "and exclude_keys excludes keys of model states for the initialization."
                  "e.g.\n"
                  "  # Load all parameters"
-                 "  --init_param some/where/model.pth\n"
+                 "  --init_param some/where/model.pb\n"
                  "  # Load only decoder parameters"
-                 "  --init_param some/where/model.pth:decoder:decoder\n"
+                 "  --init_param some/where/model.pb:decoder:decoder\n"
                  "  # Load only decoder parameters excluding decoder.embed"
-                 "  --init_param some/where/model.pth:decoder:decoder:decoder.embed\n"
-                 "  --init_param some/where/model.pth:decoder:decoder:decoder.embed\n",
+                 "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
+                 "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
         )
         group.add_argument(
             "--ignore_init_mismatch",
@@ -696,7 +709,7 @@
         group.add_argument(
             "--batch_type",
             type=str,
-            default="folded",
+            default="length",
             choices=list(BATCH_TYPES),
             help=_batch_type_help,
         )
@@ -706,6 +719,18 @@
             default=None,
             choices=list(BATCH_TYPES) + [None],
             help="If not given, the value of --batch_type is used",
+        )
+        group.add_argument(
+            "--speech_length_min",
+            type=int,
+            default=-1,
+            help="speech length min",
+        )
+        group.add_argument(
+            "--speech_length_max",
+            type=int,
+            default=-1,
+            help="speech length max",
         )
         group.add_argument("--fold_length", type=int, action="append", default=[])
         group.add_argument(
@@ -878,6 +903,11 @@
             help="flag to indicate whether training on PAI",
         )
         group.add_argument(
+            "--simple_ddp",
+            type=str2bool,
+            default=False,
+        )
+        group.add_argument(
             "--num_worker_count",
             type=int,
             default=1,
@@ -1005,29 +1035,30 @@
     @classmethod
     def check_required_command_args(cls, args: argparse.Namespace):
         assert check_argument_types()
-        for k in vars(args):
-            if "-" in k:
-                raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
+        if hasattr(args, "required"):
+            for k in vars(args):
+                if "-" in k:
+                    raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
 
-        required = ", ".join(
-            f"--{a}" for a in args.required if getattr(args, a) is None
-        )
-
-        if len(required) != 0:
-            parser = cls.get_parser()
-            parser.print_help(file=sys.stderr)
-            p = Path(sys.argv[0]).name
-            print(file=sys.stderr)
-            print(
-                f"{p}: error: the following arguments are required: " f"{required}",
-                file=sys.stderr,
+            required = ", ".join(
+                f"--{a}" for a in args.required if getattr(args, a) is None
             )
-            sys.exit(2)
+
+            if len(required) != 0:
+                parser = cls.get_parser()
+                parser.print_help(file=sys.stderr)
+                p = Path(sys.argv[0]).name
+                print(file=sys.stderr)
+                print(
+                    f"{p}: error: the following arguments are required: " f"{required}",
+                    file=sys.stderr,
+                )
+                sys.exit(2)
 
     @classmethod
     def check_task_requirements(
             cls,
-            dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope],
+            dataset: Union[AbsDataset, IterableESPnetDataset],
             allow_variable_data_keys: bool,
             train: bool,
             inference: bool = False,
@@ -1087,6 +1118,22 @@
             cls.main_worker(args)
 
     @classmethod
+    def run(cls):
+        assert hasattr(cls, "finetune_args")
+        args = cls.finetune_args
+        args.train_shape_file = None
+        if args.distributed:
+            args.simple_ddp = True
+        else:
+            args.simple_ddp = False
+            args.ngpu = 1
+        args.use_pai = False
+        args.batch_type = "length"
+        args.oss_bucket = None
+        args.input_size = None
+        cls.main_worker(args)
+
+    @classmethod
     def main_worker(cls, args: argparse.Namespace):
         assert check_argument_types()
 
@@ -1095,8 +1142,51 @@
         # Setting distributed_option.dist_rank, etc.
         if args.use_pai:
             distributed_option.init_options_pai()
-        else:
+        elif not args.simple_ddp:
             distributed_option.init_options()
+
+        # Invoking torch.distributed.init_process_group
+        if args.use_pai:
+            distributed_option.init_torch_distributed_pai(args)
+        elif not args.simple_ddp:
+            distributed_option.init_torch_distributed(args)
+        elif args.distributed and args.simple_ddp:
+            distributed_option.init_torch_distributed_pai(args)
+            args.ngpu = dist.get_world_size()
+            if args.dataset_type == "small":
+                if args.batch_size is not None:
+                    args.batch_size = args.batch_size * args.ngpu
+                if args.batch_bins is not None:
+                    args.batch_bins = args.batch_bins * args.ngpu
+
+        # filter samples if wav.scp and text are mismatch
+        if (
+                args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
+            if not args.simple_ddp or distributed_option.dist_rank == 0:
+                filter_wav_text(args.data_dir, args.train_set)
+                filter_wav_text(args.data_dir, args.dev_set)
+            if args.simple_ddp:
+                dist.barrier()
+
+        if args.train_shape_file is None and args.dataset_type == "small":
+            if not args.simple_ddp or distributed_option.dist_rank == 0:
+                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min,
+                           args.speech_length_max)
+                calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min,
+                           args.speech_length_max)
+            if args.simple_ddp:
+                dist.barrier()
+            args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
+            args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")]
+
+        if args.train_data_file is None and args.dataset_type == "large":
+            if not args.simple_ddp or distributed_option.dist_rank == 0:
+                generate_data_list(args.data_dir, args.train_set)
+                generate_data_list(args.data_dir, args.dev_set)
+            if args.simple_ddp:
+                dist.barrier()
+            args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list")
+            args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list")
 
         # NOTE(kamo): Don't use logging before invoking logging.basicConfig()
         if not distributed_option.distributed or distributed_option.dist_rank == 0:
@@ -1112,23 +1202,27 @@
             # logging.basicConfig() is invoked in main_worker() instead of main()
             # because it can be invoked only once in a process.
             # FIXME(kamo): Should we use logging.getLogger()?
+            # BUGFIX: Remove previous handlers and reset log level
+            for handler in logging.root.handlers[:]:
+                logging.root.removeHandler(handler)
             logging.basicConfig(
                 level=args.log_level,
                 format=f"[{os.uname()[1].split('.')[0]}]"
                        f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
             )
         else:
+            # BUGFIX: Remove previous handlers and reset log level
+            for handler in logging.root.handlers[:]:
+                logging.root.removeHandler(handler)
             # Suppress logging if RANK != 0
             logging.basicConfig(
                 level="ERROR",
                 format=f"[{os.uname()[1].split('.')[0]}]"
                        f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
             )
-        # Invoking torch.distributed.init_process_group
-        if args.use_pai:
-            distributed_option.init_torch_distributed_pai(args)
-        else:
-            distributed_option.init_torch_distributed(args)
+        logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
+                                                                       distributed_option.dist_rank,
+                                                                       distributed_option.local_rank))
 
         # 1. Set random-seed
         set_all_random_seed(args.seed)
@@ -1202,6 +1296,52 @@
 
         if args.dry_run:
             pass
+        elif args.collect_stats:
+            # Perform on collect_stats mode. This mode has two roles
+            # - Derive the length and dimension of all input data
+            # - Accumulate feats, square values, and the length for whitening
+
+            if args.valid_batch_size is None:
+                args.valid_batch_size = args.batch_size
+
+            if len(args.train_shape_file) != 0:
+                train_key_file = args.train_shape_file[0]
+            else:
+                train_key_file = None
+            if len(args.valid_shape_file) != 0:
+                valid_key_file = args.valid_shape_file[0]
+            else:
+                valid_key_file = None
+
+            collect_stats(
+                model=model,
+                train_iter=cls.build_streaming_iterator(
+                    data_path_and_name_and_type=args.train_data_path_and_name_and_type,
+                    key_file=train_key_file,
+                    batch_size=args.batch_size,
+                    dtype=args.train_dtype,
+                    num_workers=args.num_workers,
+                    allow_variable_data_keys=args.allow_variable_data_keys,
+                    ngpu=args.ngpu,
+                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
+                    collate_fn=cls.build_collate_fn(args, train=False),
+                ),
+                valid_iter=cls.build_streaming_iterator(
+                    data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
+                    key_file=valid_key_file,
+                    batch_size=args.valid_batch_size,
+                    dtype=args.train_dtype,
+                    num_workers=args.num_workers,
+                    allow_variable_data_keys=args.allow_variable_data_keys,
+                    ngpu=args.ngpu,
+                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
+                    collate_fn=cls.build_collate_fn(args, train=False),
+                ),
+                output_dir=output_dir,
+                ngpu=args.ngpu,
+                log_interval=args.log_interval,
+                write_collected_feats=args.write_collected_feats,
+            )
         else:
             logging.info("Training args: {}".format(args))
             # 6. Loads pre-trained model
@@ -1222,10 +1362,24 @@
             # 7. Build iterator factories
             if args.dataset_type == "large":
                 from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
-                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list,
-                                                   args.config, mode="train")
-                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list,
-                                                   args.config, mode="eval")
+                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
+                                                   frontend_conf=args.frontend_conf if hasattr(args,
+                                                                                               "frontend_conf") else None,
+                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
+                                                                                               "seg_dict_file") else None,
+                                                   punc_dict_file=args.punc_list if hasattr(args,
+                                                                                            "punc_list") else None,
+                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
+                                                   mode="train")
+                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
+                                                   frontend_conf=args.frontend_conf if hasattr(args,
+                                                                                               "frontend_conf") else None,
+                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
+                                                                                               "seg_dict_file") else None,
+                                                   punc_dict_file=args.punc_list if hasattr(args,
+                                                                                            "punc_list") else None,
+                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
+                                                   mode="eval")
             elif args.dataset_type == "small":
                 train_iter_factory = cls.build_iter_factory(
                     args=args,
@@ -1437,12 +1591,18 @@
     ) -> AbsIterFactory:
         assert check_argument_types()
 
+        if args.frontend_conf is not None and "fs" in args.frontend_conf:
+            dest_sample_rate = args.frontend_conf["fs"]
+        else:
+            dest_sample_rate = 16000
+
         dataset = ESPnetDataset(
             iter_options.data_path_and_name_and_type,
             float_dtype=args.train_dtype,
             preprocess=iter_options.preprocess_fn,
             max_cache_size=iter_options.max_cache_size,
             max_cache_fd=iter_options.max_cache_fd,
+            dest_sample_rate=dest_sample_rate,
         )
         cls.check_task_requirements(
             dataset, args.allow_variable_data_keys, train=iter_options.train
@@ -1713,6 +1873,8 @@
             collate_fn,
             key_file: str = None,
             batch_size: int = 1,
+            fs: dict = None,
+            mc: bool = False,
             dtype: str = np.float32,
             num_workers: int = 1,
             allow_variable_data_keys: bool = False,
@@ -1730,6 +1892,8 @@
         dataset = IterableESPnetDataset(
             data_path_and_name_and_type,
             float_dtype=dtype,
+            fs=fs,
+            mc=mc,
             preprocess=preprocess_fn,
             key_file=key_file,
         )
@@ -1749,70 +1913,13 @@
             **kwargs,
         )
 
-    @classmethod
-    def build_streaming_iterator_modelscope(
-            cls,
-            data_path_and_name_and_type,
-            preprocess_fn,
-            collate_fn,
-            key_file: str = None,
-            batch_size: int = 1,
-            dtype: str = np.float32,
-            num_workers: int = 1,
-            allow_variable_data_keys: bool = False,
-            ngpu: int = 0,
-            inference: bool = False,
-            sample_rate: Union[dict, int] = 16000
-    ) -> DataLoader:
-        """Build DataLoader using iterable dataset"""
-        assert check_argument_types()
-        # For backward compatibility for pytorch DataLoader
-        if collate_fn is not None:
-            kwargs = dict(collate_fn=collate_fn)
-        else:
-            kwargs = {}
-
-        audio_data = data_path_and_name_and_type[0]
-        if isinstance(audio_data, bytes):
-            dataset = IterableESPnetBytesModelScope(
-                data_path_and_name_and_type,
-                float_dtype=dtype,
-                preprocess=preprocess_fn,
-                key_file=key_file,
-                sample_rate=sample_rate
-            )
-        else:
-            dataset = IterableESPnetDatasetModelScope(
-                data_path_and_name_and_type,
-                float_dtype=dtype,
-                preprocess=preprocess_fn,
-                key_file=key_file,
-                sample_rate=sample_rate
-            )
-
-        if dataset.apply_utt2category:
-            kwargs.update(batch_size=1)
-        else:
-            kwargs.update(batch_size=batch_size)
-
-        cls.check_task_requirements(dataset,
-                                    allow_variable_data_keys,
-                                    train=False,
-                                    inference=inference)
-
-        return DataLoader(
-            dataset=dataset,
-            pin_memory=ngpu > 0,
-            num_workers=num_workers,
-            **kwargs,
-        )
-
     # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
     @classmethod
     def build_model_from_file(
             cls,
             config_file: Union[Path, str] = None,
             model_file: Union[Path, str] = None,
+            cmvn_file: Union[Path, str] = None,
             device: str = "cpu",
     ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
         """Build model from the files.
@@ -1837,6 +1944,8 @@
 
         with config_file.open("r", encoding="utf-8") as f:
             args = yaml.safe_load(f)
+        if cmvn_file is not None:
+            args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
         if not isinstance(model, AbsESPnetModel):
@@ -1850,5 +1959,5 @@
                 #   in PyTorch<=1.4
                 device = f"cuda:{torch.cuda.current_device()}"
             model.load_state_dict(torch.load(model_file, map_location=device))
-
+        model.to(device)
         return model, args

--
Gitblit v1.9.1