zhifu gao
2023-05-05 f1ef7cf48d83e18ce315e37b322146677355f4f0
Merge pull request #453 from alibaba-damo-academy/dev_clas

Update NeatContextualParaformer, finetune pipeline and dataset
8个文件已修改
5个文件已添加
1个文件已删除
706 ■■■■■ 已修改文件
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py 37 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.sh 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.sh 105 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer_aishell1_subtest_demo.py 40 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/build_trainer.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/dataset.py 37 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/utils/hotword_utils.py 32 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/utils/padding.py 58 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/utils/tokenize.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_contextual_paraformer.py 372 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/torch_utils/load_pretrained_model.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md
@@ -1 +1 @@
../TEMPLATE/README.md
../../TEMPLATE/README.md
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py
New file
@@ -0,0 +1,37 @@
import os
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
import funasr
from funasr.datasets.ms_dataset import MsDataset
from funasr.utils.modelscope_param import modelscope_args
def modelscope_finetune(params):
    if not os.path.exists(params.output_dir):
        os.makedirs(params.output_dir, exist_ok=True)
    # dataset split ["train", "validation"]
    ds_dict = MsDataset.load(params.data_path)
    kwargs = dict(
        model=params.model,
        data_dir=ds_dict,
        dataset_type=params.dataset_type,
        work_dir=params.output_dir,
        batch_bins=params.batch_bins,
        max_epoch=params.max_epoch,
        lr=params.lr)
    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
    trainer.train()
