From 6427c834dfd97b1f05c6659cdc7ccf010bf82fe1 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 24 四月 2023 19:50:07 +0800
Subject: [PATCH] update

---
 funasr/tasks/abs_task.py |  110 +++++++++++++++++++++++++++++++++++++++++++++++++++----
 1 files changed, 102 insertions(+), 8 deletions(-)

diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 5424f13..6922ae0 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -30,6 +30,7 @@
 import torch.nn
 import torch.optim
 import yaml
+from funasr.train.abs_espnet_model import AbsESPnetModel
 from torch.utils.data import DataLoader
 from typeguard import check_argument_types
 from typeguard import check_return_type
@@ -43,17 +44,19 @@
 from funasr.iterators.chunk_iter_factory import ChunkIterFactory
 from funasr.iterators.multiple_iter_factory import MultipleIterFactory
 from funasr.iterators.sequence_iter_factory import SequenceIterFactory
+from funasr.main_funcs.collect_stats import collect_stats
+from funasr.optimizers.fairseq_adam import FairseqAdam
 from funasr.optimizers.sgd import SGD
 from funasr.samplers.build_batch_sampler import BATCH_TYPES
 from funasr.samplers.build_batch_sampler import build_batch_sampler
 from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
 from funasr.schedulers.noam_lr import NoamLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
 from funasr.schedulers.warmup_lr import WarmupLR
 from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.distributed_utils import DistributedOption
 from funasr.train.trainer import Trainer
@@ -68,7 +71,7 @@
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_int
 from funasr.utils.types import str_or_none
-from funasr.utils.wav_utils import calc_shape, generate_data_list
+from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
 from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
 
 try:
@@ -83,6 +86,7 @@
 
 optim_classes = dict(
     adam=torch.optim.Adam,
+    fairseq_adam=FairseqAdam,
     adamw=torch.optim.AdamW,
     sgd=SGD,
     adadelta=torch.optim.Adadelta,
@@ -149,6 +153,7 @@
     CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
     noamlr=NoamLR,
     warmuplr=WarmupLR,
+    tri_stage=TriStageLR,
     cycliclr=torch.optim.lr_scheduler.CyclicLR,
     onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
     CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
@@ -459,6 +464,12 @@
             default=sys.maxsize,
             help="The maximum number update step to train",
         )
+        parser.add_argument(
+            "--batch_interval",
+            type=int,
+            default=10000,
+            help="The batch interval for saving model.",
+        )
         group.add_argument(
             "--patience",
             type=int_or_none,
@@ -634,12 +645,12 @@
                  "and exclude_keys excludes keys of model states for the initialization."
                  "e.g.\n"
                  "  # Load all parameters"
-                 "  --init_param some/where/model.pth\n"
+                 "  --init_param some/where/model.pb\n"
                  "  # Load only decoder parameters"
-                 "  --init_param some/where/model.pth:decoder:decoder\n"
+                 "  --init_param some/where/model.pb:decoder:decoder\n"
                  "  # Load only decoder parameters excluding decoder.embed"
-                 "  --init_param some/where/model.pth:decoder:decoder:decoder.embed\n"
-                 "  --init_param some/where/model.pth:decoder:decoder:decoder.embed\n",
+                 "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
+                 "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
         )
         group.add_argument(
             "--ignore_init_mismatch",
@@ -1148,10 +1159,21 @@
                 if args.batch_bins is not None:
                     args.batch_bins = args.batch_bins * args.ngpu
 
+        # filter samples if wav.scp and text are mismatch
+        if (
+                args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
+            if not args.simple_ddp or distributed_option.dist_rank == 0:
+                filter_wav_text(args.data_dir, args.train_set)
+                filter_wav_text(args.data_dir, args.dev_set)
+            if args.simple_ddp:
+                dist.barrier()
+
         if args.train_shape_file is None and args.dataset_type == "small":
             if not args.simple_ddp or distributed_option.dist_rank == 0:
-                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
-                calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
+                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min,
+                           args.speech_length_max)
+                calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min,
+                           args.speech_length_max)
             if args.simple_ddp:
                 dist.barrier()
             args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
@@ -1180,12 +1202,18 @@
             # logging.basicConfig() is invoked in main_worker() instead of main()
             # because it can be invoked only once in a process.
             # FIXME(kamo): Should we use logging.getLogger()?
+            # BUGFIX: Remove previous handlers and reset log level
+            for handler in logging.root.handlers[:]:
+                logging.root.removeHandler(handler)
             logging.basicConfig(
                 level=args.log_level,
                 format=f"[{os.uname()[1].split('.')[0]}]"
                        f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
             )
         else:
