From 4ace5a95b052d338947fc88809a440ccd55cf6b4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 十一月 2023 16:39:52 +0800
Subject: [PATCH] funasr pages
---
funasr/tasks/abs_task.py | 86 ++++++++++++++++++++++---------------------
1 files changed, 44 insertions(+), 42 deletions(-)
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 5f9e8fc..f7f13d2 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -32,8 +32,6 @@
import yaml
from funasr.models.base_model import FunASRModel
from torch.utils.data import DataLoader
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr import __version__
from funasr.datasets.dataset import AbsDataset
@@ -73,6 +71,7 @@
from funasr.utils.types import str_or_none
from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
+from funasr.modules.lora.utils import mark_only_lora_as_trainable
try:
import wandb
@@ -266,9 +265,9 @@
def build_model(cls, args: argparse.Namespace) -> FunASRModel:
raise NotImplementedError
+
@classmethod
def get_parser(cls) -> config_argparse.ArgumentParser:
- assert check_argument_types()
class ArgumentDefaultsRawTextHelpFormatter(
argparse.RawTextHelpFormatter,
@@ -445,6 +444,12 @@
help='Perform on "collect stats" mode',
)
group.add_argument(
+ "--mc",
+ type=bool,
+ default=False,
+ help="MultiChannel input",
+ )
+ group.add_argument(
"--write_collected_feats",
type=str2bool,
default=False,
@@ -467,7 +472,7 @@
parser.add_argument(
"--batch_interval",
type=int,
- default=10000,
+ default=-1,
help="The batch interval for saving model.",
)
group.add_argument(
@@ -547,6 +552,12 @@
type=int,
default=1,
help="The number of gradient accumulation",
+ )
+ group.add_argument(
+ "--bias_grad_times",
+ type=float,
+ default=1.0,
+ help="To scale the gradient of contextual related params",
)
group.add_argument(
"--no_forward_run",
@@ -635,8 +646,8 @@
group.add_argument(
"--init_param",
type=str,
+ action="append",
default=[],
- nargs="*",
help="Specify the file path used for initialization of parameters. "
"The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
"where file_path is the model file path, "
@@ -662,7 +673,7 @@
"--freeze_param",
type=str,
default=[],
- nargs="*",
+ action="append",
help="Freeze parameters",
)
@@ -942,11 +953,22 @@
default=None,
help="oss bucket.",
)
+ group.add_argument(
+ "--enable_lora",
+ type=str2bool,
+ default=False,
+ help="Apply lora for finetuning.",
+ )
+ group.add_argument(
+ "--lora_bias",
+ type=str,
+ default="none",
+ help="lora bias.",
+ )
cls.trainer.add_arguments(parser)
cls.add_task_arguments(parser)
- assert check_return_type(parser)
return parser
@classmethod
@@ -994,7 +1016,6 @@
return _cls
# This method is used only for --print_config
- assert check_argument_types()
parser = cls.get_parser()
args, _ = parser.parse_known_args()
config = vars(args)
@@ -1034,7 +1055,6 @@
@classmethod
def check_required_command_args(cls, args: argparse.Namespace):
- assert check_argument_types()
if hasattr(args, "required"):
for k in vars(args):
if "-" in k:
@@ -1064,7 +1084,6 @@
inference: bool = False,
) -> None:
"""Check if the dataset satisfy the requirement of current Task"""
- assert check_argument_types()
mes = (
f"If you intend to use an additional input, modify "
f'"{cls.__name__}.required_data_names()" or '
@@ -1091,14 +1110,12 @@
@classmethod
def print_config(cls, file=sys.stdout) -> None:
- assert check_argument_types()
# Shows the config: e.g. python train.py asr --print_config
config = cls.get_default_config()
file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False))
@classmethod
def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None):
- assert check_argument_types()
print(get_commandline_args(), file=sys.stderr)
if args is None:
parser = cls.get_parser()
@@ -1135,7 +1152,6 @@
@classmethod
def main_worker(cls, args: argparse.Namespace):
- assert check_argument_types()
# 0. Init distributed process
distributed_option = build_dataclass(DistributedOption, args)
@@ -1153,10 +1169,10 @@
elif args.distributed and args.simple_ddp:
distributed_option.init_torch_distributed_pai(args)
args.ngpu = dist.get_world_size()
- if args.dataset_type == "small":
+ if args.dataset_type == "small" and args.ngpu > 0:
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
+ if args.batch_bins is not None and args.ngpu > 0:
args.batch_bins = args.batch_bins * args.ngpu
# filter samples if wav.scp and text are mismatch
@@ -1243,6 +1259,8 @@
dtype=getattr(torch, args.train_dtype),
device="cuda" if args.ngpu > 0 else "cpu",
)
+ if args.enable_lora:
+ mark_only_lora_as_trainable(model, args.lora_bias)
for t in args.freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
@@ -1319,6 +1337,7 @@
data_path_and_name_and_type=args.train_data_path_and_name_and_type,
key_file=train_key_file,
batch_size=args.batch_size,
+ mc=args.mc,
dtype=args.train_dtype,
num_workers=args.num_workers,
allow_variable_data_keys=args.allow_variable_data_keys,
@@ -1330,6 +1349,7 @@
data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
key_file=valid_key_file,
batch_size=args.valid_batch_size,
+ mc=args.mc,
dtype=args.train_dtype,
num_workers=args.num_workers,
allow_variable_data_keys=args.allow_variable_data_keys,
@@ -1361,25 +1381,10 @@
# 7. Build iterator factories
if args.dataset_type == "large":
- from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
- train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
- frontend_conf=args.frontend_conf if hasattr(args,
- "frontend_conf") else None,
- seg_dict_file=args.seg_dict_file if hasattr(args,
- "seg_dict_file") else None,
- punc_dict_file=args.punc_list if hasattr(args,
- "punc_list") else None,
- bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
- mode="train")
- valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
- frontend_conf=args.frontend_conf if hasattr(args,
- "frontend_conf") else None,
- seg_dict_file=args.seg_dict_file if hasattr(args,
- "seg_dict_file") else None,
- punc_dict_file=args.punc_list if hasattr(args,
- "punc_list") else None,
- bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
- mode="eval")
+ from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
+ train_iter_factory = LargeDataLoader(args, mode="train")
+ valid_iter_factory = LargeDataLoader(args, mode="eval")
+
elif args.dataset_type == "small":
train_iter_factory = cls.build_iter_factory(
args=args,
@@ -1556,7 +1561,6 @@
- 4 epoch with "--num_iters_per_epoch" == 4
"""
- assert check_argument_types()
iter_options = cls.build_iter_options(args, distributed_option, mode)
# Overwrite iter_options if any kwargs is given
@@ -1589,10 +1593,12 @@
def build_sequence_iter_factory(
cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str
) -> AbsIterFactory:
- assert check_argument_types()
- if args.frontend_conf is not None and "fs" in args.frontend_conf:
- dest_sample_rate = args.frontend_conf["fs"]
+ if hasattr(args, "frontend_conf"):
+ if args.frontend_conf is not None and "fs" in args.frontend_conf:
+ dest_sample_rate = args.frontend_conf["fs"]
+ else:
+ dest_sample_rate = 16000
else:
dest_sample_rate = 16000
@@ -1680,7 +1686,6 @@
iter_options: IteratorOptions,
mode: str,
) -> AbsIterFactory:
- assert check_argument_types()
dataset = ESPnetDataset(
iter_options.data_path_and_name_and_type,
@@ -1785,7 +1790,6 @@
def build_multiple_iter_factory(
cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str
):
- assert check_argument_types()
iter_options = cls.build_iter_options(args, distributed_option, mode)
assert len(iter_options.data_path_and_name_and_type) > 0, len(
iter_options.data_path_and_name_and_type
@@ -1882,7 +1886,6 @@
inference: bool = False,
) -> DataLoader:
"""Build DataLoader using iterable dataset"""
- assert check_argument_types()
# For backward compatibility for pytorch DataLoader
if collate_fn is not None:
kwargs = dict(collate_fn=collate_fn)
@@ -1932,7 +1935,6 @@
device: Device type, "cpu", "cuda", or "cuda:N".
"""
- assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
--
Gitblit v1.9.1