From 9a8086bdf5c15d45087175f170ac33d574c0ed79 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 30 四月 2024 10:48:31 +0800
Subject: [PATCH] batch
---
funasr/datasets/audio_datasets/update_jsonl.py | 4 ++--
funasr/train_utils/trainer.py | 2 +-
funasr/bin/train.py | 3 ++-
3 files changed, 5 insertions(+), 4 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 97516eb..d20915c 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -205,7 +205,6 @@
dataloader_tr, dataloader_val = dataloader.build_iter(
epoch, data_split_i=data_split_i, start_step=trainer.start_step
)
- trainer.start_step = 0
trainer.train_epoch(
model=model,
@@ -218,7 +217,9 @@
writer=writer,
data_split_i=data_split_i,
data_split_num=dataloader.data_split_num,
+ start_step=trainer.start_step,
)
+ trainer.start_step = 0
torch.cuda.empty_cache()
diff --git a/funasr/datasets/audio_datasets/update_jsonl.py b/funasr/datasets/audio_datasets/update_jsonl.py
index ad47e12..05870fe 100644
--- a/funasr/datasets/audio_datasets/update_jsonl.py
+++ b/funasr/datasets/audio_datasets/update_jsonl.py
@@ -50,8 +50,8 @@
sample_num = len(waveform)
source_len = int(sample_num / 16000 * 1000 / 10)
source_len_old = data["source_len"]
- if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
- logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
+ # if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
+ # logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
data["source_len"] = source_len
data["source"] = wav_path
jsonl_line = json.dumps(data, ensure_ascii=False)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index e86420c..a28ca51 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -456,7 +456,7 @@
batch_num_epoch = len(dataloader_train)
self.log(
epoch,
- batch_idx,
+ batch_idx + kwargs.get("start_step", 0),
step_in_epoch=self.step_in_epoch,
batch_num_epoch=batch_num_epoch,
lr=lr,
--
Gitblit v1.9.1