From 2cca8104d26b454112f39b8405dcb0e70d365990 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 19 一月 2024 17:05:08 +0800
Subject: [PATCH] Funasr1.0 (#1275)

---
 funasr/models/fsmn_vad_streaming/model.py                   |    2 --
 funasr/train_utils/trainer.py                               |   20 ++++++++++++++++----
 funasr/bin/train.py                                         |    2 +-
 funasr/auto/auto_model.py                                   |    2 +-
 examples/industrial_data_pretraining/paraformer/finetune.sh |    4 ++--
 funasr/datasets/audio_datasets/samplers.py                  |    2 +-
 6 files changed, 21 insertions(+), 11 deletions(-)

diff --git a/examples/industrial_data_pretraining/paraformer/finetune.sh b/examples/industrial_data_pretraining/paraformer/finetune.sh
index 7d89876..1aff068 100644
--- a/examples/industrial_data_pretraining/paraformer/finetune.sh
+++ b/examples/industrial_data_pretraining/paraformer/finetune.sh
@@ -11,9 +11,9 @@
 +model_revision="v2.0.2" \
 +train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
 +valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
-++dataset_conf.batch_size=2 \
+++dataset_conf.batch_size=64 \
 ++dataset_conf.batch_type="example" \
 ++train_conf.max_epoch=2 \
+++dataset_conf.num_workers=4 \
 +output_dir="outputs/debug/ckpt/funasr2/exp2" \
-+device="cpu" \
 +debug="true"
\ No newline at end of file
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index bedc17d..3320136 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -132,7 +132,7 @@
         self.punc_kwargs = punc_kwargs
         self.spk_model = spk_model
         self.spk_kwargs = spk_kwargs
-        self.model_path = kwargs["model_path"]
+        self.model_path = kwargs.get("model_path", "./")
   
         
     def build_model(self, **kwargs):
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 0334006..d9d4d62 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -40,7 +40,7 @@
 
 
 def main(**kwargs):
-    
+    print(kwargs)
     # set random seed
     tables.print()
     set_all_random_seed(kwargs.get("seed", 0))
diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py
index e170c68..0d93098 100644
--- a/funasr/datasets/audio_datasets/samplers.py
+++ b/funasr/datasets/audio_datasets/samplers.py
@@ -28,7 +28,7 @@
         self.shuffle = shuffle and is_training
     
     def __len__(self):
-        return self.total_samples
+        return (self.total_samples-1) // self.batch_size + 1
     
     def set_epoch(self, epoch):
         np.random.seed(epoch)
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 7c21561..becfd56 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -255,7 +255,6 @@
 		self.waveform = None
 		self.last_drop_frames = 0
 
-
 @tables.register("model_classes", "FsmnVADStreaming")
 class FsmnVADStreaming(nn.Module):
 	"""
@@ -500,7 +499,6 @@
 		#     # reset class variables and clear the dict for the next query
 		#     self.AllResetDetection()
 		return segments
-	
 
 	def init_cache(self, cache: dict = {}, **kwargs):
     
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 91b30b0..62d6be8 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -147,9 +147,17 @@
         for epoch in range(self.start_epoch, self.max_epoch + 1):
             
             self._train_epoch(epoch)
+
             
+            if self.use_ddp or self.use_fsdp:
+                dist.barrier()
+                
             self._validate_epoch(epoch)
-            
+
+            if self.use_ddp or self.use_fsdp:
+                dist.barrier()
+                
+
             if self.rank == 0:
                 self._save_checkpoint(epoch)
             
@@ -164,7 +172,9 @@
             
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-        self.writer.close()
+      
+        if self.writer:
+            self.writer.close()
         
     
     def _train_epoch(self, epoch):
@@ -230,6 +240,8 @@
                         continue
                 
                 # Execute an optimization step (update model parameters)
+                if self.use_ddp or self.use_fsdp:
+                    dist.barrier()
                 self.optim.step()
                 self.scheduler.step()
                 # Clear gradients for the next accumulation stage
@@ -244,7 +256,7 @@
             pbar.update(1)
             if self.local_rank == 0:
                 description = (
-                    f"Epoch: {epoch}/{self.max_epoch}, "
+                    f"Train epoch: {epoch}/{self.max_epoch}, "
                     f"step {batch_idx}/{len(self.dataloader_train)}, "
                     f"{speed_stats}, "
                     f"(loss: {loss.detach().cpu().item():.3f}), "
@@ -306,7 +318,7 @@
                 pbar.update(1)
                 if self.local_rank == 0:
                     description = (
-                        f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
+                        f"validation epoch: {epoch}/{self.max_epoch}, "
                         f"step {batch_idx}/{len(self.dataloader_train)}, "
                         f"{speed_stats}, "
                         f"(loss: {loss.detach().cpu().item():.3f}), "

--
Gitblit v1.9.1