游雁
2024-06-06 524237d7595c6f3839a12df30959c1504fab79b0
auto frontend
2个文件已修改
8 ■■■■ 已修改文件
examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune2.sh 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune2.sh
@@ -41,6 +41,7 @@
++dataset_conf.batch_size=1 \
++dataset_conf.num_workers=0 \
++train_conf.max_epoch=15 \
++train_conf.save_checkpoint_interval=1000 \
++optim_conf.lr=0.0001 \
++init_param="${init_param}" \
++output_dir="${output_dir}" &> ${log_file} &
funasr/train_utils/trainer.py
@@ -85,7 +85,12 @@
        self.batch_total = 0
        self.use_fp16 = use_fp16
        self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
        self.validate_interval = kwargs.get("validate_interval", 5000)
        self.validate_interval = kwargs.get("validate_interval", -1)
        if self.validate_interval < 0:
            self.validate_interval = self.save_checkpoint_interval
        assert (
            self.save_checkpoint_interval == self.validate_interval
        ), f"save_checkpoint_interval must equal to validate_interval"
        self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
        self.avg_keep_nbest_models_type = kwargs.get("avg_keep_nbest_models_type", "acc")
        self.avg_nbest_model = kwargs.get("avg_nbest_model", 10)