嘉渊
2023-04-24 a4ab665d309ad1437c813aa61e5e84cc50996c4d
update
2个文件已修改
19 ■■■■ 已修改文件
funasr/bin/train.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_trainer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py
@@ -12,6 +12,7 @@
from funasr.build_utils.build_model import build_model
from funasr.build_utils.build_optimizer import build_optimizer
from funasr.build_utils.build_scheduler import build_scheduler
from funasr.build_utils.build_trainer import build_trainer
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
@@ -443,4 +444,18 @@
            else:
                yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
    # dataloader for training/validation
    train_dataloader, valid_dataloader = build_dataloader(args)
    # Trainer, including model, optimizers, etc.
    trainer = build_trainer(
        args=args,
        model=model,
        optimizers=optimizers,
        schedulers=schedulers,
        train_dataloader=train_dataloader,
        valid_dataloader=valid_dataloader,
        distributed_option=distributed_option
    )
    trainer.run()
funasr/build_utils/build_trainer.py
@@ -107,7 +107,6 @@
                 schedulers: Sequence[Optional[AbsScheduler]],
                 train_dataloader: AbsIterFactory,
                 valid_dataloader: AbsIterFactory,
                 trainer_options,
                 distributed_option: DistributedOption):
        self.trainer_options = self.build_options(args)
        self.model = model
@@ -115,7 +114,6 @@
        self.schedulers = schedulers
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.trainer_options = trainer_options
        self.distributed_option = distributed_option
    def build_options(self, args: argparse.Namespace) -> TrainerOptions:
@@ -808,7 +806,6 @@
        schedulers: Sequence[Optional[AbsScheduler]],
        train_dataloader: AbsIterFactory,
        valid_dataloader: AbsIterFactory,
        trainer_options,
        distributed_option: DistributedOption
):
    trainer = Trainer(
@@ -818,7 +815,6 @@
        schedulers=schedulers,
        train_dataloader=train_dataloader,
        valid_dataloader=valid_dataloader,
        trainer_options=trainer_options,
        distributed_option=distributed_option
    )
    return trainer