From adcee8828ef5d78b575043954deb662a35e318f7 Mon Sep 17 00:00:00 2001
From: huangmingming <huangmingming@deepscience.cn>
Date: 星期一, 30 一月 2023 16:02:54 +0800
Subject: [PATCH] update the minimum size of audio

---
 funasr/tasks/abs_task.py |  132 ++++++++++++++++++++++++++++++++++---------
 1 files changed, 103 insertions(+), 29 deletions(-)

diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 5ea78c3..5424f13 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -25,6 +25,7 @@
 import humanfriendly
 import numpy as np
 import torch
+import torch.distributed as dist
 import torch.multiprocessing
 import torch.nn
 import torch.optim
@@ -67,6 +68,7 @@
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_int
 from funasr.utils.types import str_or_none
+from funasr.utils.wav_utils import calc_shape, generate_data_list
 from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
 
 try:
@@ -181,6 +183,7 @@
     num_optimizers: int = 1
     trainer = Trainer
     class_choices_list: List[ClassChoices] = []
+    finetune_args: None
 
     def __init__(self):
         raise RuntimeError("This class can't be instantiated.")
@@ -278,7 +281,7 @@
 
         # NOTE(kamo): add_arguments(..., required=True) can't be used
         #  to provide --print_config mode. Instead of it, do as
-        parser.set_defaults(required=["output_dir"])
+        # parser.set_defaults(required=["output_dir"])
 
         group = parser.add_argument_group("Common configuration")
 
@@ -695,7 +698,7 @@
         group.add_argument(
             "--batch_type",
             type=str,
-            default="folded",
+            default="length",
             choices=list(BATCH_TYPES),
             help=_batch_type_help,
         )
@@ -705,6 +708,18 @@
             default=None,
             choices=list(BATCH_TYPES) + [None],
             help="If not given, the value of --batch_type is used",
