From 80bd14e6bbb7bb282ff3832194648dc4a16157ca Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 四月 2024 10:41:14 +0800
Subject: [PATCH] Dev gzf exp (#1657)
---
funasr/datasets/audio_datasets/espnet_samplers.py | 25 +++++++++---
funasr/train_utils/trainer.py | 9 ++++
funasr/bin/train.py | 5 ++
funasr/models/sense_voice/template.yaml | 42 +++++++++++++++-----
funasr/datasets/audio_datasets/index_ds.py | 5 ++
5 files changed, 68 insertions(+), 18 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 05942cd..eb1611a 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -32,6 +32,7 @@
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.misc import prepare_model_dir
+from funasr.train_utils.model_summary import model_summary
from funasr import AutoModel
@@ -107,6 +108,7 @@
logging.info(f"Setting {k}.requires_grad = False")
p.requires_grad = False
+ logging.info(f"model info: {model_summary(model)}")
if use_ddp:
model = model.cuda(local_rank)
model = DDP(
@@ -209,6 +211,9 @@
data_split_i=data_split_i,
data_split_num=dataloader.data_split_num,
)
+
+ torch.cuda.empty_cache()
+
trainer.validate_epoch(
model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py
index e6efe0a..3f14d09 100644
--- a/funasr/datasets/audio_datasets/espnet_samplers.py
+++ b/funasr/datasets/audio_datasets/espnet_samplers.py
@@ -92,14 +92,25 @@
max_len_in_batch = 0 # Tracks the max sample length within the current batch
for idx in sorted_indices:
- original_sample_length = self.dataset.get_source_len(idx)
- if (
- original_sample_length < self.min_token_length
- or original_sample_length > self.max_token_length
- ): # Skip samples that exceed the max length
- continue
+
+ # original_sample_length = self.dataset.get_source_len(idx)
+ # if (
+ # original_sample_length < self.min_token_length
+ # or original_sample_length > self.max_token_length
+ # ): # Skip samples that exceed the max length
+ # continue
+
+ # sample_length = 1 if self.batch_type == "example" else original_sample_length
+
# Set sample_length based on the batch type
- sample_length = 1 if self.batch_type == "example" else original_sample_length
+ if self.batch_type == "example":
+ sample_length = 1
+ elif self.batch_type == "token":
+ sample_length = self.dataset.get_source_len(idx) + int(
+ self.dataset.get_target_len(idx) * 1.2
+ )
+ else:
+ sample_length = self.dataset.get_source_len(idx)
# Calculate potential batch size with the new sample
potential_batch_length = max(max_len_in_batch, sample_length) * (len(batch) + 1)
# Add index to batch if it doesn't exceed batch size limit
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index d26124b..da008b4 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -21,6 +21,7 @@
self.min_source_length = kwargs.get("min_source_length", 0)
self.max_target_length = kwargs.get("max_target_length", 2048)
self.min_target_length = kwargs.get("min_target_length", 0)
+ self.max_token_length = kwargs.get("max_token_length", 2200)
is_training = kwargs.get("is_training", True)
if not (path.endswith(".jsonl") or path.endswith(".json")):
@@ -103,6 +104,10 @@
or target_len > self.max_target_length
):
continue
+
+ if (source_len + target_len) > self.max_token_length:
+ continue
+
contents_i = {
"source": source,
"prompt": prompt,
diff --git a/funasr/models/sense_voice/template.yaml b/funasr/models/sense_voice/template.yaml
index 4699c94..1a25ea4 100644
--- a/funasr/models/sense_voice/template.yaml
+++ b/funasr/models/sense_voice/template.yaml
@@ -1,11 +1,5 @@
-# This is an example that demonstrates how to configure a model file.
-# You can modify the configuration according to your own requirements.
-
-# to print the register_table:
-# from funasr.register import tables
-# tables.print()
# network architecture
-model: SenseVoice
+model: SenseVoiceRWKV
model_conf:
lsm_weight: 0.1
length_normalized_loss: true
@@ -27,6 +21,26 @@
n_text_head: 20
n_text_layer: 32
+
+# decoder
+decoder: SenseVoiceDecoder
+decoder_conf:
+ rwkv_cfg:
+ n_embd: 1280
+ dropout: 0
+ head_size_a: 64
+ ctx_len: 1280
+ dim_att: 1280 #${model_conf.rwkv_cfg.n_embd}
+ dim_ffn: null
+ head_size_divisor: 8
+ n_layer: 32
+ pre_ffn: 0
+ ln0: false
+ ln1: false
+ init_rwkv: false
+ datatype: bf16
+
+
# frontend related
frontend: WhisperFrontend
frontend_conf:
@@ -42,12 +56,17 @@
dataset: SenseVoiceDataset
dataset_conf:
- index_ds: IndexDSJsonl
+ index_ds: IndexDSJsonlRankSplit
batch_sampler: EspnetStyleBatchSampler
- batch_type: length # example or length
- batch_size: 7000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
- max_token_length: 2000 # filter samples if source_token_len+target_token_len > max_token_length,
+ rank_split: true
+ batch_type: token # example or length
+ batch_size: 3500 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 2200
min_token_length: 60
+ max_source_length: 2000
+ min_source_length: 60
+ max_target_length: 150
+ min_target_length: 0
shuffle: True
num_workers: 4
sos: ${model_conf.sos}
@@ -60,6 +79,7 @@
keep_nbest_models: 20
avg_nbest_model: ${train_conf.keep_nbest_models}
log_interval: 50
+ reset_gpu_cache: true
optim: adamw
optim_conf:
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 713e171..8f20ba4 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -107,6 +107,9 @@
self.best_step_or_epoch = ""
self.val_acc_step_or_eoch = {}
self.val_loss_step_or_eoch = {}
+
+ self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
+
def save_checkpoint(
self,
@@ -324,6 +327,12 @@
time2 = time.perf_counter()
with maybe_autocast(self.use_fp16):
retval = model(**batch)
+
+ if (
+ self.reset_gpu_cache
+ and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
+ ):
+ torch.cuda.empty_cache()
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
--
Gitblit v1.9.1