From 524237d7595c6f3839a12df30959c1504fab79b0 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 06 六月 2024 18:53:10 +0800
Subject: [PATCH] auto frontend

---
 examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune2.sh |    1 +
 funasr/train_utils/trainer.py                                           |    7 ++++++-
 2 files changed, 7 insertions(+), 1 deletions(-)

diff --git a/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune2.sh b/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune2.sh
index b3aac2b..306e23d 100644
--- a/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune2.sh
+++ b/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune2.sh
@@ -41,6 +41,7 @@
 ++dataset_conf.batch_size=1 \
 ++dataset_conf.num_workers=0 \
 ++train_conf.max_epoch=15 \
+++train_conf.save_checkpoint_interval=1000 \
 ++optim_conf.lr=0.0001 \
 ++init_param="${init_param}" \
 ++output_dir="${output_dir}" &> ${log_file} &
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 50f99f0..60fd969 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -85,7 +85,12 @@
         self.batch_total = 0
         self.use_fp16 = use_fp16
         self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
-        self.validate_interval = kwargs.get("validate_interval", 5000)
+        self.validate_interval = kwargs.get("validate_interval", -1)
+        if self.validate_interval < 0:
+            self.validate_interval = self.save_checkpoint_interval
+        assert (
+            self.save_checkpoint_interval == self.validate_interval
+        ), f"save_checkpoint_interval must equal to validate_interval"
         self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
         self.avg_keep_nbest_models_type = kwargs.get("avg_keep_nbest_models_type", "acc")
         self.avg_nbest_model = kwargs.get("avg_nbest_model", 10)

--
Gitblit v1.9.1