+        )
+        group.add_argument(
+            "--speech_length_min",
+            type=int,
+            default=-1,
+            help="speech length min",
+        )
+        group.add_argument(
+            "--speech_length_max",
+            type=int,
+            default=-1,
+            help="speech length max",
         )
         group.add_argument("--fold_length", type=int, action="append", default=[])
         group.add_argument(
@@ -877,6 +892,11 @@
             help="flag to indicate whether training on PAI",
         )
         group.add_argument(
+            "--simple_ddp",
+            type=str2bool,
+            default=False,
+        )
+        group.add_argument(
             "--num_worker_count",
             type=int,
             default=1,
@@ -1004,24 +1024,25 @@
     @classmethod
     def check_required_command_args(cls, args: argparse.Namespace):
         assert check_argument_types()
-        for k in vars(args):
-            if "-" in k:
-                raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
+        if hasattr(args, "required"):
+            for k in vars(args):
+                if "-" in k:
+                    raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
 
-        required = ", ".join(
-            f"--{a}" for a in args.required if getattr(args, a) is None
-        )
-
-        if len(required) != 0:
-            parser = cls.get_parser()
-            parser.print_help(file=sys.stderr)
-            p = Path(sys.argv[0]).name
-            print(file=sys.stderr)
-            print(
-                f"{p}: error: the following arguments are required: " f"{required}",
-                file=sys.stderr,
+            required = ", ".join(
+                f"--{a}" for a in args.required if getattr(args, a) is None
             )
-            sys.exit(2)
+
+            if len(required) != 0:
+                parser = cls.get_parser()
+                parser.print_help(file=sys.stderr)
+                p = Path(sys.argv[0]).name
+                print(file=sys.stderr)
+                print(
+                    f"{p}: error: the following arguments are required: " f"{required}",
+                    file=sys.stderr,
+                )
+                sys.exit(2)
 
     @classmethod
     def check_task_requirements(
@@ -1086,6 +1107,22 @@
             cls.main_worker(args)
 
     @classmethod
+    def run(cls):
+        assert hasattr(cls, "finetune_args")
+        args = cls.finetune_args
+        args.train_shape_file = None
+        if args.distributed:
+            args.simple_ddp = True
+        else:
+            args.simple_ddp = False
+            args.ngpu = 1
+        args.use_pai = False
+        args.batch_type = "length"
+        args.oss_bucket = None
+        args.input_size = None
+        cls.main_worker(args)
+
+    @classmethod
     def main_worker(cls, args: argparse.Namespace):
         assert check_argument_types()
 
@@ -1094,8 +1131,40 @@
         # Setting distributed_option.dist_rank, etc.
         if args.use_pai:
             distributed_option.init_options_pai()
-        else:
+        elif not args.simple_ddp:
             distributed_option.init_options()
+
+        # Invoking torch.distributed.init_process_group
+        if args.use_pai:
+            distributed_option.init_torch_distributed_pai(args)
+        elif not args.simple_ddp:
+            distributed_option.init_torch_distributed(args)
+        elif args.distributed and args.simple_ddp:
+            distributed_option.init_torch_distributed_pai(args)
+            args.ngpu = dist.get_world_size()
+            if args.dataset_type == "small":
+                if args.batch_size is not None:
+                    args.batch_size = args.batch_size * args.ngpu
+                if args.batch_bins is not None:
+                    args.batch_bins = args.batch_bins * args.ngpu
+
+        if args.train_shape_file is None and args.dataset_type == "small":
+            if not args.simple_ddp or distributed_option.dist_rank == 0:
+                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
+                calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
+            if args.simple_ddp:
+                dist.barrier()
+            args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
+            args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")]
+
+        if args.train_data_file is None and args.dataset_type == "large":
+            if not args.simple_ddp or distributed_option.dist_rank == 0:
+                generate_data_list(args.data_dir, args.train_set)
+                generate_data_list(args.data_dir, args.dev_set)
+            if args.simple_ddp:
+                dist.barrier()
+            args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list")
+            args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list")
 
         # NOTE(kamo): Don't use logging before invoking logging.basicConfig()
         if not distributed_option.distributed or distributed_option.dist_rank == 0:
@@ -1123,11 +1192,9 @@
                 format=f"[{os.uname()[1].split('.')[0]}]"
                        f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
             )
-        # Invoking torch.distributed.init_process_group
-        if args.use_pai:
-            distributed_option.init_torch_distributed_pai(args)
-        else:
-            distributed_option.init_torch_distributed(args)
+        logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
+                                                                       distributed_option.dist_rank,
+                                                                       distributed_option.local_rank))
 
         # 1. Set random-seed
         set_all_random_seed(args.seed)
@@ -1221,10 +1288,14 @@
             # 7. Build iterator factories
             if args.dataset_type == "large":
                 from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
-                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list,
-                                                   args.config, mode="train")
-                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list,
-                                                   args.config, mode="eval")
+                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
+                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
+                                                                                               "seg_dict_file") else None,
+                                                   mode="train")
+                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
+                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
+                                                                                               "seg_dict_file") else None,
+                                                   mode="eval")
             elif args.dataset_type == "small":
                 train_iter_factory = cls.build_iter_factory(
                     args=args,
@@ -1754,6 +1825,7 @@
             cls,
             config_file: Union[Path, str] = None,
             model_file: Union[Path, str] = None,
+            cmvn_file: Union[Path, str] = None,
             device: str = "cpu",
     ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
         """Build model from the files.
@@ -1778,6 +1850,8 @@
 
         with config_file.open("r", encoding="utf-8") as f:
             args = yaml.safe_load(f)
+        if cmvn_file is not None:
+            args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
         if not isinstance(model, AbsESPnetModel):
@@ -1791,5 +1865,5 @@
                 #   in PyTorch<=1.4
                 device = f"cuda:{torch.cuda.current_device()}"
             model.load_state_dict(torch.load(model_file, map_location=device))
-
+        model.to(device)
         return model, args

--
Gitblit v1.9.1