From ff4306346eae4021c711df3fe23979e82e06e751 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 21:26:25 +0800
Subject: [PATCH] aishell example
---
funasr/frontends/wav_frontend.py | 1
funasr/bin/train.py | 5 -
examples/aishell/paraformer/run.sh | 7 +-
funasr/bin/compute_audio_cmvn.py | 23 ++++---
funasr/datasets/audio_datasets/preprocessor.py | 83 +++++++++++++++++++++++++++
funasr/datasets/audio_datasets/datasets.py | 6 +-
6 files changed, 105 insertions(+), 20 deletions(-)
diff --git a/examples/aishell/paraformer/run.sh b/examples/aishell/paraformer/run.sh
index 410751a..149f4d7 100755
--- a/examples/aishell/paraformer/run.sh
+++ b/examples/aishell/paraformer/run.sh
@@ -50,6 +50,7 @@
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "stage -1: Data Download"
+ mkdir -p ${raw_data}
local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
fi
@@ -76,9 +77,8 @@
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "stage 1: Feature and CMVN Generation"
-# utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$config" --scale 1.0
python ../../../funasr/bin/compute_audio_cmvn.py \
- --config-path "${workspace}" \
+ --config-path "${workspace}/conf" \
--config-name "${config}" \
++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
++cmvn_file="${feats_dir}/data/${train_set}/cmvn.json" \
@@ -109,13 +109,14 @@
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "stage 4: ASR Training"
+ mkdir -p ${exp_dir}/exp/${model_dir}
log_file="${exp_dir}/exp/${model_dir}/train.log.txt"
echo "log_file: ${log_file}"
torchrun \
--nnodes 1 \
--nproc_per_node ${gpu_num} \
../../../funasr/bin/train.py \
- --config-path "${workspace}" \
+ --config-path "${workspace}/conf" \
--config-name "${config}" \
++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
++tokenizer_conf.token_list="${token_list}" \
diff --git a/funasr/bin/compute_audio_cmvn.py b/funasr/bin/compute_audio_cmvn.py
index b66bb14..4561bec 100644
--- a/funasr/bin/compute_audio_cmvn.py
+++ b/funasr/bin/compute_audio_cmvn.py
@@ -79,8 +79,8 @@
fbank = batch["speech"].numpy()[0, :, :]
if total_frames == 0:
- mean_stats = fbank
- var_stats = np.square(fbank)
+ mean_stats = np.sum(fbank, axis=0)
+ var_stats = np.sum(np.square(fbank), axis=0)
else:
mean_stats += np.sum(fbank, axis=0)
var_stats += np.sum(np.square(fbank), axis=0)
@@ -93,6 +93,7 @@
'total_frames': total_frames
}
cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
+ # import pdb;pdb.set_trace()
with open(cmvn_file, 'w') as fout:
fout.write(json.dumps(cmvn_info))
@@ -110,14 +111,14 @@
fout.write("</Nnet>" + '\n')
-
+
+"""
+python funasr/bin/compute_audio_cmvn.py \
+--config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
+--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
+++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
+++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
+++dataset_conf.num_workers=0
+"""
if __name__ == "__main__":
main_hydra()
- """
- python funasr/bin/compute_status.py \
- --config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \
- --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
- ++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
- ++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
- ++dataset_conf.num_workers=32
- """
\ No newline at end of file
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index c9a4a67..d916509 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -79,9 +79,8 @@
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
-
- # import pdb;
- # pdb.set_trace()
+
+
# build model
model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py
index 62acb44..ab08fb0 100644
--- a/funasr/datasets/audio_datasets/datasets.py
+++ b/funasr/datasets/audio_datasets/datasets.py
@@ -22,12 +22,12 @@
self.index_ds = index_ds_class(path, **kwargs)
preprocessor_speech = kwargs.get("preprocessor_speech", None)
if preprocessor_speech:
- preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
+ preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
self.preprocessor_speech = preprocessor_speech
preprocessor_text = kwargs.get("preprocessor_text", None)
if preprocessor_text:
- preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
+ preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
self.preprocessor_text = preprocessor_text
@@ -57,7 +57,7 @@
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
- data_src = self.preprocessor_speech(data_src)
+ data_src = self.preprocessor_speech(data_src, fs=self.fs)
speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
target = item["target"]
diff --git a/funasr/datasets/audio_datasets/preprocessor.py b/funasr/datasets/audio_datasets/preprocessor.py
new file mode 100644
index 0000000..6c21fbf
--- /dev/null
+++ b/funasr/datasets/audio_datasets/preprocessor.py
@@ -0,0 +1,83 @@
+import os
+import json
+import torch
+import logging
+import concurrent.futures
+import librosa
+import torch.distributed as dist
+from typing import Collection
+import torch
+import torchaudio
+from torch import nn
+import random
+import re
+from funasr.tokenizer.cleaner import TextCleaner
+from funasr.register import tables
+
+
+@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
+class SpeechPreprocessSpeedPerturb(nn.Module):
+ def __init__(self, speed_perturb: list=None, **kwargs):
+ super().__init__()
+ self.speed_perturb = speed_perturb
+
+ def forward(self, waveform, fs, **kwargs):
+ if self.speed_perturb is None:
+ return waveform
+ speed = random.choice(self.speed_perturb)
+ if speed != 1.0:
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+ torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
+ waveform = waveform.view(-1)
+
+ return waveform
+
+
+@tables.register("preprocessor_classes", "TextPreprocessSegDict")
+class TextPreprocessSegDict(nn.Module):
+ def __init__(self, seg_dict: str = None,
+ text_cleaner: Collection[str] = None,
+ split_with_space: bool = False,
+ **kwargs):
+ super().__init__()
+
+ self.seg_dict = None
+ if seg_dict is not None:
+ self.seg_dict = {}
+ with open(seg_dict, "r", encoding="utf8") as f:
+ lines = f.readlines()
+ for line in lines:
+ s = line.strip().split()
+ key = s[0]
+ value = s[1:]
+ self.seg_dict[key] = " ".join(value)
+ self.text_cleaner = TextCleaner(text_cleaner)
+ self.split_with_space = split_with_space
+
+ def forward(self, text, **kwargs):
+ if self.seg_dict is not None:
+ text = self.text_cleaner(text)
+ if self.split_with_space:
+ tokens = text.strip().split(" ")
+ if self.seg_dict is not None:
+ text = seg_tokenize(tokens, self.seg_dict)
+
+ return text
+
+def seg_tokenize(txt, seg_dict):
+ pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
+ out_txt = ""
+ for word in txt:
+ word = word.lower()
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ if pattern.match(word):
+ for char in word:
+ if char in seg_dict:
+ out_txt += seg_dict[char] + " "
+ else:
+ out_txt += "<unk>" + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
\ No newline at end of file
diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py
index c6e03e8..71cf77a 100644
--- a/funasr/frontends/wav_frontend.py
+++ b/funasr/frontends/wav_frontend.py
@@ -32,6 +32,7 @@
rescale_line = line_item[3:(len(line_item) - 1)]
vars_list = list(rescale_line)
continue
+ import pdb;pdb.set_trace()
means = np.array(means_list).astype(np.float32)
vars = np.array(vars_list).astype(np.float32)
cmvn = np.array([means, vars])
--
Gitblit v1.9.1