if __name__ == '__main__':
    params = modelscope_args(model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", data_path="./data")
    params.output_dir = "./checkpoint"              # m模型保存路径
    params.data_path = "./example_data/"            # 数据路径
    params.dataset_type = "large"                   # 小数据量设置small,若数据量大于1000小时,请使用large
    params.batch_bins = 2000                       # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
    params.max_epoch = 50                           # 最大训练轮数
    params.lr = 0.00005                             # 设置学习率
    modelscope_finetune(params)
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.sh
File was deleted
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.sh
New file
@@ -0,0 +1,105 @@
#!/usr/bin/env bash
set -e
set -u
set -o pipefail
stage=1
stop_stage=2
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
data_dir="./data/test"
output_dir="./results"
batch_size=64
gpu_inference=true    # whether to perform gpu decoding
gpuid_list="0,1"    # set gpus, e.g., gpuid_list="0,1"
njob=10    # the number of jobs for CPU decoding, if gpu_inference=false, use CPU decoding, please set njob
checkpoint_dir=
checkpoint_name="valid.cer_ctc.ave.pb"
hotword_txt=None
. utils/parse_options.sh || exit 1;
if ${gpu_inference} == "true"; then
    nj=$(echo $gpuid_list | awk -F "," '{print NF}')
else
    nj=$njob
    batch_size=1
    gpuid_list=""
    for JOB in $(seq ${nj}); do
        gpuid_list=$gpuid_list"-1,"
    done
fi
mkdir -p $output_dir/split
split_scps=""
for JOB in $(seq ${nj}); do
    split_scps="$split_scps $output_dir/split/wav.$JOB.scp"
done
perl utils/split_scp.pl ${data_dir}/wav.scp ${split_scps}
if [ -n "${checkpoint_dir}" ]; then
  python utils/prepare_checkpoint.py ${model} ${checkpoint_dir} ${checkpoint_name}
  model=${checkpoint_dir}/${model}
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ];then
    echo "Decoding ..."
    gpuid_list_array=(${gpuid_list//,/ })
    for JOB in $(seq ${nj}); do
        {
        id=$((JOB-1))
        gpuid=${gpuid_list_array[$id]}
        mkdir -p ${output_dir}/output.$JOB
        python infer.py \
            --model ${model} \
            --audio_in ${output_dir}/split/wav.$JOB.scp \
            --output_dir ${output_dir}/output.$JOB \
            --batch_size ${batch_size} \
            --hotword_txt ${hotword_txt} \
            --gpuid ${gpuid}
        }&
    done
    wait
    mkdir -p ${output_dir}/1best_recog
    for f in token score text; do
        if [ -f "${output_dir}/output.1/1best_recog/${f}" ]; then
          for i in $(seq "${nj}"); do
              cat "${output_dir}/output.${i}/1best_recog/${f}"
          done | sort -k1 >"${output_dir}/1best_recog/${f}"
        fi
    done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
    echo "Computing WER ..."
    cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
    cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
    python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
    tail -n 3 ${output_dir}/1best_recog/text.cer
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ];then
    echo "SpeechIO TIOBE textnorm"
    echo "$0 --> Normalizing REF text ..."
    ./utils/textnorm_zh.py \
        --has_key --to_upper \
        ${data_dir}/text \
        ${output_dir}/1best_recog/ref.txt
    echo "$0 --> Normalizing HYP text ..."
    ./utils/textnorm_zh.py \
        --has_key --to_upper \
        ${output_dir}/1best_recog/text.proc \
        ${output_dir}/1best_recog/rec.txt
    grep -v $'\t$' ${output_dir}/1best_recog/rec.txt > ${output_dir}/1best_recog/rec_non_empty.txt
    echo "$0 --> computing WER/CER and alignment ..."
    ./utils/error_rate_zh \
        --tokenizer char \
        --ref ${output_dir}/1best_recog/ref.txt \
        --hyp ${output_dir}/1best_recog/rec_non_empty.txt \
        ${output_dir}/1best_recog/DETAILS.txt | tee ${output_dir}/1best_recog/RESULTS.txt
    rm -rf ${output_dir}/1best_recog/rec.txt ${output_dir}/1best_recog/rec_non_empty.txt
fi
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer_aishell1_subtest_demo.py
New file
@@ -0,0 +1,40 @@
import os
import tempfile
import codecs
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.msdatasets import MsDataset
if __name__ == '__main__':
    param_dict = dict()
    param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt"
    output_dir = "./output"
    batch_size = 1
    # dataset split ['test']
    ds_dict = MsDataset.load(dataset_name='speech_asr_aishell1_hotwords_testsets', namespace='speech_asr')
    work_dir = tempfile.TemporaryDirectory().name
    if not os.path.exists(work_dir):
        os.makedirs(work_dir)
    wav_file_path = os.path.join(work_dir, "wav.scp")
    counter = 0
    with codecs.open(wav_file_path, 'w') as fin:
        for line in ds_dict:
            counter += 1
            wav = line["Audio:FILE"]
            idx = wav.split("/")[-1].split(".")[0]
            fin.writelines(idx + " " + wav + "\n")
            if counter == 50:
                break
    audio_in = wav_file_path
    inference_pipeline = pipeline(
        task=Tasks.auto_speech_recognition,
        model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
        output_dir=output_dir,
        batch_size=batch_size,
        param_dict=param_dict)
    rec_result = inference_pipeline(audio_in=audio_in)
funasr/bin/asr_inference_paraformer.py
@@ -41,6 +41,7 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.bin.tp_inference import SpeechText2Timestamp
@@ -236,7 +237,7 @@
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        if not isinstance(self.asr_model, ContextualParaformer):
        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, NeatContextualParaformer):
            if self.hotword_list:
                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
funasr/bin/build_trainer.py
@@ -83,7 +83,8 @@
        finetune_configs = yaml.safe_load(f)
        # set data_types
        if dataset_type == "large":
            finetune_configs["dataset_conf"]["data_types"] = "sound,text"
            if 'data_types' not in finetune_configs['dataset_conf']:
                finetune_configs["dataset_conf"]["data_types"] = "sound,text"
    finetune_configs = update_dct(configs, finetune_configs)
    for key, value in finetune_configs.items():
        if hasattr(args, key):
@@ -131,7 +132,8 @@
        if args.dataset_type == "small":
            args.batch_bins = batch_bins
        elif args.dataset_type == "large":
            args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
            if "batch_size" not in args.dataset_conf["batch_conf"]:
                args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
        else:
            raise ValueError(f"Not supported dataset_type={args.dataset_type}")
    if args.normalize in ["null", "none", "None"]:
funasr/datasets/large_datasets/dataset.py
@@ -101,7 +101,7 @@
                if data_type == "kaldi_ark":
                    ark_reader = ReadHelper('ark:{}'.format(data_file))
                    reader_list.append(ark_reader)
                elif data_type == "text" or data_type == "sound":
                elif data_type == "text" or data_type == "sound" or data_type == 'text_hotword':
                    text_reader = open(data_file, "r")
                    reader_list.append(text_reader)
                elif data_type == "none":
