From 9a8086bdf5c15d45087175f170ac33d574c0ed79 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 30 四月 2024 10:48:31 +0800
Subject: [PATCH] batch

---
 funasr/datasets/audio_datasets/update_jsonl.py |    4 ++--
 funasr/train_utils/trainer.py                  |    2 +-
 funasr/bin/train.py                            |    3 ++-
 3 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 97516eb..d20915c 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -205,7 +205,6 @@
             dataloader_tr, dataloader_val = dataloader.build_iter(
                 epoch, data_split_i=data_split_i, start_step=trainer.start_step
             )
-            trainer.start_step = 0
 
             trainer.train_epoch(
                 model=model,
@@ -218,7 +217,9 @@
                 writer=writer,
                 data_split_i=data_split_i,
                 data_split_num=dataloader.data_split_num,
+                start_step=trainer.start_step,
             )
+            trainer.start_step = 0
 
             torch.cuda.empty_cache()
 
diff --git a/funasr/datasets/audio_datasets/update_jsonl.py b/funasr/datasets/audio_datasets/update_jsonl.py
index ad47e12..05870fe 100644
--- a/funasr/datasets/audio_datasets/update_jsonl.py
+++ b/funasr/datasets/audio_datasets/update_jsonl.py
@@ -50,8 +50,8 @@
     sample_num = len(waveform)
     source_len = int(sample_num / 16000 * 1000 / 10)
     source_len_old = data["source_len"]
-    if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
-        logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
+    # if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
+    #     logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
     data["source_len"] = source_len
     data["source"] = wav_path
     jsonl_line = json.dumps(data, ensure_ascii=False)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index e86420c..a28ca51 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -456,7 +456,7 @@
                     batch_num_epoch = len(dataloader_train)
                 self.log(
                     epoch,
-                    batch_idx,
+                    batch_idx + kwargs.get("start_step", 0),
                     step_in_epoch=self.step_in_epoch,
                     batch_num_epoch=batch_num_epoch,
                     lr=lr,

--
Gitblit v1.9.1