+            # BUGFIX: Remove previous handlers and reset log level
+            for handler in logging.root.handlers[:]:
+                logging.root.removeHandler(handler)
             # Suppress logging if RANK != 0
             logging.basicConfig(
                 level="ERROR",
@@ -1268,6 +1296,52 @@
 
         if args.dry_run:
             pass
+        elif args.collect_stats:
+            # Perform on collect_stats mode. This mode has two roles
+            # - Derive the length and dimension of all input data
+            # - Accumulate feats, square values, and the length for whitening
+
+            if args.valid_batch_size is None:
+                args.valid_batch_size = args.batch_size
+
+            if len(args.train_shape_file) != 0:
+                train_key_file = args.train_shape_file[0]
+            else:
+                train_key_file = None
+            if len(args.valid_shape_file) != 0:
+                valid_key_file = args.valid_shape_file[0]
+            else:
+                valid_key_file = None
+
+            collect_stats(
+                model=model,
+                train_iter=cls.build_streaming_iterator(
+                    data_path_and_name_and_type=args.train_data_path_and_name_and_type,
+                    key_file=train_key_file,
+                    batch_size=args.batch_size,
+                    dtype=args.train_dtype,
+                    num_workers=args.num_workers,
+                    allow_variable_data_keys=args.allow_variable_data_keys,
+                    ngpu=args.ngpu,
+                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
+                    collate_fn=cls.build_collate_fn(args, train=False),
+                ),
+                valid_iter=cls.build_streaming_iterator(
+                    data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
+                    key_file=valid_key_file,
+                    batch_size=args.valid_batch_size,
+                    dtype=args.train_dtype,
+                    num_workers=args.num_workers,
+                    allow_variable_data_keys=args.allow_variable_data_keys,
+                    ngpu=args.ngpu,
+                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
+                    collate_fn=cls.build_collate_fn(args, train=False),
+                ),
+                output_dir=output_dir,
+                ngpu=args.ngpu,
+                log_interval=args.log_interval,
+                write_collected_feats=args.write_collected_feats,
+            )
         else:
             logging.info("Training args: {}".format(args))
             # 6. Loads pre-trained model
@@ -1289,12 +1363,22 @@
             if args.dataset_type == "large":
                 from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
                 train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
+                                                   frontend_conf=args.frontend_conf if hasattr(args,
+                                                                                               "frontend_conf") else None,
                                                    seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                                "seg_dict_file") else None,
+                                                   punc_dict_file=args.punc_list if hasattr(args,
+                                                                                            "punc_list") else None,
+                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
                                                    mode="train")
                 valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
+                                                   frontend_conf=args.frontend_conf if hasattr(args,
+                                                                                               "frontend_conf") else None,
                                                    seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                                "seg_dict_file") else None,
+                                                   punc_dict_file=args.punc_list if hasattr(args,
+                                                                                            "punc_list") else None,
+                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
                                                    mode="eval")
             elif args.dataset_type == "small":
                 train_iter_factory = cls.build_iter_factory(
@@ -1507,12 +1591,18 @@
     ) -> AbsIterFactory:
         assert check_argument_types()
 
+        if args.frontend_conf is not None and "fs" in args.frontend_conf:
+            dest_sample_rate = args.frontend_conf["fs"]
+        else:
+            dest_sample_rate = 16000
+
         dataset = ESPnetDataset(
             iter_options.data_path_and_name_and_type,
             float_dtype=args.train_dtype,
             preprocess=iter_options.preprocess_fn,
             max_cache_size=iter_options.max_cache_size,
             max_cache_fd=iter_options.max_cache_fd,
+            dest_sample_rate=dest_sample_rate,
         )
         cls.check_task_requirements(
             dataset, args.allow_variable_data_keys, train=iter_options.train
@@ -1783,6 +1873,8 @@
             collate_fn,
             key_file: str = None,
             batch_size: int = 1,
+            fs: dict = None,
+            mc: bool = False,
             dtype: str = np.float32,
             num_workers: int = 1,
             allow_variable_data_keys: bool = False,
@@ -1800,6 +1892,8 @@
         dataset = IterableESPnetDataset(
             data_path_and_name_and_type,
             float_dtype=dtype,
+            fs=fs,
+            mc=mc,
             preprocess=preprocess_fn,
             key_file=key_file,
         )

--
Gitblit v1.9.1