@@ -131,6 +131,13 @@
                        sample_dict["sampling_rate"] = sampling_rate
                        if data_name == "speech":
                            sample_dict["key"] = key
                    elif data_type == "text_hotword":
                        text = item
                        segs = text.strip().split()
                        sample_dict[data_name] = segs[1:]
                        if "key" not in sample_dict:
                            sample_dict["key"] = segs[0]
                        sample_dict['hw_tag'] = 1
                    else:
                        text = item
                        segs = text.strip().split()
@@ -167,14 +174,38 @@
    shuffle = conf.get('shuffle', True)
    data_names = conf.get("data_names", "speech,text")
    data_types = conf.get("data_types", "kaldi_ark,text")
    dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, mode=mode)
    pre_hwfile = conf.get("pre_hwlist", None)
    pre_prob = conf.get("pre_prob", 0)  # unused yet
    hw_config = {"sample_rate": conf.get("sample_rate", 0.6),
                 "double_rate": conf.get("double_rate", 0.1),
                 "hotword_min_length": conf.get("hotword_min_length", 2),
                 "hotword_max_length": conf.get("hotword_max_length", 8),
                 "pre_prob": conf.get("pre_prob", 0.0)}
    if pre_hwfile is not None:
        pre_hwlist = []
        with open(pre_hwfile, 'r') as fin:
            for line in fin.readlines():
                pre_hwlist.append(line.strip())
    else:
        pre_hwlist = None
    dataset = AudioDataset(scp_lists,
                           data_names,
                           data_types,
                           frontend_conf=frontend_conf,
                           shuffle=shuffle,
                           mode=mode,
                           )
    filter_conf = conf.get('filter_conf', {})
    filter_fn = partial(filter, **filter_conf)
    dataset = FilterIterDataPipe(dataset, fn=filter_fn)
    if "text" in data_names:
        vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer}
        vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer, 'hw_config': hw_config}
        tokenize_fn = partial(tokenize, **vocab)
        dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
funasr/datasets/large_datasets/utils/hotword_utils.py
New file
@@ -0,0 +1,32 @@
import random
def sample_hotword(length,
                   hotword_min_length,
                   hotword_max_length,
                   sample_rate,
                   double_rate,
                   pre_prob,
                   pre_index=None):
        if length < hotword_min_length:
            return [-1]
        if random.random() < sample_rate:
            if pre_prob > 0 and random.random() < pre_prob and pre_index is not None:
                return pre_index
            if length == hotword_min_length:
                return [0, length-1]
            elif random.random() < double_rate and length > hotword_max_length + hotword_min_length + 2:
                # sample two hotwords in a sentence
                _max_hw_length = min(hotword_max_length, length // 2)
                # first hotword
                start1 = random.randint(0, length // 3)
                end1 = random.randint(start1 + hotword_min_length - 1, start1 + _max_hw_length - 1)
                # second hotword
                start2 = random.randint(end1 + 1, length - hotword_min_length)
                end2 = random.randint(min(length-1, start2+hotword_min_length-1), min(length-1, start2+hotword_max_length-1))
                return [start1, end1, start2, end2]
            else:  # single hotword
                start = random.randint(0, length - hotword_min_length)
                end = random.randint(min(length-1, start+hotword_min_length-1), min(length-1, start+hotword_max_length-1))
                return [start, end]
        else:
            return [-1]
funasr/datasets/large_datasets/utils/padding.py
@@ -13,15 +13,16 @@
    batch = {}
    data_names = data[0].keys()
    for data_name in data_names:
        if data_name == "key" or data_name =="sampling_rate":
        if data_name == "key" or data_name == "sampling_rate":
            continue
        else:
            if data[0][data_name].dtype.kind == "i":
                pad_value = int_pad_value
                tensor_type = torch.int64
            else:
                pad_value = float_pad_value
                tensor_type = torch.float32
            if data_name != 'hotword_indxs':
                if data[0][data_name].dtype.kind == "i":
                    pad_value = int_pad_value
                    tensor_type = torch.int64
                else:
                    pad_value = float_pad_value
                    tensor_type = torch.float32
            tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
            tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
@@ -31,4 +32,47 @@
            batch[data_name] = tensor_pad
            batch[data_name + "_lengths"] = tensor_lengths
    # DHA, EAHC NOT INCLUDED
    if "hotword_indxs" in batch:
        # if hotword indxs in batch
        # use it to slice hotwords out
        hotword_list = []
        hotword_lengths = []
        text = batch['text']
        text_lengths = batch['text_lengths']
        hotword_indxs = batch['hotword_indxs']
        num_hw = sum([int(i) for i in batch['hotword_indxs_lengths'] if i != 1]) // 2
        B, t1 = text.shape
        t1 += 1  # TODO: as parameter which is same as predictor_bias
        ideal_attn = torch.zeros(B, t1, num_hw+1)
        nth_hw = 0
        for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
            ideal_attn[b][:,-1] = 1
            if hotword_indx[0] != -1:
                start, end = int(hotword_indx[0]), int(hotword_indx[1])
                hotword = one_text[start: end+1]
                hotword_list.append(hotword)
                hotword_lengths.append(end-start+1)
                ideal_attn[b][start:end+1, nth_hw] = 1
                ideal_attn[b][start:end+1, -1] = 0
                nth_hw += 1
                if len(hotword_indx) == 4 and hotword_indx[2] != -1:
                    # the second hotword if exist
                    start, end = int(hotword_indx[2]), int(hotword_indx[3])
                    hotword_list.append(one_text[start: end+1])
                    hotword_lengths.append(end-start+1)
                    ideal_attn[b][start:end+1, nth_hw-1] = 1
                    ideal_attn[b][start:end+1, -1] = 0
                    nth_hw += 1
        hotword_list.append(torch.tensor([1]))
        hotword_lengths.append(1)
        hotword_pad = pad_sequence(hotword_list,
                                batch_first=True,
                                padding_value=0)
        batch["hotword_pad"] = hotword_pad
        batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
        batch['ideal_attn'] = ideal_attn
        del batch['hotword_indxs']
        del batch['hotword_indxs_lengths']
    return keys, batch
funasr/datasets/large_datasets/utils/tokenize.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python
import re
import numpy as np
from funasr.datasets.large_datasets.utils.hotword_utils import sample_hotword
def forward_segment(text, seg_dict):
    word_list = []
@@ -38,7 +39,8 @@
             vocab=None,
             seg_dict=None,
             punc_dict=None,
             bpe_tokenizer=None):
             bpe_tokenizer=None,
             hw_config=None):
    assert "text" in data
    assert isinstance(vocab, dict)
    text = data["text"]
