游雁
2024-02-19 ff4306346eae4021c711df3fe23979e82e06e751
aishell example
5个文件已修改
1个文件已添加
125 ■■■■ 已修改文件
examples/aishell/paraformer/run.sh 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/compute_audio_cmvn.py 23 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/datasets.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/preprocessor.py 83 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/frontends/wav_frontend.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/run.sh
@@ -50,6 +50,7 @@
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
    echo "stage -1: Data Download"
    mkdir -p ${raw_data}
    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
fi
@@ -76,9 +77,8 @@
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    echo "stage 1: Feature and CMVN Generation"
#    utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$config" --scale 1.0
    python ../../../funasr/bin/compute_audio_cmvn.py \
    --config-path "${workspace}" \
    --config-path "${workspace}/conf" \
    --config-name "${config}" \
    ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
    ++cmvn_file="${feats_dir}/data/${train_set}/cmvn.json" \
@@ -109,13 +109,14 @@
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
  echo "stage 4: ASR Training"
  mkdir -p ${exp_dir}/exp/${model_dir}
  log_file="${exp_dir}/exp/${model_dir}/train.log.txt"
  echo "log_file: ${log_file}"
  torchrun \
  --nnodes 1 \
  --nproc_per_node ${gpu_num} \
  ../../../funasr/bin/train.py \
  --config-path "${workspace}" \
  --config-path "${workspace}/conf" \
  --config-name "${config}" \
  ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
  ++tokenizer_conf.token_list="${token_list}" \
funasr/bin/compute_audio_cmvn.py
@@ -79,8 +79,8 @@
        fbank = batch["speech"].numpy()[0, :, :]
        if total_frames == 0:
            mean_stats = fbank
            var_stats = np.square(fbank)
            mean_stats = np.sum(fbank, axis=0)
            var_stats = np.sum(np.square(fbank), axis=0)
        else:
            mean_stats += np.sum(fbank, axis=0)
            var_stats += np.sum(np.square(fbank), axis=0)
@@ -93,6 +93,7 @@
        'total_frames': total_frames
    }
    cmvn_file = kwargs.get("cmvn_file", "cmvn.json")
    # import pdb;pdb.set_trace()
    with open(cmvn_file, 'w') as fout:
        fout.write(json.dumps(cmvn_info))
        
@@ -110,14 +111,14 @@
        fout.write("</Nnet>" + '\n')
    
    
"""
python funasr/bin/compute_audio_cmvn.py \
--config-path "/Users/zhifu/funasr1.0/examples/aishell/paraformer/conf" \
--config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
++dataset_conf.num_workers=0
"""
if __name__ == "__main__":
    main_hydra()
    """
    python funasr/bin/compute_status.py \
    --config-path "/Users/zhifu/funasr1.0/examples/aishell/conf" \
    --config-name "train_asr_paraformer_conformer_12e_6d_2048_256.yaml" \
    ++train_data_set_list="/Users/zhifu/funasr1.0/data/list/audio_datasets.jsonl" \
    ++cmvn_file="/Users/zhifu/funasr1.0/data/list/cmvn.json" \
    ++dataset_conf.num_workers=32
    """
funasr/bin/train.py
@@ -79,9 +79,8 @@
        frontend = frontend_class(**kwargs["frontend_conf"])
        kwargs["frontend"] = frontend
        kwargs["input_size"] = frontend.output_size()
    # import pdb;
    # pdb.set_trace()
    # build model
    model_class = tables.model_classes.get(kwargs["model"])
    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
funasr/datasets/audio_datasets/datasets.py
@@ -22,12 +22,12 @@
        self.index_ds = index_ds_class(path, **kwargs)
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
        self.preprocessor_speech = preprocessor_speech
        preprocessor_text = kwargs.get("preprocessor_text", None)
        if preprocessor_text:
            preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
        self.preprocessor_text = preprocessor_text
        
@@ -57,7 +57,7 @@
        source = item["source"]
        data_src = load_audio_text_image_video(source, fs=self.fs)
        if self.preprocessor_speech:
            data_src = self.preprocessor_speech(data_src)
            data_src = self.preprocessor_speech(data_src, fs=self.fs)
        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
        target = item["target"]
funasr/datasets/audio_datasets/preprocessor.py
New file
@@ -0,0 +1,83 @@
import os
import json
import torch
import logging
import concurrent.futures
import librosa
import torch.distributed as dist
from typing import Collection
import torch
import torchaudio
from torch import nn
import random
import re
from funasr.tokenizer.cleaner import TextCleaner
from funasr.register import tables
@tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb")
class SpeechPreprocessSpeedPerturb(nn.Module):
    def __init__(self, speed_perturb: list=None, **kwargs):
        super().__init__()
        self.speed_perturb = speed_perturb
    def forward(self, waveform, fs, **kwargs):
        if self.speed_perturb is None:
            return waveform
        speed = random.choice(self.speed_perturb)
        if speed != 1.0:
            waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
                torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
            waveform = waveform.view(-1)
        return waveform
@tables.register("preprocessor_classes", "TextPreprocessSegDict")
class TextPreprocessSegDict(nn.Module):
    def __init__(self, seg_dict: str = None,
                 text_cleaner: Collection[str] = None,
                 split_with_space: bool = False,
                 **kwargs):
        super().__init__()
        self.seg_dict = None
        if seg_dict is not None:
            self.seg_dict = {}
            with open(seg_dict, "r", encoding="utf8") as f:
                lines = f.readlines()
            for line in lines:
                s = line.strip().split()
                key = s[0]
                value = s[1:]
                self.seg_dict[key] = " ".join(value)
        self.text_cleaner = TextCleaner(text_cleaner)
        self.split_with_space = split_with_space
    def forward(self, text, **kwargs):
        if self.seg_dict is not None:
            text = self.text_cleaner(text)
            if self.split_with_space:
                tokens = text.strip().split(" ")
                if self.seg_dict is not None:
                    text = seg_tokenize(tokens, self.seg_dict)
        return text
def seg_tokenize(txt, seg_dict):
    pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
    out_txt = ""
    for word in txt:
        word = word.lower()
        if word in seg_dict:
            out_txt += seg_dict[word] + " "
        else:
            if pattern.match(word):
                for char in word:
                    if char in seg_dict:
                        out_txt += seg_dict[char] + " "
                    else:
                        out_txt += "<unk>" + " "
            else:
                out_txt += "<unk>" + " "
    return out_txt.strip().split()
funasr/frontends/wav_frontend.py
@@ -32,6 +32,7 @@
                rescale_line = line_item[3:(len(line_item) - 1)]
                vars_list = list(rescale_line)
                continue
    import pdb;pdb.set_trace()
    means = np.array(means_list).astype(np.float32)
    vars = np.array(vars_list).astype(np.float32)
    cmvn = np.array([means, vars])