From 824377d2aae11dc9ebbde871e3b23a0e0cadc7af Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 17 四月 2024 16:59:29 +0800
Subject: [PATCH] Dev gzf exp (#1626)

---
 funasr/datasets/audio_datasets/espnet_samplers.py |    3 +
 funasr/train_utils/trainer.py                     |   13 ++++--
 funasr/bin/train.py                               |    2 +
 funasr/models/sense_voice/template.yaml           |   91 +++++++++++++++++++++++++++++++++------------
 funasr/datasets/audio_datasets/index_ds.py        |   10 ++++-
 5 files changed, 88 insertions(+), 31 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 353ce68..4ab2d8a 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -55,6 +55,8 @@
     torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
     torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
     torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
+    # open tf32
+    torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
     
     local_rank = int(os.environ.get('LOCAL_RANK', 0))
     if local_rank == 0:
diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py
index c6ed919..bca0753 100644
--- a/funasr/datasets/audio_datasets/espnet_samplers.py
+++ b/funasr/datasets/audio_datasets/espnet_samplers.py
@@ -61,6 +61,7 @@
         self.epoch = 0
         self.sort_size = sort_size * num_replicas
         self.max_token_length = kwargs.get("max_token_length", 2048)
+        self.min_token_length = kwargs.get("min_token_length", 0)
         self.length_scale_source = kwargs.get("length_scale_source", 1.0)
 
 
@@ -85,7 +86,7 @@
         
         for idx in sorted_indices:
             original_sample_length = self.dataset.get_source_len(idx)
-            if original_sample_length > self.max_token_length:  # Skip samples that exceed the max length
+            if original_sample_length < self.min_token_length or original_sample_length > self.max_token_length:  # Skip samples that exceed the max length
                 continue
             # Set sample_length based on the batch type
             sample_length = 1 if self.batch_type == "example" else original_sample_length
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index 5396c8a..53419e8 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -76,7 +76,10 @@
     
     def __init__(self, path: str, **kwargs):
         super().__init__()
-        
+        self.max_source_length = kwargs.get("max_source_length", 2048)
+        self.min_source_length = kwargs.get("min_source_length", 0)
+        self.max_target_length = kwargs.get("max_target_length", 2048)
+        self.min_target_length = kwargs.get("min_target_length", 0)
         if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans
             from funasr.datasets.audio_datasets.scp2jsonl import gen_jsonl_from_wav_text_list
             jsonl_outdir = os.path.dirname(path[0])
@@ -101,7 +104,10 @@
                     target_len = data.get("target_len", 0)
                     if "aishell" in source:
                         target = target.replace(" ", "")
-
+                    if source_len < self.min_source_length or source_len > self.max_source_length:
+                        continue
+                    if target_len < self.min_target_length or target_len > self.max_target_length:
+                        continue
                     contents_i = {"source": source,
                                  "prompt": prompt,
                                  "target": target,
diff --git a/funasr/models/sense_voice/template.yaml b/funasr/models/sense_voice/template.yaml
index 5ebace9..4699c94 100644
--- a/funasr/models/sense_voice/template.yaml
+++ b/funasr/models/sense_voice/template.yaml
@@ -4,42 +4,85 @@
 # to print the register_table:
 # from funasr.register import tables
 # tables.print()
-
 # network architecture
 model: SenseVoice
 model_conf:
     lsm_weight: 0.1
     length_normalized_loss: true
-    hub: funasr
+    activation_checkpoint: true
+    sos: "<|startoftranscript|>"
+    eos: "<|endoftext|>"
+    downsample_rate: 4
+    use_padmask: true
 
-
-
-# only use for hub == funasr,
-#  if hub == openai, dims is automaticall download
-dims:
-    n_mels: 128
-    n_vocab: 51866
-    n_audio_ctx: 1500
-    n_audio_state: 1280
-    n_audio_head: 20
-    n_audio_layer: 32
-    n_text_ctx: 448
-    n_text_state: 1280
-    n_text_head: 20
-    n_text_layer: 32
+    dims:
+        n_mels: 128
+        n_vocab: 60515
+        n_audio_ctx: 1500
+        n_audio_state: 1280
+        n_audio_head: 20
+        n_audio_layer: 32
+        n_text_ctx: 448
+        n_text_state: 1280
+        n_text_head: 20
+        n_text_layer: 32
 
 # frontend related
 frontend: WhisperFrontend
 frontend_conf:
     fs: 16000
-    n_mels: ${dims.n_mels}
-    do_pad_trim: true
+    n_mels: ${model_conf.dims.n_mels}
+    do_pad_trim: false
 
-tokenizer: WhisperTokenizer
+tokenizer: SenseVoiceTokenizer
 tokenizer_conf:
-  language: null
-  task: transcribe
+  vocab_path: null
   is_multilingual: true
-  num_languages: 100
+  num_languages: 8749
 
-scope_map: [none, "model."]
\ No newline at end of file
+dataset: SenseVoiceDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: EspnetStyleBatchSampler
+    batch_type: length # example or length
+    batch_size: 7000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2000 # filter samples if source_token_len+target_token_len > max_token_length,
+    min_token_length: 60
+    shuffle: True
+    num_workers: 4
+    sos: ${model_conf.sos}
+    eos: ${model_conf.eos}
+
+train_conf:
+  accum_grad: 2
+  grad_clip: 5
+  max_epoch: 20
+  keep_nbest_models: 20
+  avg_nbest_model: ${train_conf.keep_nbest_models}
+  log_interval: 50
+
+optim: adamw
+optim_conf:
+    lr: 0.00002
+
+scheduler: warmuplr
+scheduler_conf:
+    warmup_steps: 10000
+
+specaug: SpecAug
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 40
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_ratio_range:
+    - 0.0
+    - 0.12
+    num_time_mask: 2
+
+scope_map: ['encoder.encoders', 'model.encoder', 'decoder.decoders', 'model.decoder']
\ No newline at end of file
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 35a266f..caaef38 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -268,10 +268,12 @@
         # Initialize the gradient accumulation
         optim.zero_grad()
         speed_stats = {}
-        time5 = time.perf_counter()
+        
         iterator_stop = torch.tensor(0).to(self.device)
 
         dataloader_train.batch_sampler.set_epoch(epoch)
+        time_beg = time.perf_counter()
+        time5 = time_beg
         for batch_idx, batch in enumerate(dataloader_train):
             if self.use_ddp or self.use_fsdp:
                 dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
@@ -279,11 +281,13 @@
                     break
             self.batch_total += 1
             time1 = time.perf_counter()
-            speed_stats["data_load"] = f"{time1-time5:0.3f}"
+            speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
 
             batch = to_device(batch, self.device)
-            
-            my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext
+
+            my_context = nullcontext
+            if self.use_ddp or self.use_fsdp:
+                my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context
             with my_context():
                 time2 = time.perf_counter()
                 with maybe_autocast(self.use_fp16):
@@ -384,6 +388,7 @@
             if (batch_idx+1) % self.save_checkpoint_interval == 0:
                 self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
 
+            time_beg = time.perf_counter()
         else:
             if self.use_ddp or self.use_fsdp:
                 iterator_stop.fill_(1)

--
Gitblit v1.9.1