@@ -53,6 +55,10 @@
        text = seg_tokenize(text, seg_dict)
    length = len(text)
    if 'hw_tag' in data:
        hotword_indxs = sample_hotword(length, **hw_config)
        data['hotword_indxs'] = hotword_indxs
        del data['hw_tag']
    for i in range(length):
        x = text[i]
        if i == length-1 and "punc" in data and x.startswith("vad:"):
funasr/models/e2e_asr_contextual_paraformer.py
New file
@@ -0,0 +1,372 @@
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.e2e_asr_paraformer import Paraformer
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
class NeatContextualParaformer(Paraformer):
    def __init__(
        self,
        vocab_size: int,
        token_list: Union[Tuple[str, ...], List[str]],
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        preencoder: Optional[AbsPreEncoder],
        encoder: AbsEncoder,
        postencoder: Optional[AbsPostEncoder],
        decoder: AbsDecoder,
        ctc: CTC,
        ctc_weight: float = 0.5,
        interctc_weight: float = 0.0,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        report_cer: bool = True,
        report_wer: bool = True,
        sym_space: str = "<space>",
        sym_blank: str = "<blank>",
        extract_feats_in_collect_stats: bool = True,
        predictor = None,
        predictor_weight: float = 0.0,
        predictor_bias: int = 0,
        sampling_ratio: float = 0.2,
        target_buffer_length: int = -1,
        inner_dim: int = 256,
        bias_encoder_type: str = 'lstm',
        use_decoder_embedding: bool = False,
        crit_attn_weight: float = 0.0,
        crit_attn_smooth: float = 0.0,
        bias_encoder_dropout_rate: float = 0.0,
    ):
        assert check_argument_types()
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
        assert 0.0 <= interctc_weight < 1.0, interctc_weight
        super().__init__(
        vocab_size=vocab_size,
        token_list=token_list,
        frontend=frontend,
        specaug=specaug,
        normalize=normalize,
        preencoder=preencoder,
        encoder=encoder,
        postencoder=postencoder,
        decoder=decoder,
        ctc=ctc,
        ctc_weight=ctc_weight,
        interctc_weight=interctc_weight,
        ignore_id=ignore_id,
        blank_id=blank_id,
        sos=sos,
        eos=eos,
        lsm_weight=lsm_weight,
        length_normalized_loss=length_normalized_loss,
        report_cer=report_cer,
        report_wer=report_wer,
        sym_space=sym_space,
        sym_blank=sym_blank,
        extract_feats_in_collect_stats=extract_feats_in_collect_stats,
        predictor=predictor,
        predictor_weight=predictor_weight,
        predictor_bias=predictor_bias,
        sampling_ratio=sampling_ratio,
        )
        if bias_encoder_type == 'lstm':
            logging.warning("enable bias encoder sampling and contextual training")
            self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
            self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim)
        elif bias_encoder_type == 'mean':
            logging.warning("enable bias encoder sampling and contextual training")
            self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim)
        else:
            logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type))
        self.target_buffer_length = target_buffer_length
        if self.target_buffer_length > 0:
            self.hotword_buffer = None
            self.length_record = []
            self.current_buffer_length = 0
        self.use_decoder_embedding = use_decoder_embedding
        self.crit_attn_weight = crit_attn_weight
        if self.crit_attn_weight > 0:
            self.attn_loss = torch.nn.L1Loss()
        self.crit_attn_smooth = crit_attn_smooth
    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
            hotword_pad: torch.Tensor,
            hotword_lengths: torch.Tensor,
            ideal_attn: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
                speech.shape[0]
                == speech_lengths.shape[0]
                == text.shape[0]
                == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        batch_size = speech.shape[0]
        self.step_cur += 1
        # for data-parallel
        text = text[:, : text_lengths.max()]
        speech = speech[:, :speech_lengths.max()]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        loss_att, acc_att, cer_att, wer_att = None, None, None, None
        loss_ctc, cer_ctc = None, None
        loss_pre = None
        loss_ideal = None
        stats = dict()
        # 1. CTC branch
        if self.ctc_weight != 0.0:
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
            # Collect CTC branch stats
            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc
        # Intermediate CTC (optional)
        loss_interctc = 0.0
        if self.interctc_weight != 0.0 and intermediate_outs is not None:
            for layer_idx, intermediate_out in intermediate_outs:
                # we assume intermediate_out has the same length & padding
                # as those of encoder_out
                loss_ic, cer_ic = self._calc_ctc_loss(
                    intermediate_out, encoder_out_lens, text, text_lengths
                )
                loss_interctc = loss_interctc + loss_ic
                # Collect Intermedaite CTC stats
                stats["loss_interctc_layer{}".format(layer_idx)] = (
                    loss_ic.detach() if loss_ic is not None else None
                )
                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
            loss_interctc = loss_interctc / len(intermediate_outs)
            # calculate whole encoder loss
            loss_ctc = (1 - self.interctc_weight) * loss_ctc + self.interctc_weight * loss_interctc
        # 2b. Attention decoder branch
        if self.ctc_weight != 1.0:
            loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
                encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths, ideal_attn
            )
        # 3. CTC-Att loss definition
        if self.ctc_weight == 0.0:
            loss = loss_att + loss_pre * self.predictor_weight
        elif self.ctc_weight == 1.0:
            loss = loss_ctc
        else:
            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
        if loss_ideal is not None:
            loss = loss + loss_ideal * self.crit_attn_weight
            stats["loss_ideal"] = loss_ideal.detach().cpu()
        # Collect Attn branch stats
        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
        stats["acc"] = acc_att
        stats["cer"] = cer_att
        stats["wer"] = wer_att
        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def _calc_att_clas_loss(
            self,
            encoder_out: torch.Tensor,
            encoder_out_lens: torch.Tensor,
            ys_pad: torch.Tensor,
            ys_pad_lens: torch.Tensor,
            hotword_pad: torch.Tensor,
            hotword_lengths: torch.Tensor,
            ideal_attn: torch.Tensor,
    ):
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        if self.predictor_bias == 1:
            _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
            ys_pad_lens = ys_pad_lens + self.predictor_bias
        pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
                                                                                  ignore_id=self.ignore_id)
        # -1. bias encoder
        if self.use_decoder_embedding:
            hw_embed = self.decoder.embed(hotword_pad)
        else:
            hw_embed = self.bias_embed(hotword_pad)
        hw_embed, (_, _) = self.bias_encoder(hw_embed)
        _ind = np.arange(0, hotword_pad.shape[0]).tolist()
        selected = hw_embed[_ind, [i-1 for i in hotword_lengths.detach().cpu().tolist()]]
        contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
        # 0. sampler
        decoder_out_1st = None
        if self.sampling_ratio > 0.0:
            if self.step_cur < 2:
                logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
            sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
                                                           pre_acoustic_embeds, contextual_info)
        else:
            if self.step_cur < 2:
                logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
            sematic_embeds = pre_acoustic_embeds
        # 1. Forward decoder
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
        )
        decoder_out, _ = decoder_outs[0], decoder_outs[1]
        '''
        if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
            ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
            attn_non_blank = attn[:,:,:,:-1]
            ideal_attn_non_blank = ideal_attn[:,:,:-1]
            loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device))
        else:
            loss_ideal = None
        '''
        loss_ideal = None
        if decoder_out_1st is None:
            decoder_out_1st = decoder_out
        # 2. Compute attention loss
        loss_att = self.criterion_att(decoder_out, ys_pad)
        acc_att = th_accuracy(
            decoder_out_1st.view(-1, self.vocab_size),
            ys_pad,
            ignore_label=self.ignore_id,
        )
        loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
        # Compute cer/wer using attention-decoder
        if self.training or self.error_calculator is None:
            cer_att, wer_att = None, None
        else:
            ys_hat = decoder_out_1st.argmax(dim=-1)
            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
        return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
    def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
        ys_pad = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad)
        with torch.no_grad():
            decoder_outs = self.decoder(
                encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
            )
            decoder_out, _ = decoder_outs[0], decoder_outs[1]
            pred_tokens = decoder_out.argmax(-1)
            nonpad_positions = ys_pad.ne(self.ignore_id)
            seq_lens = (nonpad_positions).sum(1)
            same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
            input_mask = torch.ones_like(nonpad_positions)
            bsz, seq_len = ys_pad.size()
            for li in range(bsz):
                target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
                if target_num > 0:
                    input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(pre_acoustic_embeds.device), value=0)
            input_mask = input_mask.eq(1)
            input_mask = input_mask.masked_fill(~nonpad_positions, False)
            input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
        sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
            input_mask_expand_dim, 0)
        return sematic_embeds * tgt_mask, decoder_out * tgt_mask
    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
        if hw_list is None:
            hw_list = [torch.Tensor([1]).long().to(encoder_out.device)]  # empty hotword list
            hw_list_pad = pad_list(hw_list, 0)
            if self.use_decoder_embedding:
                hw_embed = self.decoder.embed(hw_list_pad)
            else:
                hw_embed = self.bias_embed(hw_list_pad)
            hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
        else:
            hw_lengths = [len(i) for i in hw_list]
            hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
            if self.use_decoder_embedding:
                hw_embed = self.decoder.embed(hw_list_pad)
            else:
                hw_embed = self.bias_embed(hw_list_pad)
            hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
                                                            enforce_sorted=False)
            _, (h_n, _) = self.bias_encoder(hw_embed)
            hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed
        )
        decoder_out = decoder_outs[0]
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
funasr/tasks/asr.py
@@ -42,6 +42,7 @@
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
@@ -128,6 +129,7 @@
        paraformer_bert=ParaformerBert,
        bicif_paraformer=BiCifParaformer,
        contextual_paraformer=ContextualParaformer,
        neatcontextual_paraformer=NeatContextualParaformer,
        mfcca=MFCCA,
        timestamp_prediction=TimestampPredictor,
    ),
@@ -1647,7 +1649,6 @@
            normalize = None
        # 4. Encoder
        if getattr(args, "encoder", None) is not None:
            encoder_class = encoder_choices.get_class(args.encoder)
            encoder = encoder_class(input_size, **args.encoder_conf)
funasr/torch_utils/load_pretrained_model.py
@@ -120,6 +120,6 @@
    if ignore_init_mismatch:
        src_state = filter_state_dict(dst_state, src_state)
    logging.info("Loaded src_state keys: {}".format(src_state.keys()))
    # logging.info("Loaded src_state keys: {}".format(src_state.keys()))
    dst_state.update(src_state)
    obj.load_state_dict(dst_state)