Merge pull request #453 from alibaba-damo-academy/dev_clas
Update NeatContextualParaformer, finetune pipeline and dataset
| | |
| | | ../TEMPLATE/README.md |
| | | ../../TEMPLATE/README.md |
| New file |
| | |
| | | import os |
| | | |
| | | from modelscope.metainfo import Trainers |
| | | from modelscope.trainers import build_trainer |
| | | |
| | | import funasr |
| | | from funasr.datasets.ms_dataset import MsDataset |
| | | from funasr.utils.modelscope_param import modelscope_args |
| | | |
| | | |
| | | def modelscope_finetune(params): |
| | | if not os.path.exists(params.output_dir): |
| | | os.makedirs(params.output_dir, exist_ok=True) |
| | | # dataset split ["train", "validation"] |
| | | ds_dict = MsDataset.load(params.data_path) |
| | | kwargs = dict( |
| | | model=params.model, |
| | | data_dir=ds_dict, |
| | | dataset_type=params.dataset_type, |
| | | work_dir=params.output_dir, |
| | | batch_bins=params.batch_bins, |
| | | max_epoch=params.max_epoch, |
| | | lr=params.lr) |
| | | trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs) |
| | | trainer.train() |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | params = modelscope_args(model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", data_path="./data") |
| | | params.output_dir = "./checkpoint" # m模型保存路径 |
| | | params.data_path = "./example_data/" # 数据路径 |
| | | params.dataset_type = "large" # 小数据量设置small,若数据量大于1000小时,请使用large |
| | | params.batch_bins = 2000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒, |
| | | params.max_epoch = 50 # 最大训练轮数 |
| | | params.lr = 0.00005 # 设置学习率 |
| | | |
| | | modelscope_finetune(params) |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | set -e |
| | | set -u |
| | | set -o pipefail |
| | | |
| | | stage=1 |
| | | stop_stage=2 |
| | | model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" |
| | | data_dir="./data/test" |
| | | output_dir="./results" |
| | | batch_size=64 |
| | | gpu_inference=true # whether to perform gpu decoding |
| | | gpuid_list="0,1" # set gpus, e.g., gpuid_list="0,1" |
| | | njob=10 # the number of jobs for CPU decoding, if gpu_inference=false, use CPU decoding, please set njob |
| | | checkpoint_dir= |
| | | checkpoint_name="valid.cer_ctc.ave.pb" |
| | | hotword_txt=None |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | if ${gpu_inference} == "true"; then |
| | | nj=$(echo $gpuid_list | awk -F "," '{print NF}') |
| | | else |
| | | nj=$njob |
| | | batch_size=1 |
| | | gpuid_list="" |
| | | for JOB in $(seq ${nj}); do |
| | | gpuid_list=$gpuid_list"-1," |
| | | done |
| | | fi |
| | | |
| | | mkdir -p $output_dir/split |
| | | split_scps="" |
| | | for JOB in $(seq ${nj}); do |
| | | split_scps="$split_scps $output_dir/split/wav.$JOB.scp" |
| | | done |
| | | perl utils/split_scp.pl ${data_dir}/wav.scp ${split_scps} |
| | | |
| | | if [ -n "${checkpoint_dir}" ]; then |
| | | python utils/prepare_checkpoint.py ${model} ${checkpoint_dir} ${checkpoint_name} |
| | | model=${checkpoint_dir}/${model} |
| | | fi |
| | | |
| | | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ];then |
| | | echo "Decoding ..." |
| | | gpuid_list_array=(${gpuid_list//,/ }) |
| | | for JOB in $(seq ${nj}); do |
| | | { |
| | | id=$((JOB-1)) |
| | | gpuid=${gpuid_list_array[$id]} |
| | | mkdir -p ${output_dir}/output.$JOB |
| | | python infer.py \ |
| | | --model ${model} \ |
| | | --audio_in ${output_dir}/split/wav.$JOB.scp \ |
| | | --output_dir ${output_dir}/output.$JOB \ |
| | | --batch_size ${batch_size} \ |
| | | --hotword_txt ${hotword_txt} \ |
| | | --gpuid ${gpuid} |
| | | }& |
| | | done |
| | | wait |
| | | |
| | | mkdir -p ${output_dir}/1best_recog |
| | | for f in token score text; do |
| | | if [ -f "${output_dir}/output.1/1best_recog/${f}" ]; then |
| | | for i in $(seq "${nj}"); do |
| | | cat "${output_dir}/output.${i}/1best_recog/${f}" |
| | | done | sort -k1 >"${output_dir}/1best_recog/${f}" |
| | | fi |
| | | done |
| | | fi |
| | | |
| | | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then |
| | | echo "Computing WER ..." |
| | | cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc |
| | | cp ${data_dir}/text ${output_dir}/1best_recog/text.ref |
| | | python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer |
| | | tail -n 3 ${output_dir}/1best_recog/text.cer |
| | | fi |
| | | |
| | | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ];then |
| | | echo "SpeechIO TIOBE textnorm" |
| | | echo "$0 --> Normalizing REF text ..." |
| | | ./utils/textnorm_zh.py \ |
| | | --has_key --to_upper \ |
| | | ${data_dir}/text \ |
| | | ${output_dir}/1best_recog/ref.txt |
| | | |
| | | echo "$0 --> Normalizing HYP text ..." |
| | | ./utils/textnorm_zh.py \ |
| | | --has_key --to_upper \ |
| | | ${output_dir}/1best_recog/text.proc \ |
| | | ${output_dir}/1best_recog/rec.txt |
| | | grep -v $'\t$' ${output_dir}/1best_recog/rec.txt > ${output_dir}/1best_recog/rec_non_empty.txt |
| | | |
| | | echo "$0 --> computing WER/CER and alignment ..." |
| | | ./utils/error_rate_zh \ |
| | | --tokenizer char \ |
| | | --ref ${output_dir}/1best_recog/ref.txt \ |
| | | --hyp ${output_dir}/1best_recog/rec_non_empty.txt \ |
| | | ${output_dir}/1best_recog/DETAILS.txt | tee ${output_dir}/1best_recog/RESULTS.txt |
| | | rm -rf ${output_dir}/1best_recog/rec.txt ${output_dir}/1best_recog/rec_non_empty.txt |
| | | fi |
| | | |
| New file |
| | |
| | | import os |
| | | import tempfile |
| | | import codecs |
| | | from modelscope.pipelines import pipeline |
| | | from modelscope.utils.constant import Tasks |
| | | from modelscope.msdatasets import MsDataset |
| | | |
| | | if __name__ == '__main__': |
| | | param_dict = dict() |
| | | param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt" |
| | | |
| | | output_dir = "./output" |
| | | batch_size = 1 |
| | | |
| | | # dataset split ['test'] |
| | | ds_dict = MsDataset.load(dataset_name='speech_asr_aishell1_hotwords_testsets', namespace='speech_asr') |
| | | work_dir = tempfile.TemporaryDirectory().name |
| | | if not os.path.exists(work_dir): |
| | | os.makedirs(work_dir) |
| | | wav_file_path = os.path.join(work_dir, "wav.scp") |
| | | |
| | | counter = 0 |
| | | with codecs.open(wav_file_path, 'w') as fin: |
| | | for line in ds_dict: |
| | | counter += 1 |
| | | wav = line["Audio:FILE"] |
| | | idx = wav.split("/")[-1].split(".")[0] |
| | | fin.writelines(idx + " " + wav + "\n") |
| | | if counter == 50: |
| | | break |
| | | audio_in = wav_file_path |
| | | |
| | | inference_pipeline = pipeline( |
| | | task=Tasks.auto_speech_recognition, |
| | | model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", |
| | | output_dir=output_dir, |
| | | batch_size=batch_size, |
| | | param_dict=param_dict) |
| | | |
| | | rec_result = inference_pipeline(audio_in=audio_in) |
| | |
| | | from funasr.utils import asr_utils, wav_utils, postprocess_utils |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer |
| | | from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer |
| | | from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | from funasr.bin.tp_inference import SpeechText2Timestamp |
| | |
| | | pre_token_length = pre_token_length.round().long() |
| | | if torch.max(pre_token_length) < 1: |
| | | return [] |
| | | if not isinstance(self.asr_model, ContextualParaformer): |
| | | if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, NeatContextualParaformer): |
| | | if self.hotword_list: |
| | | logging.warning("Hotword is given but asr model is not a ContextualParaformer.") |
| | | decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) |
| | |
| | | finetune_configs = yaml.safe_load(f) |
| | | # set data_types |
| | | if dataset_type == "large": |
| | | finetune_configs["dataset_conf"]["data_types"] = "sound,text" |
| | | if 'data_types' not in finetune_configs['dataset_conf']: |
| | | finetune_configs["dataset_conf"]["data_types"] = "sound,text" |
| | | finetune_configs = update_dct(configs, finetune_configs) |
| | | for key, value in finetune_configs.items(): |
| | | if hasattr(args, key): |
| | |
| | | if args.dataset_type == "small": |
| | | args.batch_bins = batch_bins |
| | | elif args.dataset_type == "large": |
| | | args.dataset_conf["batch_conf"]["batch_size"] = batch_bins |
| | | if "batch_size" not in args.dataset_conf["batch_conf"]: |
| | | args.dataset_conf["batch_conf"]["batch_size"] = batch_bins |
| | | else: |
| | | raise ValueError(f"Not supported dataset_type={args.dataset_type}") |
| | | if args.normalize in ["null", "none", "None"]: |
| | |
| | | if data_type == "kaldi_ark": |
| | | ark_reader = ReadHelper('ark:{}'.format(data_file)) |
| | | reader_list.append(ark_reader) |
| | | elif data_type == "text" or data_type == "sound": |
| | | elif data_type == "text" or data_type == "sound" or data_type == 'text_hotword': |
| | | text_reader = open(data_file, "r") |
| | | reader_list.append(text_reader) |
| | | elif data_type == "none": |
| | |
| | | sample_dict["sampling_rate"] = sampling_rate |
| | | if data_name == "speech": |
| | | sample_dict["key"] = key |
| | | elif data_type == "text_hotword": |
| | | text = item |
| | | segs = text.strip().split() |
| | | sample_dict[data_name] = segs[1:] |
| | | if "key" not in sample_dict: |
| | | sample_dict["key"] = segs[0] |
| | | sample_dict['hw_tag'] = 1 |
| | | else: |
| | | text = item |
| | | segs = text.strip().split() |
| | |
| | | shuffle = conf.get('shuffle', True) |
| | | data_names = conf.get("data_names", "speech,text") |
| | | data_types = conf.get("data_types", "kaldi_ark,text") |
| | | dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, mode=mode) |
| | | |
| | | pre_hwfile = conf.get("pre_hwlist", None) |
| | | pre_prob = conf.get("pre_prob", 0) # unused yet |
| | | |
| | | hw_config = {"sample_rate": conf.get("sample_rate", 0.6), |
| | | "double_rate": conf.get("double_rate", 0.1), |
| | | "hotword_min_length": conf.get("hotword_min_length", 2), |
| | | "hotword_max_length": conf.get("hotword_max_length", 8), |
| | | "pre_prob": conf.get("pre_prob", 0.0)} |
| | | |
| | | if pre_hwfile is not None: |
| | | pre_hwlist = [] |
| | | with open(pre_hwfile, 'r') as fin: |
| | | for line in fin.readlines(): |
| | | pre_hwlist.append(line.strip()) |
| | | else: |
| | | pre_hwlist = None |
| | | |
| | | dataset = AudioDataset(scp_lists, |
| | | data_names, |
| | | data_types, |
| | | frontend_conf=frontend_conf, |
| | | shuffle=shuffle, |
| | | mode=mode, |
| | | ) |
| | | |
| | | filter_conf = conf.get('filter_conf', {}) |
| | | filter_fn = partial(filter, **filter_conf) |
| | | dataset = FilterIterDataPipe(dataset, fn=filter_fn) |
| | | |
| | | if "text" in data_names: |
| | | vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer} |
| | | vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer, 'hw_config': hw_config} |
| | | tokenize_fn = partial(tokenize, **vocab) |
| | | dataset = MapperIterDataPipe(dataset, fn=tokenize_fn) |
| | | |
| New file |
| | |
| | | import random |
| | | |
| | | def sample_hotword(length, |
| | | hotword_min_length, |
| | | hotword_max_length, |
| | | sample_rate, |
| | | double_rate, |
| | | pre_prob, |
| | | pre_index=None): |
| | | if length < hotword_min_length: |
| | | return [-1] |
| | | if random.random() < sample_rate: |
| | | if pre_prob > 0 and random.random() < pre_prob and pre_index is not None: |
| | | return pre_index |
| | | if length == hotword_min_length: |
| | | return [0, length-1] |
| | | elif random.random() < double_rate and length > hotword_max_length + hotword_min_length + 2: |
| | | # sample two hotwords in a sentence |
| | | _max_hw_length = min(hotword_max_length, length // 2) |
| | | # first hotword |
| | | start1 = random.randint(0, length // 3) |
| | | end1 = random.randint(start1 + hotword_min_length - 1, start1 + _max_hw_length - 1) |
| | | # second hotword |
| | | start2 = random.randint(end1 + 1, length - hotword_min_length) |
| | | end2 = random.randint(min(length-1, start2+hotword_min_length-1), min(length-1, start2+hotword_max_length-1)) |
| | | return [start1, end1, start2, end2] |
| | | else: # single hotword |
| | | start = random.randint(0, length - hotword_min_length) |
| | | end = random.randint(min(length-1, start+hotword_min_length-1), min(length-1, start+hotword_max_length-1)) |
| | | return [start, end] |
| | | else: |
| | | return [-1] |
| | |
| | | batch = {} |
| | | data_names = data[0].keys() |
| | | for data_name in data_names: |
| | | if data_name == "key" or data_name =="sampling_rate": |
| | | if data_name == "key" or data_name == "sampling_rate": |
| | | continue |
| | | else: |
| | | if data[0][data_name].dtype.kind == "i": |
| | | pad_value = int_pad_value |
| | | tensor_type = torch.int64 |
| | | else: |
| | | pad_value = float_pad_value |
| | | tensor_type = torch.float32 |
| | | if data_name != 'hotword_indxs': |
| | | if data[0][data_name].dtype.kind == "i": |
| | | pad_value = int_pad_value |
| | | tensor_type = torch.int64 |
| | | else: |
| | | pad_value = float_pad_value |
| | | tensor_type = torch.float32 |
| | | |
| | | tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data] |
| | | tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32) |
| | |
| | | batch[data_name] = tensor_pad |
| | | batch[data_name + "_lengths"] = tensor_lengths |
| | | |
| | | # DHA, EAHC NOT INCLUDED |
| | | if "hotword_indxs" in batch: |
| | | # if hotword indxs in batch |
| | | # use it to slice hotwords out |
| | | hotword_list = [] |
| | | hotword_lengths = [] |
| | | text = batch['text'] |
| | | text_lengths = batch['text_lengths'] |
| | | hotword_indxs = batch['hotword_indxs'] |
| | | num_hw = sum([int(i) for i in batch['hotword_indxs_lengths'] if i != 1]) // 2 |
| | | B, t1 = text.shape |
| | | t1 += 1 # TODO: as parameter which is same as predictor_bias |
| | | ideal_attn = torch.zeros(B, t1, num_hw+1) |
| | | nth_hw = 0 |
| | | for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)): |
| | | ideal_attn[b][:,-1] = 1 |
| | | if hotword_indx[0] != -1: |
| | | start, end = int(hotword_indx[0]), int(hotword_indx[1]) |
| | | hotword = one_text[start: end+1] |
| | | hotword_list.append(hotword) |
| | | hotword_lengths.append(end-start+1) |
| | | ideal_attn[b][start:end+1, nth_hw] = 1 |
| | | ideal_attn[b][start:end+1, -1] = 0 |
| | | nth_hw += 1 |
| | | if len(hotword_indx) == 4 and hotword_indx[2] != -1: |
| | | # the second hotword if exist |
| | | start, end = int(hotword_indx[2]), int(hotword_indx[3]) |
| | | hotword_list.append(one_text[start: end+1]) |
| | | hotword_lengths.append(end-start+1) |
| | | ideal_attn[b][start:end+1, nth_hw-1] = 1 |
| | | ideal_attn[b][start:end+1, -1] = 0 |
| | | nth_hw += 1 |
| | | hotword_list.append(torch.tensor([1])) |
| | | hotword_lengths.append(1) |
| | | hotword_pad = pad_sequence(hotword_list, |
| | | batch_first=True, |
| | | padding_value=0) |
| | | batch["hotword_pad"] = hotword_pad |
| | | batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32) |
| | | batch['ideal_attn'] = ideal_attn |
| | | del batch['hotword_indxs'] |
| | | del batch['hotword_indxs_lengths'] |
| | | |
| | | return keys, batch |
| | |
| | | #!/usr/bin/env python |
| | | import re |
| | | import numpy as np |
| | | from funasr.datasets.large_datasets.utils.hotword_utils import sample_hotword |
| | | |
| | | def forward_segment(text, seg_dict): |
| | | word_list = [] |
| | |
| | | vocab=None, |
| | | seg_dict=None, |
| | | punc_dict=None, |
| | | bpe_tokenizer=None): |
| | | bpe_tokenizer=None, |
| | | hw_config=None): |
| | | assert "text" in data |
| | | assert isinstance(vocab, dict) |
| | | text = data["text"] |
| | |
| | | text = seg_tokenize(text, seg_dict) |
| | | |
| | | length = len(text) |
| | | if 'hw_tag' in data: |
| | | hotword_indxs = sample_hotword(length, **hw_config) |
| | | data['hotword_indxs'] = hotword_indxs |
| | | del data['hw_tag'] |
| | | for i in range(length): |
| | | x = text[i] |
| | | if i == length-1 and "punc" in data and x.startswith("vad:"): |
| New file |
| | |
| | | import logging |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | from typing import Union |
| | | import numpy as np |
| | | |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.postencoder.abs_postencoder import AbsPostEncoder |
| | | from funasr.models.preencoder.abs_preencoder import AbsPreEncoder |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.modules.add_sos_eos import add_sos_eos |
| | | from funasr.modules.nets_utils import make_pad_mask, pad_list |
| | | from funasr.modules.nets_utils import th_accuracy |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.models.e2e_asr_paraformer import Paraformer |
| | | |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | | else: |
| | | # Nothing to do if torch<1.6.0 |
| | | @contextmanager |
| | | def autocast(enabled=True): |
| | | yield |
| | | |
| | | |
| | | class NeatContextualParaformer(Paraformer): |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | token_list: Union[Tuple[str, ...], List[str]], |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | normalize: Optional[AbsNormalize], |
| | | preencoder: Optional[AbsPreEncoder], |
| | | encoder: AbsEncoder, |
| | | postencoder: Optional[AbsPostEncoder], |
| | | decoder: AbsDecoder, |
| | | ctc: CTC, |
| | | ctc_weight: float = 0.5, |
| | | interctc_weight: float = 0.0, |
| | | ignore_id: int = -1, |
| | | blank_id: int = 0, |
| | | sos: int = 1, |
| | | eos: int = 2, |
| | | lsm_weight: float = 0.0, |
| | | length_normalized_loss: bool = False, |
| | | report_cer: bool = True, |
| | | report_wer: bool = True, |
| | | sym_space: str = "<space>", |
| | | sym_blank: str = "<blank>", |
| | | extract_feats_in_collect_stats: bool = True, |
| | | predictor = None, |
| | | predictor_weight: float = 0.0, |
| | | predictor_bias: int = 0, |
| | | sampling_ratio: float = 0.2, |
| | | target_buffer_length: int = -1, |
| | | inner_dim: int = 256, |
| | | bias_encoder_type: str = 'lstm', |
| | | use_decoder_embedding: bool = False, |
| | | crit_attn_weight: float = 0.0, |
| | | crit_attn_smooth: float = 0.0, |
| | | bias_encoder_dropout_rate: float = 0.0, |
| | | ): |
| | | assert check_argument_types() |
| | | assert 0.0 <= ctc_weight <= 1.0, ctc_weight |
| | | assert 0.0 <= interctc_weight < 1.0, interctc_weight |
| | | |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | preencoder=preencoder, |
| | | encoder=encoder, |
| | | postencoder=postencoder, |
| | | decoder=decoder, |
| | | ctc=ctc, |
| | | ctc_weight=ctc_weight, |
| | | interctc_weight=interctc_weight, |
| | | ignore_id=ignore_id, |
| | | blank_id=blank_id, |
| | | sos=sos, |
| | | eos=eos, |
| | | lsm_weight=lsm_weight, |
| | | length_normalized_loss=length_normalized_loss, |
| | | report_cer=report_cer, |
| | | report_wer=report_wer, |
| | | sym_space=sym_space, |
| | | sym_blank=sym_blank, |
| | | extract_feats_in_collect_stats=extract_feats_in_collect_stats, |
| | | predictor=predictor, |
| | | predictor_weight=predictor_weight, |
| | | predictor_bias=predictor_bias, |
| | | sampling_ratio=sampling_ratio, |
| | | ) |
| | | |
| | | if bias_encoder_type == 'lstm': |
| | | logging.warning("enable bias encoder sampling and contextual training") |
| | | self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate) |
| | | self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim) |
| | | elif bias_encoder_type == 'mean': |
| | | logging.warning("enable bias encoder sampling and contextual training") |
| | | self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim) |
| | | else: |
| | | logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type)) |
| | | |
| | | self.target_buffer_length = target_buffer_length |
| | | if self.target_buffer_length > 0: |
| | | self.hotword_buffer = None |
| | | self.length_record = [] |
| | | self.current_buffer_length = 0 |
| | | self.use_decoder_embedding = use_decoder_embedding |
| | | self.crit_attn_weight = crit_attn_weight |
| | | if self.crit_attn_weight > 0: |
| | | self.attn_loss = torch.nn.L1Loss() |
| | | self.crit_attn_smooth = crit_attn_smooth |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | hotword_pad: torch.Tensor, |
| | | hotword_lengths: torch.Tensor, |
| | | ideal_attn: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Frontend + Encoder + Decoder + Calc loss |
| | | |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | assert text_lengths.dim() == 1, text_lengths.shape |
| | | # Check that batch_size is unified |
| | | assert ( |
| | | speech.shape[0] |
| | | == speech_lengths.shape[0] |
| | | == text.shape[0] |
| | | == text_lengths.shape[0] |
| | | ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) |
| | | batch_size = speech.shape[0] |
| | | self.step_cur += 1 |
| | | # for data-parallel |
| | | text = text[:, : text_lengths.max()] |
| | | speech = speech[:, :speech_lengths.max()] |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | |
| | | loss_att, acc_att, cer_att, wer_att = None, None, None, None |
| | | loss_ctc, cer_ctc = None, None |
| | | loss_pre = None |
| | | loss_ideal = None |
| | | |
| | | stats = dict() |
| | | |
| | | # 1. CTC branch |
| | | if self.ctc_weight != 0.0: |
| | | loss_ctc, cer_ctc = self._calc_ctc_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # Collect CTC branch stats |
| | | stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None |
| | | stats["cer_ctc"] = cer_ctc |
| | | |
| | | # Intermediate CTC (optional) |
| | | loss_interctc = 0.0 |
| | | if self.interctc_weight != 0.0 and intermediate_outs is not None: |
| | | for layer_idx, intermediate_out in intermediate_outs: |
| | | # we assume intermediate_out has the same length & padding |
| | | # as those of encoder_out |
| | | loss_ic, cer_ic = self._calc_ctc_loss( |
| | | intermediate_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | loss_interctc = loss_interctc + loss_ic |
| | | |
| | | # Collect Intermedaite CTC stats |
| | | stats["loss_interctc_layer{}".format(layer_idx)] = ( |
| | | loss_ic.detach() if loss_ic is not None else None |
| | | ) |
| | | stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic |
| | | |
| | | loss_interctc = loss_interctc / len(intermediate_outs) |
| | | |
| | | # calculate whole encoder loss |
| | | loss_ctc = (1 - self.interctc_weight) * loss_ctc + self.interctc_weight * loss_interctc |
| | | |
| | | # 2b. Attention decoder branch |
| | | if self.ctc_weight != 1.0: |
| | | loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths, ideal_attn |
| | | ) |
| | | |
| | | # 3. CTC-Att loss definition |
| | | if self.ctc_weight == 0.0: |
| | | loss = loss_att + loss_pre * self.predictor_weight |
| | | elif self.ctc_weight == 1.0: |
| | | loss = loss_ctc |
| | | else: |
| | | loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight |
| | | |
| | | if loss_ideal is not None: |
| | | loss = loss + loss_ideal * self.crit_attn_weight |
| | | stats["loss_ideal"] = loss_ideal.detach().cpu() |
| | | |
| | | # Collect Attn branch stats |
| | | stats["loss_att"] = loss_att.detach() if loss_att is not None else None |
| | | stats["acc"] = acc_att |
| | | stats["cer"] = cer_att |
| | | stats["wer"] = wer_att |
| | | stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None |
| | | |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def _calc_att_clas_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | hotword_pad: torch.Tensor, |
| | | hotword_lengths: torch.Tensor, |
| | | ideal_attn: torch.Tensor, |
| | | ): |
| | | encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( |
| | | encoder_out.device) |
| | | if self.predictor_bias == 1: |
| | | _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) |
| | | ys_pad_lens = ys_pad_lens + self.predictor_bias |
| | | pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, |
| | | ignore_id=self.ignore_id) |
| | | |
| | | # -1. bias encoder |
| | | if self.use_decoder_embedding: |
| | | hw_embed = self.decoder.embed(hotword_pad) |
| | | else: |
| | | hw_embed = self.bias_embed(hotword_pad) |
| | | hw_embed, (_, _) = self.bias_encoder(hw_embed) |
| | | _ind = np.arange(0, hotword_pad.shape[0]).tolist() |
| | | selected = hw_embed[_ind, [i-1 for i in hotword_lengths.detach().cpu().tolist()]] |
| | | contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device) |
| | | |
| | | # 0. sampler |
| | | decoder_out_1st = None |
| | | if self.sampling_ratio > 0.0: |
| | | if self.step_cur < 2: |
| | | logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) |
| | | sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, |
| | | pre_acoustic_embeds, contextual_info) |
| | | else: |
| | | if self.step_cur < 2: |
| | | logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) |
| | | sematic_embeds = pre_acoustic_embeds |
| | | |
| | | # 1. Forward decoder |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info |
| | | ) |
| | | decoder_out, _ = decoder_outs[0], decoder_outs[1] |
| | | ''' |
| | | if self.crit_attn_weight > 0 and attn.shape[-1] > 1: |
| | | ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0) |
| | | attn_non_blank = attn[:,:,:,:-1] |
| | | ideal_attn_non_blank = ideal_attn[:,:,:-1] |
| | | loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device)) |
| | | else: |
| | | loss_ideal = None |
| | | ''' |
| | | loss_ideal = None |
| | | |
| | | if decoder_out_1st is None: |
| | | decoder_out_1st = decoder_out |
| | | # 2. Compute attention loss |
| | | loss_att = self.criterion_att(decoder_out, ys_pad) |
| | | acc_att = th_accuracy( |
| | | decoder_out_1st.view(-1, self.vocab_size), |
| | | ys_pad, |
| | | ignore_label=self.ignore_id, |
| | | ) |
| | | loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) |
| | | |
| | | # Compute cer/wer using attention-decoder |
| | | if self.training or self.error_calculator is None: |
| | | cer_att, wer_att = None, None |
| | | else: |
| | | ys_hat = decoder_out_1st.argmax(dim=-1) |
| | | cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) |
| | | |
| | | return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal |
| | | |
| | | def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info): |
| | | |
| | | tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device) |
| | | ys_pad = ys_pad * tgt_mask[:, :, 0] |
| | | if self.share_embedding: |
| | | ys_pad_embed = self.decoder.output_layer.weight[ys_pad] |
| | | else: |
| | | ys_pad_embed = self.decoder.embed(ys_pad) |
| | | with torch.no_grad(): |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info |
| | | ) |
| | | decoder_out, _ = decoder_outs[0], decoder_outs[1] |
| | | pred_tokens = decoder_out.argmax(-1) |
| | | nonpad_positions = ys_pad.ne(self.ignore_id) |
| | | seq_lens = (nonpad_positions).sum(1) |
| | | same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1) |
| | | input_mask = torch.ones_like(nonpad_positions) |
| | | bsz, seq_len = ys_pad.size() |
| | | for li in range(bsz): |
| | | target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long() |
| | | if target_num > 0: |
| | | input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(pre_acoustic_embeds.device), value=0) |
| | | input_mask = input_mask.eq(1) |
| | | input_mask = input_mask.masked_fill(~nonpad_positions, False) |
| | | input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device) |
| | | |
| | | sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill( |
| | | input_mask_expand_dim, 0) |
| | | return sematic_embeds * tgt_mask, decoder_out * tgt_mask |
| | | |
| | | def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None): |
| | | if hw_list is None: |
| | | hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list |
| | | hw_list_pad = pad_list(hw_list, 0) |
| | | if self.use_decoder_embedding: |
| | | hw_embed = self.decoder.embed(hw_list_pad) |
| | | else: |
| | | hw_embed = self.bias_embed(hw_list_pad) |
| | | hw_embed, (h_n, _) = self.bias_encoder(hw_embed) |
| | | else: |
| | | hw_lengths = [len(i) for i in hw_list] |
| | | hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device) |
| | | if self.use_decoder_embedding: |
| | | hw_embed = self.decoder.embed(hw_list_pad) |
| | | else: |
| | | hw_embed = self.bias_embed(hw_list_pad) |
| | | hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True, |
| | | enforce_sorted=False) |
| | | _, (h_n, _) = self.bias_encoder(hw_embed) |
| | | hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1) |
| | | |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed |
| | | ) |
| | | decoder_out = decoder_outs[0] |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |
| | | return decoder_out, ys_pad_lens |
| | |
| | | from funasr.models.joint_net.joint_network import JointNetwork |
| | | from funasr.models.e2e_asr import ESPnetASRModel |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer |
| | | from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer |
| | | from funasr.models.e2e_tp import TimestampPredictor |
| | | from funasr.models.e2e_asr_mfcca import MFCCA |
| | | from funasr.models.e2e_uni_asr import UniASR |
| | |
| | | paraformer_bert=ParaformerBert, |
| | | bicif_paraformer=BiCifParaformer, |
| | | contextual_paraformer=ContextualParaformer, |
| | | neatcontextual_paraformer=NeatContextualParaformer, |
| | | mfcca=MFCCA, |
| | | timestamp_prediction=TimestampPredictor, |
| | | ), |
| | |
| | | normalize = None |
| | | |
| | | # 4. Encoder |
| | | |
| | | if getattr(args, "encoder", None) is not None: |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class(input_size, **args.encoder_conf) |
| | |
| | | if ignore_init_mismatch: |
| | | src_state = filter_state_dict(dst_state, src_state) |
| | | |
| | | logging.info("Loaded src_state keys: {}".format(src_state.keys())) |
| | | # logging.info("Loaded src_state keys: {}".format(src_state.keys())) |
| | | dst_state.update(src_state) |
| | | obj.load_state_dict(dst_state) |