From adcee8828ef5d78b575043954deb662a35e318f7 Mon Sep 17 00:00:00 2001
From: huangmingming <huangmingming@deepscience.cn>
Date: 星期一, 30 一月 2023 16:02:54 +0800
Subject: [PATCH] update the minimum size of audio
---
funasr/tasks/abs_task.py | 132 ++++++++++++++++++++++++++++++++++---------
1 files changed, 103 insertions(+), 29 deletions(-)
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 5ea78c3..5424f13 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -25,6 +25,7 @@
import humanfriendly
import numpy as np
import torch
+import torch.distributed as dist
import torch.multiprocessing
import torch.nn
import torch.optim
@@ -67,6 +68,7 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_int
from funasr.utils.types import str_or_none
+from funasr.utils.wav_utils import calc_shape, generate_data_list
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
try:
@@ -181,6 +183,7 @@
num_optimizers: int = 1
trainer = Trainer
class_choices_list: List[ClassChoices] = []
+ finetune_args: None
def __init__(self):
raise RuntimeError("This class can't be instantiated.")
@@ -278,7 +281,7 @@
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
- parser.set_defaults(required=["output_dir"])
+ # parser.set_defaults(required=["output_dir"])
group = parser.add_argument_group("Common configuration")
@@ -695,7 +698,7 @@
group.add_argument(
"--batch_type",
type=str,
- default="folded",
+ default="length",
choices=list(BATCH_TYPES),
help=_batch_type_help,
)
@@ -705,6 +708,18 @@
default=None,
choices=list(BATCH_TYPES) + [None],
help="If not given, the value of --batch_type is used",
+ )
+ group.add_argument(
+ "--speech_length_min",
+ type=int,
+ default=-1,
+ help="speech length min",
+ )
+ group.add_argument(
+ "--speech_length_max",
+ type=int,
+ default=-1,
+ help="speech length max",
)
group.add_argument("--fold_length", type=int, action="append", default=[])
group.add_argument(
@@ -877,6 +892,11 @@
help="flag to indicate whether training on PAI",
)
group.add_argument(
+ "--simple_ddp",
+ type=str2bool,
+ default=False,
+ )
+ group.add_argument(
"--num_worker_count",
type=int,
default=1,
@@ -1004,24 +1024,25 @@
@classmethod
def check_required_command_args(cls, args: argparse.Namespace):
assert check_argument_types()
- for k in vars(args):
- if "-" in k:
- raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
+ if hasattr(args, "required"):
+ for k in vars(args):
+ if "-" in k:
+ raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
- required = ", ".join(
- f"--{a}" for a in args.required if getattr(args, a) is None
- )
-
- if len(required) != 0:
- parser = cls.get_parser()
- parser.print_help(file=sys.stderr)
- p = Path(sys.argv[0]).name
- print(file=sys.stderr)
- print(
- f"{p}: error: the following arguments are required: " f"{required}",
- file=sys.stderr,
+ required = ", ".join(
+ f"--{a}" for a in args.required if getattr(args, a) is None
)
- sys.exit(2)
+
+ if len(required) != 0:
+ parser = cls.get_parser()
+ parser.print_help(file=sys.stderr)
+ p = Path(sys.argv[0]).name
+ print(file=sys.stderr)
+ print(
+ f"{p}: error: the following arguments are required: " f"{required}",
+ file=sys.stderr,
+ )
+ sys.exit(2)
@classmethod
def check_task_requirements(
@@ -1086,6 +1107,22 @@
cls.main_worker(args)
@classmethod
+ def run(cls):
+ assert hasattr(cls, "finetune_args")
+ args = cls.finetune_args
+ args.train_shape_file = None
+ if args.distributed:
+ args.simple_ddp = True
+ else:
+ args.simple_ddp = False
+ args.ngpu = 1
+ args.use_pai = False
+ args.batch_type = "length"
+ args.oss_bucket = None
+ args.input_size = None
+ cls.main_worker(args)
+
+ @classmethod
def main_worker(cls, args: argparse.Namespace):
assert check_argument_types()
@@ -1094,8 +1131,40 @@
# Setting distributed_option.dist_rank, etc.
if args.use_pai:
distributed_option.init_options_pai()
- else:
+ elif not args.simple_ddp:
distributed_option.init_options()
+
+ # Invoking torch.distributed.init_process_group
+ if args.use_pai:
+ distributed_option.init_torch_distributed_pai(args)
+ elif not args.simple_ddp:
+ distributed_option.init_torch_distributed(args)
+ elif args.distributed and args.simple_ddp:
+ distributed_option.init_torch_distributed_pai(args)
+ args.ngpu = dist.get_world_size()
+ if args.dataset_type == "small":
+ if args.batch_size is not None:
+ args.batch_size = args.batch_size * args.ngpu
+ if args.batch_bins is not None:
+ args.batch_bins = args.batch_bins * args.ngpu
+
+ if args.train_shape_file is None and args.dataset_type == "small":
+ if not args.simple_ddp or distributed_option.dist_rank == 0:
+ calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
+ calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
+ if args.simple_ddp:
+ dist.barrier()
+ args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
+ args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")]
+
+ if args.train_data_file is None and args.dataset_type == "large":
+ if not args.simple_ddp or distributed_option.dist_rank == 0:
+ generate_data_list(args.data_dir, args.train_set)
+ generate_data_list(args.data_dir, args.dev_set)
+ if args.simple_ddp:
+ dist.barrier()
+ args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list")
+ args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list")
# NOTE(kamo): Don't use logging before invoking logging.basicConfig()
if not distributed_option.distributed or distributed_option.dist_rank == 0:
@@ -1123,11 +1192,9 @@
format=f"[{os.uname()[1].split('.')[0]}]"
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
- # Invoking torch.distributed.init_process_group
- if args.use_pai:
- distributed_option.init_torch_distributed_pai(args)
- else:
- distributed_option.init_torch_distributed(args)
+ logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
+ distributed_option.dist_rank,
+ distributed_option.local_rank))
# 1. Set random-seed
set_all_random_seed(args.seed)
@@ -1221,10 +1288,14 @@
# 7. Build iterator factories
if args.dataset_type == "large":
from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
- train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list,
- args.config, mode="train")
- valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list,
- args.config, mode="eval")
+ train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
+ seg_dict_file=args.seg_dict_file if hasattr(args,
+ "seg_dict_file") else None,
+ mode="train")
+ valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
+ seg_dict_file=args.seg_dict_file if hasattr(args,
+ "seg_dict_file") else None,
+ mode="eval")
elif args.dataset_type == "small":
train_iter_factory = cls.build_iter_factory(
args=args,
@@ -1754,6 +1825,7 @@
cls,
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
device: str = "cpu",
) -> Tuple[AbsESPnetModel, argparse.Namespace]:
"""Build model from the files.
@@ -1778,6 +1850,8 @@
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
+ if cmvn_file is not None:
+ args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
if not isinstance(model, AbsESPnetModel):
@@ -1791,5 +1865,5 @@
# in PyTorch<=1.4
device = f"cuda:{torch.cuda.current_device()}"
model.load_state_dict(torch.load(model_file, map_location=device))
-
+ model.to(device)
return model, args
--
Gitblit v1.9.1