| | |
| | | from funasr.build_utils.build_scheduler import build_scheduler |
| | | from funasr.build_utils.build_trainer import build_trainer |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.torch_utils.load_pretrained_model import load_pretrained_model |
| | | from funasr.torch_utils.model_summary import model_summary |
| | | from funasr.torch_utils.pytorch_version import pytorch_cudnn_version |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | |
| | | else: |
| | | yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False) |
| | | |
| | | for p in args.init_param: |
| | | logging.info(f"Loading pretrained params from {p}") |
| | | load_pretrained_model( |
| | | model=model, |
| | | init_param=p, |
| | | ignore_init_mismatch=args.ignore_init_mismatch, |
| | | map_location=f"cuda:{torch.cuda.current_device()}" |
| | | if args.ngpu > 0 |
| | | else "cpu", |
| | | oss_bucket=args.oss_bucket, |
| | | ) |
| | | |
| | | # dataloader for training/validation |
| | | train_dataloader, valid_dataloader = build_dataloader(args) |
| | | |