From 11cf10e433c173efd892766b669e0bba57253fed Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 14:52:20 +0800
Subject: [PATCH] Dev gzf exp (#1678)
---
funasr/models/sense_voice/model.py | 2 ++
funasr/schedulers/lambdalr_cus.py | 42 +++++++++++++++++++++++++-----------------
funasr/tokenizer/abs_tokenizer.py | 2 +-
funasr/datasets/audio_datasets/scp2jsonl.py | 13 ++++++++++---
4 files changed, 38 insertions(+), 21 deletions(-)
diff --git a/funasr/datasets/audio_datasets/scp2jsonl.py b/funasr/datasets/audio_datasets/scp2jsonl.py
index f6ceb69..f167173 100644
--- a/funasr/datasets/audio_datasets/scp2jsonl.py
+++ b/funasr/datasets/audio_datasets/scp2jsonl.py
@@ -7,6 +7,7 @@
import concurrent.futures
import librosa
import torch.distributed as dist
+from tqdm import tqdm
def gen_jsonl_from_wav_text_list(
@@ -28,6 +29,7 @@
with open(data_file, "r") as f:
data_file_lists = f.readlines()
+ print("")
lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
# import pdb;pdb.set_trace()
@@ -41,6 +43,7 @@
i * lines_for_each_th : (i + 1) * lines_for_each_th
],
data_type,
+ i,
)
for i in range(task_num)
]
@@ -69,11 +72,15 @@
dist.barrier()
-def parse_context_length(data_list: list, data_type: str):
-
+def parse_context_length(data_list: list, data_type: str, id=0):
+ pbar = tqdm(total=len(data_list), dynamic_ncols=True)
res = {}
for i, line in enumerate(data_list):
- key, line = line.strip().split(maxsplit=1)
+ pbar.update(1)
+ pbar.set_description(f"cpu: {id}")
+ lines = line.strip().split(maxsplit=1)
+ key = lines[0]
+ line = lines[1] if len(lines) > 1 else ""
line = line.strip()
if os.path.exists(line):
waveform, _ = librosa.load(line, sr=16000)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 07fb4eb..ae20902 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -329,6 +329,8 @@
stats["loss"] = torch.clone(loss.detach())
stats["batch_size"] = batch_size
stats["batch_size_x_frames"] = frames * batch_size
+ stats["batch_size_real_frames"] = speech_lengths.sum().item()
+ stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
diff --git a/funasr/schedulers/lambdalr_cus.py b/funasr/schedulers/lambdalr_cus.py
index 19ad7a8..e3bb1fb 100644
--- a/funasr/schedulers/lambdalr_cus.py
+++ b/funasr/schedulers/lambdalr_cus.py
@@ -2,28 +2,36 @@
from torch.optim.lr_scheduler import _LRScheduler
+# class CustomLambdaLR(_LRScheduler):
+# def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+# self.warmup_steps = warmup_steps
+# super().__init__(optimizer, last_epoch)
+#
+# def get_lr(self):
+# if self.last_epoch < self.warmup_steps:
+# return [
+# base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
+# ]
+# else:
+# return [base_lr for base_lr in self.base_lrs]
+
+
class CustomLambdaLR(_LRScheduler):
- def __init__(self, optimizer, warmup_steps, last_epoch=-1):
+ def __init__(
+ self,
+ optimizer,
+ warmup_steps: int = 25000,
+ total_steps: int = 500000,
+ last_epoch=-1,
+ verbose=False,
+ ):
self.warmup_steps = warmup_steps
- super().__init__(optimizer, last_epoch)
+ self.total_steps = total_steps
+ super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
- if self.last_epoch < self.warmup_steps:
- return [
- base_lr * min(self.last_epoch / self.warmup_steps, 1) for base_lr in self.base_lrs
- ]
- else:
- return [base_lr for base_lr in self.base_lrs]
-
-class CustomLambdaLR(_LRScheduler):
- def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False):
- self.warmup_steps = train_config.warmup_steps
- self.total_steps = train_config.total_steps
- super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose)
-
- def get_lr(self):
- step = self._step_count
+ step = self.last_epoch + 1
if step < self.warmup_steps:
lr_scale = step / self.warmup_steps
else:
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index a629e94..e125d29 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -62,7 +62,7 @@
raise RuntimeError(f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list")
self.unk_id = self.token2id[self.unk_symbol]
- def encode(self, text):
+ def encode(self, text, **kwargs):
tokens = self.text2tokens(text)
text_ints = self.tokens2ids(tokens)
--
Gitblit v1.9.1