From 2cca8104d26b454112f39b8405dcb0e70d365990 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 19 一月 2024 17:05:08 +0800
Subject: [PATCH] Funasr1.0 (#1275)
---
funasr/models/fsmn_vad_streaming/model.py | 2 --
funasr/train_utils/trainer.py | 20 ++++++++++++++++----
funasr/bin/train.py | 2 +-
funasr/auto/auto_model.py | 2 +-
examples/industrial_data_pretraining/paraformer/finetune.sh | 4 ++--
funasr/datasets/audio_datasets/samplers.py | 2 +-
6 files changed, 21 insertions(+), 11 deletions(-)
diff --git a/examples/industrial_data_pretraining/paraformer/finetune.sh b/examples/industrial_data_pretraining/paraformer/finetune.sh
index 7d89876..1aff068 100644
--- a/examples/industrial_data_pretraining/paraformer/finetune.sh
+++ b/examples/industrial_data_pretraining/paraformer/finetune.sh
@@ -11,9 +11,9 @@
+model_revision="v2.0.2" \
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
+valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
-++dataset_conf.batch_size=2 \
+++dataset_conf.batch_size=64 \
++dataset_conf.batch_type="example" \
++train_conf.max_epoch=2 \
+++dataset_conf.num_workers=4 \
+output_dir="outputs/debug/ckpt/funasr2/exp2" \
-+device="cpu" \
+debug="true"
\ No newline at end of file
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index bedc17d..3320136 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -132,7 +132,7 @@
self.punc_kwargs = punc_kwargs
self.spk_model = spk_model
self.spk_kwargs = spk_kwargs
- self.model_path = kwargs["model_path"]
+ self.model_path = kwargs.get("model_path", "./")
def build_model(self, **kwargs):
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 0334006..d9d4d62 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -40,7 +40,7 @@
def main(**kwargs):
-
+ print(kwargs)
# set random seed
tables.print()
set_all_random_seed(kwargs.get("seed", 0))
diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py
index e170c68..0d93098 100644
--- a/funasr/datasets/audio_datasets/samplers.py
+++ b/funasr/datasets/audio_datasets/samplers.py
@@ -28,7 +28,7 @@
self.shuffle = shuffle and is_training
def __len__(self):
- return self.total_samples
+ return (self.total_samples-1) // self.batch_size + 1
def set_epoch(self, epoch):
np.random.seed(epoch)
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 7c21561..becfd56 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -255,7 +255,6 @@
self.waveform = None
self.last_drop_frames = 0
-
@tables.register("model_classes", "FsmnVADStreaming")
class FsmnVADStreaming(nn.Module):
"""
@@ -500,7 +499,6 @@
# # reset class variables and clear the dict for the next query
# self.AllResetDetection()
return segments
-
def init_cache(self, cache: dict = {}, **kwargs):
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 91b30b0..62d6be8 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -147,9 +147,17 @@
for epoch in range(self.start_epoch, self.max_epoch + 1):
self._train_epoch(epoch)
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
self._validate_epoch(epoch)
-
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+
if self.rank == 0:
self._save_checkpoint(epoch)
@@ -164,7 +172,9 @@
if self.use_ddp or self.use_fsdp:
dist.barrier()
- self.writer.close()
+
+ if self.writer:
+ self.writer.close()
def _train_epoch(self, epoch):
@@ -230,6 +240,8 @@
continue
# Execute an optimization step (update model parameters)
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
self.optim.step()
self.scheduler.step()
# Clear gradients for the next accumulation stage
@@ -244,7 +256,7 @@
pbar.update(1)
if self.local_rank == 0:
description = (
- f"Epoch: {epoch}/{self.max_epoch}, "
+ f"Train epoch: {epoch}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
@@ -306,7 +318,7 @@
pbar.update(1)
if self.local_rank == 0:
description = (
- f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
+ f"validation epoch: {epoch}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
--
Gitblit v1.9.1