From 806a03609df033d61f824f1ab8527eb88fe837ad Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 十二月 2023 19:43:13 +0800
Subject: [PATCH] funasr2 paraformer biciparaformer contextuaparaformer
---
funasr/cli/train_cli.py | 91 ++++++++++++++++++++++++---------------------
1 files changed, 49 insertions(+), 42 deletions(-)
diff --git a/funasr/cli/train_cli.py b/funasr/cli/train_cli.py
index 28e0e28..a22d5d4 100644
--- a/funasr/cli/train_cli.py
+++ b/funasr/cli/train_cli.py
@@ -25,19 +25,28 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-
+from funasr.utils.download_from_hub import download_model
def preprocess_config(cfg: DictConfig):
for key, value in cfg.items():
if value == 'None':
cfg[key] = None
-
-
-@hydra.main()
-def main(kwargs: DictConfig):
- # preprocess_config(kwargs)
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(kwargs: DictConfig):
import pdb; pdb.set_trace()
+ if ":" in kwargs["model"]:
+ logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+ kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
+
+ import pdb;
+ pdb.set_trace()
+ main(**kwargs)
+
+
+def main(**kwargs):
+ # preprocess_config(kwargs)
+ # import pdb; pdb.set_trace()
# set random seed
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
@@ -46,11 +55,11 @@
local_rank = int(os.environ.get('LOCAL_RANK', 0))
# Check if we are using DDP or FSDP
- use_ddp = 'WORLD_SIZE' in os.environ
+ use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
use_fsdp = kwargs.get("use_fsdp", None)
if use_ddp or use_fsdp:
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
- device= torch.cuda.set_device(local_rank)
+ torch.cuda.set_device(local_rank)
# build_tokenizer
@@ -72,9 +81,23 @@
# model_class = load_class_from_path(kwargs.get("model").split(":"))
model_class = dynamic_import(kwargs.get("model"))
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
- # model = model.to(device=kwargs.get("device", "cpu"))
-
-
+ frontend = model.frontend
+ # init_param
+ init_param = kwargs.get("init_param", None)
+ if init_param is not None:
+ if not isinstance(init_param, Sequence):
+ init_param = (init_param,)
+ logging.info("init_param is not None: %s", init_param)
+ for p in init_param:
+ logging.info(f"Loading pretrained params from {p}")
+ load_pretrained_model(
+ model=model,
+ init_param=p,
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
+ oss_bucket=kwargs.get("oss_bucket", None),
+ )
+ else:
+ initialize(model, kwargs.get("init", "kaiming_normal"))
# import pdb;
# pdb.set_trace()
@@ -84,7 +107,7 @@
freeze_param = eval(freeze_param)
if isinstance(freeze_param, Sequence):
freeze_param = (freeze_param,)
- logging.info("freeze_param is not None: ", freeze_param)
+ logging.info("freeze_param is not None: %s", freeze_param)
for t in freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
@@ -94,9 +117,12 @@
if use_ddp:
model = model.cuda(local_rank)
- model = DDP(model, device_ids=[local_rank])
+ model = DDP(model, device_ids=[local_rank],
+ find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
elif use_fsdp:
model = FSDP(model).cuda(local_rank)
+ else:
+ model = model.to(device=kwargs.get("device", "cuda"))
# optim
@@ -111,36 +137,24 @@
scheduler_class = scheduler_choices.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
- # init_param
- init_param = kwargs.get("init_param", None)
- if init_param is not None:
- init_param = eval(init_param)
- if isinstance(init_param, Sequence):
- init_param = (init_param,)
- logging.info("init_param is not None: ", freeze_param)
- for p in init_param:
- logging.info(f"Loading pretrained params from {p}")
- load_pretrained_model(
- model=model,
- init_param=p,
- ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
- oss_bucket=kwargs.get("oss_bucket", None),
- )
- else:
- initialize(model, kwargs.get("init", "kaiming_normal"))
-
# dataset
- dataset_tr = AudioDataset(kwargs.get("train_data_set_list"), frontend=model.frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
+ dataset_tr = AudioDataset(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
# dataloader
batch_sampler = BatchSampler(dataset_tr, **kwargs.get("dataset_conf"), **kwargs.get("dataset_conf").get("batch_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler,
- num_workers=kwargs.get("num_workers", 0),
+ num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
pin_memory=True)
-
+
+ if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
+ os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
+ yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
+ OmegaConf.save(config=kwargs, f=yaml_file)
+ logging.info("config.yaml is saved to: %s", yaml_file)
+
trainer = Trainer(
model=model,
optim=optim,
@@ -158,13 +172,6 @@
torch.distributed.destroy_process_group()
-
-def train(epoch, model, op):
- pass
-
-def val():
- pass
-
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main_hydra()
\ No newline at end of file
--
Gitblit v1.9.1