Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
| New file |
| | |
| | | # ModelScope Model |
| | | |
| | | ## How to finetune and infer using a pretrained Paraformer-large Model |
| | | |
| | | ### Finetune |
| | | |
| | | - Modify finetune training related parameters in `finetune.py` |
| | | - <strong>output_dir:</strong> # result dir |
| | | - <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text` |
| | | - <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small` |
| | | - <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms |
| | | - <strong>max_epoch:</strong> # number of training epoch |
| | | - <strong>lr:</strong> # learning rate |
| | | |
| | | - Then you can run the pipeline to finetune with: |
| | | ```python |
| | | python finetune.py |
| | | ``` |
| | | |
| | | ### Inference |
| | | |
| | | Or you can use the finetuned model for inference directly. |
| | | |
| | | - Setting parameters in `infer.py` |
| | | - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed |
| | | - <strong>output_dir:</strong> # result dir |
| | | - <strong>ngpu:</strong> # the number of GPUs for decoding |
| | | - <strong>njob:</strong> # the number of jobs for each GPU |
| | | |
| | | - Then you can run the pipeline to infer with: |
| | | ```python |
| | | python infer.py |
| | | ``` |
| | | |
| | | - Results |
| | | |
| | | The decoding results can be found in `$output_dir/1best_recog/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set. |
| | | |
| | | ### Inference using local finetuned model |
| | | |
| | | - Modify inference related parameters in `infer_after_finetune.py` |
| | | - <strong>output_dir:</strong> # result dir |
| | | - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed~~~~ |
| | | - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth` |
| | | |
| | | - Then you can run the pipeline to finetune with: |
| | | ```python |
| | | python infer_after_finetune.py |
| | | ``` |
| | | |
| | | - Results |
| | | |
| | | The decoding results can be found in `$output_dir/decoding_results/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set. |
| New file |
| | |
| | | import os |
| | | |
| | | from modelscope.metainfo import Trainers |
| | | from modelscope.trainers import build_trainer |
| | | |
| | | from funasr.datasets.ms_dataset import MsDataset |
| | | from funasr.utils.modelscope_param import modelscope_args |
| | | |
| | | |
| | | def modelscope_finetune(params): |
| | | if not os.path.exists(params.output_dir): |
| | | os.makedirs(params.output_dir, exist_ok=True) |
| | | # dataset split ["train", "validation"] |
| | | ds_dict = MsDataset.load(params.data_path) |
| | | kwargs = dict( |
| | | model=params.model, |
| | | data_dir=ds_dict, |
| | | dataset_type=params.dataset_type, |
| | | work_dir=params.output_dir, |
| | | batch_bins=params.batch_bins, |
| | | max_epoch=params.max_epoch, |
| | | lr=params.lr) |
| | | trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs) |
| | | trainer.train() |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | params = modelscope_args(model="damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k", |
| | | data_path="./data") |
| | | params.output_dir = "./checkpoint" |
| | | params.data_path = "./example_data/" |
| | | params.dataset_type = "small" |
| | | params.batch_bins = 16000 |
| | | params.max_epoch = 50 |
| | | params.lr = 0.00002 |
| | | |
| | | modelscope_finetune(params) |
| New file |
| | |
| | | import os |
| | | import shutil |
| | | from multiprocessing import Pool |
| | | |
| | | from modelscope.pipelines import pipeline |
| | | from modelscope.utils.constant import Tasks |
| | | |
| | | from funasr.utils.compute_wer import compute_wer |
| | | |
| | | |
| | | def modelscope_infer_core(output_dir, split_dir, njob, idx): |
| | | output_dir_job = os.path.join(output_dir, "output.{}".format(idx)) |
| | | gpu_id = (int(idx) - 1) // njob |
| | | if "CUDA_VISIBLE_DEVICES" in os.environ.keys(): |
| | | gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",") |
| | | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id]) |
| | | else: |
| | | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) |
| | | inference_pipline = pipeline( |
| | | task=Tasks.auto_speech_recognition, |
| | | model="damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k", |
| | | output_dir=output_dir_job, |
| | | ) |
| | | audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx)) |
| | | inference_pipline(audio_in=audio_in) |
| | | |
| | | |
| | | def modelscope_infer(params): |
| | | # prepare for multi-GPU decoding |
| | | ngpu = params["ngpu"] |
| | | njob = params["njob"] |
| | | output_dir = params["output_dir"] |
| | | if os.path.exists(output_dir): |
| | | shutil.rmtree(output_dir) |
| | | os.mkdir(output_dir) |
| | | split_dir = os.path.join(output_dir, "split") |
| | | os.mkdir(split_dir) |
| | | nj = ngpu * njob |
| | | wav_scp_file = os.path.join(params["data_dir"], "wav.scp") |
| | | with open(wav_scp_file) as f: |
| | | lines = f.readlines() |
| | | num_lines = len(lines) |
| | | num_job_lines = num_lines // nj |
| | | start = 0 |
| | | for i in range(nj): |
| | | end = start + num_job_lines |
| | | file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1))) |
| | | with open(file, "w") as f: |
| | | if i == nj - 1: |
| | | f.writelines(lines[start:]) |
| | | else: |
| | | f.writelines(lines[start:end]) |
| | | start = end |
| | | |
| | | p = Pool(nj) |
| | | for i in range(nj): |
| | | p.apply_async(modelscope_infer_core, |
| | | args=(output_dir, split_dir, njob, str(i + 1))) |
| | | p.close() |
| | | p.join() |
| | | |
| | | # combine decoding results |
| | | best_recog_path = os.path.join(output_dir, "1best_recog") |
| | | os.mkdir(best_recog_path) |
| | | files = ["text", "token", "score"] |
| | | for file in files: |
| | | with open(os.path.join(best_recog_path, file), "w") as f: |
| | | for i in range(nj): |
| | | job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file) |
| | | with open(job_file) as f_job: |
| | | lines = f_job.readlines() |
| | | f.writelines(lines) |
| | | |
| | | # If text exists, compute CER |
| | | text_in = os.path.join(params["data_dir"], "text") |
| | | if os.path.exists(text_in): |
| | | text_proc_file = os.path.join(best_recog_path, "token") |
| | | compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer")) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | params = {} |
| | | params["data_dir"] = "./data/test" |
| | | params["output_dir"] = "./results" |
| | | params["ngpu"] = 2 |
| | | params["njob"] = 5 |
| | | modelscope_infer(params) |
| New file |
| | |
| | | import json |
| | | import os |
| | | import shutil |
| | | |
| | | from modelscope.pipelines import pipeline |
| | | from modelscope.utils.constant import Tasks |
| | | |
| | | from funasr.utils.compute_wer import compute_wer |
| | | |
| | | |
| | | def modelscope_infer_after_finetune(params): |
| | | # prepare for decoding |
| | | pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"]) |
| | | for file_name in params["required_files"]: |
| | | if file_name == "configuration.json": |
| | | with open(os.path.join(pretrained_model_path, file_name)) as f: |
| | | config_dict = json.load(f) |
| | | config_dict["model"]["am_model_name"] = params["decoding_model_name"] |
| | | with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f: |
| | | json.dump(config_dict, f, indent=4, separators=(',', ': ')) |
| | | else: |
| | | shutil.copy(os.path.join(pretrained_model_path, file_name), |
| | | os.path.join(params["output_dir"], file_name)) |
| | | decoding_path = os.path.join(params["output_dir"], "decode_results") |
| | | if os.path.exists(decoding_path): |
| | | shutil.rmtree(decoding_path) |
| | | os.mkdir(decoding_path) |
| | | |
| | | # decoding |
| | | inference_pipeline = pipeline( |
| | | task=Tasks.auto_speech_recognition, |
| | | model=params["output_dir"], |
| | | output_dir=decoding_path, |
| | | ) |
| | | audio_in = os.path.join(params["data_dir"], "wav.scp") |
| | | inference_pipeline(audio_in=audio_in) |
| | | |
| | | # computer CER if GT text is set |
| | | text_in = os.path.join(params["data_dir"], "text") |
| | | if os.path.exists(text_in): |
| | | text_proc_file = os.path.join(decoding_path, "1best_recog/token") |
| | | compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer")) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | params = {} |
| | | params["modelscope_model_name"] = "damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k" |
| | | params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"] |
| | | params["output_dir"] = "./checkpoint" |
| | | params["data_dir"] = "./data/test" |
| | | params["decoding_model_name"] = "valid.cer_ctc.ave.pth" |
| | | modelscope_infer_after_finetune(params) |
| New file |
| | |
| | | # Paraformer-Large |
| | | - Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/summary> |
| | | - Model size: 220M |
| | | |
| | | # Environments |
| | | - date: `Fri Feb 10 13:34:24 CST 2023` |
| | | - python version: `3.7.12` |
| | | - FunASR version: `0.1.6` |
| | | - pytorch version: `pytorch 1.7.0` |
| | | - Git hash: `` |
| | | - Commit date: `` |
| | | |
| | | # Beachmark Results |
| | | |
| | | ## AISHELL-1 |
| | | - Decode config: |
| | | - Decode without CTC |
| | | - Decode without LM |
| | | |
| | | | testset CER(%) | base model|finetune model | |
| | | |:--------------:|:---------:|:-------------:| |
| | | | dev | 1.75 |1.62 | |
| | | | test | 1.95 |1.78 | |
| New file |
| | |
| | | # Paraformer-Large |
| | | - Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/summary> |
| | | - Model size: 220M |
| | | |
| | | # Environments |
| | | - date: `Fri Feb 10 13:34:24 CST 2023` |
| | | - python version: `3.7.12` |
| | | - FunASR version: `0.1.6` |
| | | - pytorch version: `pytorch 1.7.0` |
| | | - Git hash: `` |
| | | - Commit date: `` |
| | | |
| | | # Beachmark Results |
| | | |
| | | ## AISHELL-2 |
| | | - Decode config: |
| | | - Decode without CTC |
| | | - Decode without LM |
| | | |
| | | | testset | base model|finetune model| |
| | | |:------------:|:---------:|:------------:| |
| | | | dev_ios | 2.80 |2.60 | |
| | | | test_android | 3.13 |2.84 | |
| | | | test_ios | 2.85 |2.82 | |
| | | | test_mic | 3.06 |2.88 | |
| New file |
| | |
| | | # Paraformer-Large |
| | | - Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary> |
| | | - Model size: 220M |
| | | |
| | | # Environments |
| | | - date: `Tue Nov 22 18:48:39 CST 2022` |
| | | - python version: `3.7.12` |
| | | - FunASR version: `0.1.0` |
| | | - pytorch version: `pytorch 1.7.0` |
| | | - Git hash: `` |
| | | - Commit date: `` |
| | | |
| | | # Beachmark Results |
| | | |
| | | ## AISHELL-1 |
| | | - Decode config: |
| | | - Decode without CTC |
| | | - Decode without LM |
| | | |
| | | | testset | CER(%)| |
| | | |:---------:|:-----:| |
| | | | dev | 1.75 | |
| | | | test | 1.95 | |
| | | |
| | | ## AISHELL-2 |
| | | - Decode config: |
| | | - Decode without CTC |
| | | - Decode without LM |
| | | |
| | | | testset | CER(%)| |
| | | |:------------:|:-----:| |
| | | | dev_ios | 2.80 | |
| | | | test_android | 3.13 | |
| | | | test_ios | 2.85 | |
| | | | test_mic | 3.06 | |
| | | |
| | | ## Wenetspeech |
| | | - Decode config: |
| | | - Decode without CTC |
| | | - Decode without LM |
| | | |
| | | | testset | CER(%)| |
| | | |:---------:|:-----:| |
| | | | dev | 3.57 | |
| | | | test | 6.97 | |
| | | | test_net | 6.74 | |
| | | |
| | | ## SpeechIO TIOBE |
| | | - Decode config 1: |
| | | - Decode without CTC |
| | | - Decode without LM |
| | | - With text norm |
| | | - Decode config 2: |
| | | - Decode without CTC |
| | | - Decode with Transformer-LM |
| | | - LM weight: 0.15 |
| | | - With text norm |
| | | |
| | | | testset | w/o LM | w/ LM | |
| | | |:------------------:|:----:|:----:| |
| | | |SPEECHIO_ASR_ZH00001| 0.49 | 0.35 | |
| | | |SPEECHIO_ASR_ZH00002| 3.23 | 2.86 | |
| | | |SPEECHIO_ASR_ZH00003| 1.13 | 0.80 | |
| | | |SPEECHIO_ASR_ZH00004| 1.33 | 1.10 | |
| | | |SPEECHIO_ASR_ZH00005| 1.41 | 1.18 | |
| | | |SPEECHIO_ASR_ZH00006| 5.25 | 4.85 | |
| | | |SPEECHIO_ASR_ZH00007| 5.51 | 4.97 | |
| | | |SPEECHIO_ASR_ZH00008| 3.69 | 3.18 | |
| | | |SPEECHIO_ASR_ZH00009| 3.02 | 2.78 | |
| | | |SPEECHIO_ASR_ZH000010| 3.35 | 2.99 | |
| | | |SPEECHIO_ASR_ZH000011| 1.54 | 1.25 | |
| | | |SPEECHIO_ASR_ZH000012| 2.06 | 1.68 | |
| | | |SPEECHIO_ASR_ZH000013| 2.57 | 2.25 | |
| | | |SPEECHIO_ASR_ZH000014| 3.86 | 3.08 | |
| | | |SPEECHIO_ASR_ZH000015| 3.34 | 2.67 | |
| New file |
| | |
| | | |
| | | |
| | | ##################text二进制数据##################### |
| | | inputs = "hello 大 家 好 呀" |
| | | |
| | | from modelscope.pipelines import pipeline |
| | | from modelscope.utils.constant import Tasks |
| | | |
| | | inference_pipline = pipeline( |
| | | task=Tasks.language_model, |
| | | model='damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch', |
| | | output_dir="./tmp/" |
| | | ) |
| | | |
| | | rec_result = inference_pipline(text_in=inputs) |
| | | print(rec_result) |
| | | |
| | |
| | | ibest_writer["score"][key] = str(hyp.score) |
| | | |
| | | if text is not None: |
| | | text_postprocessed = postprocess_utils.sentence_postprocess(token) |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | |
| | | import logging |
| | | import sys |
| | | import time |
| | | import copy |
| | | import os |
| | | import codecs |
| | | from pathlib import Path |
| | | from typing import Optional |
| | | from typing import Sequence |
| | |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils import asr_utils, wav_utils, postprocess_utils |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer |
| | | |
| | | |
| | | header_colors = '\033[95m' |
| | | end_colors = '\033[0m' |
| | |
| | | penalty: float = 0.0, |
| | | nbest: int = 1, |
| | | frontend_conf: dict = None, |
| | | hotword_list_or_file: str = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | |
| | | self.asr_train_args = asr_train_args |
| | | self.converter = converter |
| | | self.tokenizer = tokenizer |
| | | |
| | | # 6. [Optional] Build hotword list from file or str |
| | | if hotword_list_or_file is None: |
| | | self.hotword_list = None |
| | | elif os.path.exists(hotword_list_or_file): |
| | | self.hotword_list = [] |
| | | hotword_str_list = [] |
| | | with codecs.open(hotword_list_or_file, 'r') as fin: |
| | | for line in fin.readlines(): |
| | | hw = line.strip() |
| | | hotword_str_list.append(hw) |
| | | self.hotword_list.append(self.converter.tokens2ids([i for i in hw])) |
| | | self.hotword_list.append([1]) |
| | | hotword_str_list.append('<s>') |
| | | logging.info("Initialized hotword list from file: {}, hotword list: {}." |
| | | .format(hotword_list_or_file, hotword_str_list)) |
| | | else: |
| | | logging.info("Attempting to parse hotwords as str...") |
| | | self.hotword_list = [] |
| | | hotword_str_list = [] |
| | | for hw in hotword_list_or_file.strip().split(): |
| | | hotword_str_list.append(hw) |
| | | self.hotword_list.append(self.converter.tokens2ids([i for i in hw])) |
| | | self.hotword_list.append([1]) |
| | | hotword_str_list.append('<s>') |
| | | logging.info("Hotword list: {}.".format(hotword_str_list)) |
| | | |
| | | |
| | | is_use_lm = lm_weight != 0.0 and lm_file is not None |
| | | if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: |
| | | beam_search = None |
| | |
| | | pre_token_length = pre_token_length.round().long() |
| | | if torch.max(pre_token_length) < 1: |
| | | return [] |
| | | decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | if not isinstance(self.asr_model, ContextualParaformer): |
| | | if self.hotword_list: |
| | | logging.warning("Hotword is given but asr model is not a ContextualParaformer.") |
| | | decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | else: |
| | | decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | |
| | | results = [] |
| | | b, n, d = decoder_out.size() |
| | |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | if param_dict is not None: |
| | | hotword_list_or_file = param_dict.get('hotword') |
| | | else: |
| | | hotword_list_or_file = None |
| | | |
| | | if ngpu >= 1 and torch.cuda.is_available(): |
| | | device = "cuda" |
| | | else: |
| | |
| | | ngram_weight=ngram_weight, |
| | | penalty=penalty, |
| | | nbest=nbest, |
| | | hotword_list_or_file=hotword_list_or_file, |
| | | ) |
| | | speech2text = Speech2Text(**speech2text_kwargs) |
| | | |
| | |
| | | ibest_writer["rtf"][key] = rtf_cur |
| | | |
| | | if text is not None: |
| | | text_postprocessed = postprocess_utils.sentence_postprocess(token) |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | |
| | | default=1, |
| | | help="The number of workers used for DataLoader", |
| | | ) |
| | | |
| | | parser.add_argument( |
| | | "--hotword", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="hotword file path or hotwords seperated by space" |
| | | ) |
| | | group = parser.add_argument_group("Input data related") |
| | | group.add_argument( |
| | | "--data_path_and_name_and_type", |
| | |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | parser = get_parser() |
| | | args = parser.parse_args(cmd) |
| | | param_dict = {'hotword': args.hotword} |
| | | kwargs = vars(args) |
| | | kwargs.pop("config", None) |
| | | kwargs['param_dict'] = param_dict |
| | | inference(**kwargs) |
| | | |
| | | |
| | |
| | | ibest_writer["score"][key] = str(hyp.score) |
| | | |
| | | if text is not None: |
| | | text_postprocessed = postprocess_utils.sentence_postprocess(token) |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | ) |
| | | |
| | | if param_dict is not None: |
| | | use_timestamp = param_dict.get('use_timestamp', True) |
| | | else: |
| | | use_timestamp = True |
| | | |
| | | finish_count = 0 |
| | | file_count = 1 |
| | |
| | | text, token, token_int = result[0], result[1], result[2] |
| | | time_stamp = None if len(result) < 4 else result[3] |
| | | |
| | | |
| | | postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) |
| | | if use_timestamp and time_stamp is not None: |
| | | postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) |
| | | else: |
| | | postprocessed_result = postprocess_utils.sentence_postprocess(token) |
| | | text_postprocessed = "" |
| | | time_stamp_postprocessed = "" |
| | | text_postprocessed_punc = postprocessed_result |
| | |
| | | text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \ |
| | | postprocessed_result[1], \ |
| | | postprocessed_result[2] |
| | | text_postprocessed_punc = text_postprocessed |
| | | if len(word_lists) > 0 and text2punc is not None: |
| | | text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20) |
| | | else: |
| | | text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1] |
| | | text_postprocessed_punc = text_postprocessed |
| | | if len(word_lists) > 0 and text2punc is not None: |
| | | text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20) |
| | | |
| | | |
| | | item = {'key': key, 'value': text_postprocessed_punc} |
| | |
| | | from typing import Any |
| | | from typing import List |
| | | import math |
| | | import copy |
| | | import numpy as np |
| | | import torch |
| | | from typeguard import check_argument_types |
| | |
| | | from funasr.utils import asr_utils, wav_utils, postprocess_utils |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.tasks.vad import VADTask |
| | | from funasr.utils.timestamp_tools import time_stamp_lfr6 |
| | | from funasr.utils.timestamp_tools import time_stamp_lfr6, time_stamp_lfr6_pl |
| | | from funasr.bin.punctuation_infer import Text2Punc |
| | | from funasr.models.e2e_asr_paraformer import BiCifParaformer |
| | | |
| | | header_colors = '\033[95m' |
| | | end_colors = '\033[0m' |
| | |
| | | decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | |
| | | if isinstance(self.asr_model, BiCifParaformer): |
| | | _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len, |
| | | pre_token_length) # test no bias cif2 |
| | | |
| | | results = [] |
| | | b, n, d = decoder_out.size() |
| | | for i in range(b): |
| | |
| | | else: |
| | | text = None |
| | | |
| | | time_stamp = time_stamp_lfr6(alphas[i:i+1,], enc_len[i:i+1,], token, begin_time, end_time) |
| | | |
| | | results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor)) |
| | | if isinstance(self.asr_model, BiCifParaformer): |
| | | timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time) |
| | | results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor)) |
| | | else: |
| | | time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time, end_time) |
| | | results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor)) |
| | | |
| | | # assert check_return_type(results) |
| | | return results |
| | |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | ) |
| | | |
| | | if param_dict is not None: |
| | | use_timestamp = param_dict.get('use_timestamp', True) |
| | | else: |
| | | use_timestamp = True |
| | | |
| | | finish_count = 0 |
| | | file_count = 1 |
| | |
| | | result = result_segments[0] |
| | | text, token, token_int = result[0], result[1], result[2] |
| | | time_stamp = None if len(result) < 4 else result[3] |
| | | |
| | | postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) |
| | | |
| | | if use_timestamp and time_stamp is not None: |
| | | postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) |
| | | else: |
| | | postprocessed_result = postprocess_utils.sentence_postprocess(token) |
| | | text_postprocessed = "" |
| | | time_stamp_postprocessed = "" |
| | | text_postprocessed_punc = postprocessed_result |
| | |
| | | text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \ |
| | | postprocessed_result[1], \ |
| | | postprocessed_result[2] |
| | | text_postprocessed_punc = text_postprocessed |
| | | if len(word_lists) > 0 and text2punc is not None: |
| | | text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20) |
| | | else: |
| | | text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1] |
| | | |
| | | text_postprocessed_punc = text_postprocessed |
| | | if len(word_lists) > 0 and text2punc is not None: |
| | | text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20) |
| | | |
| | | item = {'key': key, 'value': text_postprocessed_punc} |
| | | if text_postprocessed != "": |
| | |
| | | ibest_writer["score"][key] = str(hyp.score) |
| | | |
| | | if text is not None: |
| | | text_postprocessed = postprocess_utils.sentence_postprocess(token) |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | |
| | | ibest_writer["score"][key] = str(hyp.score) |
| | | |
| | | if text is not None: |
| | | text_postprocessed = postprocess_utils.sentence_postprocess(token) |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2. Build LM |
| | | model, train_args = LMTask.build_model_from_file(train_config, model_file, device) |
| | | model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device) |
| | | # Wrape model to make model.nll() data-parallel |
| | | wrapped_model = ForwardAdaptor(model, "nll") |
| | | wrapped_model.to(dtype=getattr(torch, dtype)).eval() |
| | |
| | | utt_ppl = log_base ** (_nll / ntoken / np.log(log_base)) |
| | | |
| | | # Write PPL of each utts for debugging or analysis |
| | | writer["utt2nll"][key] = str(-_nll) |
| | | writer["utt2ppl"][key] = str(utt_ppl) |
| | | writer["utt2ntokens"][key] = str(ntoken) |
| | | |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | import argparse |
| | | import logging |
| | | from pathlib import Path |
| | | import sys |
| | | import os |
| | | from typing import Optional |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import Dict |
| | | from typing import Any |
| | | from typing import List |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from torch.nn.parallel import data_parallel |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.tasks.lm import LMTask |
| | | from funasr.datasets.preprocessor import LMPreprocessor |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.fileio.datadir_writer import DatadirWriter |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.forward_adaptor import ForwardAdaptor |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import float_or_none |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | |
| | | def inference( |
| | | output_dir: str, |
| | | batch_size: int, |
| | | dtype: str, |
| | | ngpu: int, |
| | | seed: int, |
| | | num_workers: int, |
| | | log_level: Union[int, str], |
| | | train_config: Optional[str], |
| | | model_file: Optional[str], |
| | | log_base: Optional[float], |
| | | key_file: Optional[str] = None, |
| | | allow_variable_data_keys: bool = False, |
| | | split_with_space: Optional[bool] = False, |
| | | seg_dict_file: Optional[str] = None, |
| | | data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None, |
| | | raw_inputs: Union[List[Any], bytes, str] = None, |
| | | **kwargs, |
| | | ): |
| | | inference_pipeline = inference_modelscope( |
| | | output_dir=output_dir, |
| | | raw_inputs=raw_inputs, |
| | | batch_size=batch_size, |
| | | dtype=dtype, |
| | | ngpu=ngpu, |
| | | seed=seed, |
| | | num_workers=num_workers, |
| | | log_level=log_level, |
| | | key_file=key_file, |
| | | train_config=train_config, |
| | | model_file=model_file, |
| | | log_base = log_base, |
| | | allow_variable_data_keys = allow_variable_data_keys, |
| | | split_with_space=split_with_space, |
| | | seg_dict_file=seg_dict_file, |
| | | **kwargs, |
| | | ) |
| | | return inference_pipeline(data_path_and_name_and_type, raw_inputs) |
| | | |
| | | |
| | | def inference_modelscope( |
| | | batch_size: int, |
| | | dtype: str, |
| | | ngpu: int, |
| | | seed: int, |
| | | num_workers: int, |
| | | log_level: Union[int, str], |
| | | key_file: Optional[str], |
| | | train_config: Optional[str], |
| | | model_file: Optional[str], |
| | | log_base: Optional[float] = 10, |
| | | allow_variable_data_keys: bool = False, |
| | | split_with_space: Optional[bool] = False, |
| | | seg_dict_file: Optional[str] = None, |
| | | output_dir: Optional[str] = None, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | logging.basicConfig( |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | if ngpu >= 1 and torch.cuda.is_available(): |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2. Build Model |
| | | model, train_args = LMTask.build_model_from_file( |
| | | train_config, model_file, device) |
| | | wrapped_model = ForwardAdaptor(model, "nll") |
| | | wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval() |
| | | logging.info(f"Model:\n{model}") |
| | | |
| | | preprocessor = LMPreprocessor( |
| | | train=False, |
| | | token_type=train_args.token_type, |
| | | token_list=train_args.token_list, |
| | | bpemodel=train_args.bpemodel, |
| | | text_cleaner=train_args.cleaner, |
| | | g2p_type=train_args.g2p, |
| | | text_name="text", |
| | | non_linguistic_symbols=train_args.non_linguistic_symbols, |
| | | split_with_space=split_with_space, |
| | | seg_dict_file=seg_dict_file |
| | | ) |
| | | |
| | | def _forward( |
| | | data_path_and_name_and_type, |
| | | raw_inputs: Union[List[Any], bytes, str] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | | param_dict: dict = None, |
| | | ): |
| | | results = [] |
| | | if output_dir_v2 is not None: |
| | | writer = DatadirWriter(output_dir_v2) |
| | | else: |
| | | writer = None |
| | | |
| | | if raw_inputs != None: |
| | | line = raw_inputs.strip() |
| | | key = "lm demo" |
| | | if line=="": |
| | | item = {'key': key, 'value': ""} |
| | | results.append(item) |
| | | return results |
| | | batch = {} |
| | | batch['text'] = line |
| | | if preprocessor != None: |
| | | batch = preprocessor(key, batch) |
| | | |
| | | # Force data-precision |
| | | for name in batch: |
| | | value = batch[name] |
| | | if not isinstance(value, np.ndarray): |
| | | raise RuntimeError( |
| | | f"All values must be converted to np.ndarray object " |
| | | f'by preprocessing, but "{name}" is still {type(value)}.' |
| | | ) |
| | | # Cast to desired type |
| | | if value.dtype.kind == "f": |
| | | value = value.astype("float32") |
| | | elif value.dtype.kind == "i": |
| | | value = value.astype("long") |
| | | else: |
| | | raise NotImplementedError(f"Not supported dtype: {value.dtype}") |
| | | batch[name] = value |
| | | |
| | | batch["text_lengths"] = torch.from_numpy( |
| | | np.array([len(batch["text"])], dtype='int32')) |
| | | batch["text"] = np.expand_dims(batch["text"], axis=0) |
| | | |
| | | with torch.no_grad(): |
| | | batch = to_device(batch, device) |
| | | if ngpu <= 1: |
| | | nll, lengths = wrapped_model(**batch) |
| | | else: |
| | | nll, lengths = data_parallel( |
| | | wrapped_model, (), range(ngpu), module_kwargs=batch |
| | | ) |
| | | ## compute ppl |
| | | ppl_out_batch = "" |
| | | ids2tokens = preprocessor.token_id_converter.ids2tokens |
| | | for sent_ids, sent_nll in zip(batch['text'], nll): |
| | | pre_word = "<s>" |
| | | cur_word = None |
| | | sent_lst = ids2tokens(sent_ids) + ['</s>'] |
| | | ppl_out = " ".join(sent_lst) + "\n" |
| | | for word, word_nll in zip(sent_lst, sent_nll): |
| | | cur_word = word |
| | | word_nll = -word_nll.cpu() |
| | | if log_base is None: |
| | | word_prob = np.exp(word_nll) |
| | | else: |
| | | word_prob = log_base ** (word_nll / np.log(log_base)) |
| | | ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format( |
| | | cur=cur_word, |
| | | pre=pre_word, |
| | | prob=round(word_prob.item(), 8), |
| | | word_nll=round(word_nll.item(), 8) |
| | | ) |
| | | pre_word = cur_word |
| | | |
| | | sent_nll_mean = sent_nll.mean().cpu().numpy() |
| | | sent_nll_sum = sent_nll.sum().cpu().numpy() |
| | | if log_base is None: |
| | | sent_ppl = np.exp(sent_nll_mean) |
| | | else: |
| | | sent_ppl = log_base ** (sent_nll_mean / np.log(log_base)) |
| | | ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format( |
| | | sent_nll=round(-sent_nll_sum.item(), 4), |
| | | sent_ppl=round(sent_ppl.item(), 4) |
| | | ) |
| | | ppl_out_batch += ppl_out |
| | | item = {'key': key, 'value': ppl_out} |
| | | if writer is not None: |
| | | writer["ppl"][key+":\n"] = ppl_out |
| | | results.append(item) |
| | | |
| | | return results |
| | | |
| | | # 3. Build data-iterator |
| | | loader = LMTask.build_streaming_iterator( |
| | | data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | | preprocess_fn=preprocessor, |
| | | collate_fn=LMTask.build_collate_fn(train_args, False), |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | ) |
| | | |
| | | # 4. Start for-loop |
| | | total_nll = 0.0 |
| | | total_ntokens = 0 |
| | | ppl_out_all = "" |
| | | for keys, batch in loader: |
| | | assert isinstance(batch, dict), type(batch) |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | _bs = len(next(iter(batch.values()))) |
| | | assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
| | | |
| | | ppl_out_batch = "" |
| | | with torch.no_grad(): |
| | | batch = to_device(batch, device) |
| | | if ngpu <= 1: |
| | | # NOTE(kamo): data_parallel also should work with ngpu=1, |
| | | # but for debuggability it's better to keep this block. |
| | | nll, lengths = wrapped_model(**batch) |
| | | else: |
| | | nll, lengths = data_parallel( |
| | | wrapped_model, (), range(ngpu), module_kwargs=batch |
| | | ) |
| | | ## print ppl |
| | | ids2tokens = preprocessor.token_id_converter.ids2tokens |
| | | for key, sent_ids, sent_nll in zip(keys, batch['text'], nll): |
| | | pre_word = "<s>" |
| | | cur_word = None |
| | | sent_lst = ids2tokens(sent_ids) + ['</s>'] |
| | | ppl_out = " ".join(sent_lst) + "\n" |
| | | for word, word_nll in zip(sent_lst, sent_nll): |
| | | cur_word = word |
| | | word_nll = -word_nll.cpu() |
| | | if log_base is None: |
| | | word_prob = np.exp(word_nll) |
| | | else: |
| | | word_prob = log_base ** (word_nll / np.log(log_base)) |
| | | ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format( |
| | | cur=cur_word, |
| | | pre=pre_word, |
| | | prob=round(word_prob.item(), 8), |
| | | word_nll=round(word_nll.item(), 8) |
| | | ) |
| | | pre_word = cur_word |
| | | |
| | | sent_nll_mean = sent_nll.mean().cpu().numpy() |
| | | sent_nll_sum = sent_nll.sum().cpu().numpy() |
| | | if log_base is None: |
| | | sent_ppl = np.exp(sent_nll_mean) |
| | | else: |
| | | sent_ppl = log_base ** (sent_nll_mean / np.log(log_base)) |
| | | ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format( |
| | | sent_nll=round(-sent_nll_sum.item(), 4), |
| | | sent_ppl=round(sent_ppl.item(), 4) |
| | | ) |
| | | ppl_out_batch += ppl_out |
| | | utt2nll = round(-sent_nll_sum.item(), 5) |
| | | item = {'key': key, 'value': ppl_out} |
| | | if writer is not None: |
| | | writer["ppl"][key+":\n"] = ppl_out |
| | | writer["utt2nll"][key] = str(utt2nll) |
| | | results.append(item) |
| | | |
| | | ppl_out_all += ppl_out_batch |
| | | |
| | | assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths)) |
| | | # nll: (B, L) -> (B,) |
| | | nll = nll.detach().cpu().numpy().sum(1) |
| | | # lengths: (B,) |
| | | lengths = lengths.detach().cpu().numpy() |
| | | total_nll += nll.sum() |
| | | total_ntokens += lengths.sum() |
| | | |
| | | if log_base is None: |
| | | ppl = np.exp(total_nll / total_ntokens) |
| | | else: |
| | | ppl = log_base ** (total_nll / total_ntokens / np.log(log_base)) |
| | | |
| | | avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format( |
| | | total_nll=round(-total_nll.item(), 4), |
| | | total_ppl=round(ppl.item(), 4) |
| | | ) |
| | | item = {'key': 'AVG PPL', 'value': avg_ppl} |
| | | ppl_out_all += avg_ppl |
| | | if writer is not None: |
| | | writer["ppl"]["AVG PPL : "] = avg_ppl |
| | | results.append(item) |
| | | |
| | | return results |
| | | |
| | | return _forward |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="Calc perplexity", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | |
| | | parser.add_argument( |
| | | "--log_level", |
| | | type=lambda x: x.upper(), |
| | | default="INFO", |
| | | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| | | help="The verbose level of logging", |
| | | ) |
| | | |
| | | parser.add_argument("--output_dir", type=str, required=False) |
| | | parser.add_argument( |
| | | "--ngpu", |
| | | type=int, |
| | | default=0, |
| | | help="The number of gpus. 0 indicates CPU mode", |
| | | ) |
| | | parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| | | parser.add_argument( |
| | | "--dtype", |
| | | default="float32", |
| | | choices=["float16", "float32", "float64"], |
| | | help="Data type", |
| | | ) |
| | | parser.add_argument( |
| | | "--num_workers", |
| | | type=int, |
| | | default=1, |
| | | help="The number of workers used for DataLoader", |
| | | ) |
| | | parser.add_argument( |
| | | "--batch_size", |
| | | type=int, |
| | | default=1, |
| | | help="The batch size for inference", |
| | | ) |
| | | parser.add_argument( |
| | | "--log_base", |
| | | type=float_or_none, |
| | | default=10, |
| | | help="The base of logarithm for Perplexity. " |
| | | "If None, napier's constant is used.", |
| | | required=False |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Input data related") |
| | | group.add_argument( |
| | | "--data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | action="append", |
| | | required=False |
| | | ) |
| | | group.add_argument( |
| | | "--raw_inputs", |
| | | type=str, |
| | | required=False |
| | | ) |
| | | group.add_argument("--key_file", type=str_or_none) |
| | | group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) |
| | | |
| | | group.add_argument("--split_with_space", type=str2bool, default=False) |
| | | group.add_argument("--seg_dict_file", type=str_or_none) |
| | | |
| | | group = parser.add_argument_group("The model configuration related") |
| | | group.add_argument("--train_config", type=str) |
| | | group.add_argument("--model_file", type=str) |
| | | |
| | | return parser |
| | | |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | parser = get_parser() |
| | | args = parser.parse_args(cmd) |
| | | kwargs = vars(args) |
| | | inference(**kwargs) |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | | |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | import sys |
| | | from typing import Union, Dict, Any |
| | | |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.types import float_or_none |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="Calc perplexity", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | |
| | | parser.add_argument( |
| | | "--log_level", |
| | | type=lambda x: x.upper(), |
| | | default="INFO", |
| | | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| | | help="The verbose level of logging", |
| | | ) |
| | | parser.add_argument("--output_dir", type=str, required=True) |
| | | parser.add_argument("--gpuid_list", type=str, required=True) |
| | | parser.add_argument( |
| | | "--ngpu", |
| | | type=int, |
| | | default=0, |
| | | help="The number of gpus. 0 indicates CPU mode", |
| | | ) |
| | | parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| | | parser.add_argument("--njob", type=int, default=1, help="Random seed") |
| | | parser.add_argument( |
| | | "--dtype", |
| | | default="float32", |
| | | choices=["float16", "float32", "float64"], |
| | | help="Data type", |
| | | ) |
| | | parser.add_argument( |
| | | "--num_workers", |
| | | type=int, |
| | | default=1, |
| | | help="The number of workers used for DataLoader", |
| | | ) |
| | | parser.add_argument( |
| | | "--batch_size", |
| | | type=int, |
| | | default=1, |
| | | help="The batch size for inference", |
| | | ) |
| | | parser.add_argument( |
| | | "--log_base", |
| | | type=float_or_none, |
| | | default=10, |
| | | help="The base of logarithm for Perplexity. " |
| | | "If None, napier's constant is used.", |
| | | required=False |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Input data related") |
| | | group.add_argument( |
| | | "--data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | action="append", |
| | | required=False |
| | | ) |
| | | group.add_argument( |
| | | "--raw_inputs", |
| | | type=str, |
| | | required=False |
| | | ) |
| | | group.add_argument("--key_file", type=str_or_none) |
| | | group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) |
| | | |
| | | group.add_argument("--split_with_space", type=str2bool, default=False) |
| | | group.add_argument("--seg_dict_file", type=str_or_none) |
| | | |
| | | group = parser.add_argument_group("The model configuration related") |
| | | group.add_argument("--train_config", type=str) |
| | | group.add_argument("--model_file", type=str) |
| | | group.add_argument("--mode", type=str, default="lm") |
| | | return parser |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "transformer": |
| | | from funasr.bin.lm_inference import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | parser = get_parser() |
| | | args = parser.parse_args(cmd) |
| | | kwargs = vars(args) |
| | | kwargs.pop("config", None) |
| | | |
| | | # set logging messages |
| | | logging.basicConfig( |
| | | level=args.log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | logging.info("Decoding args: {}".format(kwargs)) |
| | | |
| | | # gpu setting |
| | | if args.ngpu > 0: |
| | | jobid = int(args.output_dir.split(".")[-1]) |
| | | gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob] |
| | | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | | os.environ["CUDA_VISIBLE_DEVICES"] = gpuid |
| | | |
| | | kwargs.pop("gpuid_list", None) |
| | | kwargs.pop("njob", None) |
| | | results = inference_launch(**kwargs) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | | |
| | |
| | | #!/usr/bin/env python3 |
| | | |
| | | import os |
| | | |
| | | from funasr.tasks.lm import LMTask |
| | | |
| | | |
| | | def get_parser(): |
| | | # for LM Training |
| | | def parse_args(): |
| | | parser = LMTask.get_parser() |
| | | return parser |
| | | parser.add_argument( |
| | | "--gpu_id", |
| | | type=int, |
| | | default=0, |
| | | help="local gpu id.", |
| | | ) |
| | | args = parser.parse_args() |
| | | return args |
| | | |
| | | |
| | | def main(cmd=None): |
| | | """LM training. |
| | | |
| | | Example: |
| | | |
| | | % python lm_train.py asr --print_config --optim adadelta |
| | | % python lm_train.py --config conf/train_asr.yaml |
| | | """ |
| | | LMTask.main(cmd=cmd) |
| | | def main(args=None, cmd=None): |
| | | # for LM Training |
| | | LMTask.main(args=args, cmd=cmd) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | | if __name__ == '__main__': |
| | | args = parse_args() |
| | | |
| | | # setup local gpu_id |
| | | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) |
| | | |
| | | # DDP settings |
| | | if args.ngpu > 1: |
| | | args.distributed = True |
| | | else: |
| | | args.distributed = False |
| | | assert args.num_worker_count == 1 |
| | | |
| | | # re-compute batch size: when dataset type is small |
| | | if args.dataset_type == "small" and args.ngpu != 0: |
| | | if args.batch_size is not None: |
| | | args.batch_size = args.batch_size * args.ngpu |
| | | if args.batch_bins is not None: |
| | | args.batch_bins = args.batch_bins * args.ngpu |
| | | |
| | | main(args=args) |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | import argparse |
| | | from collections import Counter |
| | | import logging |
| | | from pathlib import Path |
| | | import sys |
| | | from typing import List |
| | | from typing import Optional |
| | | |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.text.build_tokenizer import build_tokenizer |
| | | from funasr.text.cleaner import TextCleaner |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str_or_none |
| | | |
| | | |
| | | def field2slice(field: Optional[str]) -> slice: |
| | | """Convert field string to slice |
| | | |
| | | Note that field string accepts 1-based integer. |
| | | |
| | | Examples: |
| | | >>> field2slice("1-") |
| | | slice(0, None, None) |
| | | >>> field2slice("1-3") |
| | | slice(0, 3, None) |
| | | >>> field2slice("-3") |
| | | slice(None, 3, None) |
| | | """ |
| | | field = field.strip() |
| | | try: |
| | | if "-" in field: |
| | | # e.g. "2-" or "2-5" or "-7" |
| | | s1, s2 = field.split("-", maxsplit=1) |
| | | if s1.strip() == "": |
| | | s1 = None |
| | | else: |
| | | s1 = int(s1) |
| | | if s1 == 0: |
| | | raise ValueError("1-based string") |
| | | if s2.strip() == "": |
| | | s2 = None |
| | | else: |
| | | s2 = int(s2) |
| | | else: |
| | | # e.g. "2" |
| | | s1 = int(field) |
| | | s2 = s1 + 1 |
| | | if s1 == 0: |
| | | raise ValueError("must be 1 or more value") |
| | | except ValueError: |
| | | raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}") |
| | | |
| | | if s1 is None: |
| | | slic = slice(None, s2) |
| | | else: |
| | | # -1 because of 1-based integer following "cut" command |
| | | # e.g "1-3" -> slice(0, 3) |
| | | slic = slice(s1 - 1, s2) |
| | | return slic |
| | | |
| | | |
| | | def tokenize( |
| | | input: str, |
| | | output: str, |
| | | field: Optional[str], |
| | | delimiter: Optional[str], |
| | | token_type: str, |
| | | space_symbol: str, |
| | | non_linguistic_symbols: Optional[str], |
| | | bpemodel: Optional[str], |
| | | log_level: str, |
| | | write_vocabulary: bool, |
| | | vocabulary_size: int, |
| | | remove_non_linguistic_symbols: bool, |
| | | cutoff: int, |
| | | add_symbol: List[str], |
| | | cleaner: Optional[str], |
| | | g2p: Optional[str], |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | | logging.basicConfig( |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | if input == "-": |
| | | fin = sys.stdin |
| | | else: |
| | | fin = Path(input).open("r", encoding="utf-8") |
| | | if output == "-": |
| | | fout = sys.stdout |
| | | else: |
| | | p = Path(output) |
| | | p.parent.mkdir(parents=True, exist_ok=True) |
| | | fout = p.open("w", encoding="utf-8") |
| | | |
| | | cleaner = TextCleaner(cleaner) |
| | | tokenizer = build_tokenizer( |
| | | token_type=token_type, |
| | | bpemodel=bpemodel, |
| | | delimiter=delimiter, |
| | | space_symbol=space_symbol, |
| | | non_linguistic_symbols=non_linguistic_symbols, |
| | | remove_non_linguistic_symbols=remove_non_linguistic_symbols, |
| | | g2p_type=g2p, |
| | | ) |
| | | |
| | | counter = Counter() |
| | | if field is not None: |
| | | field = field2slice(field) |
| | | |
| | | for line in fin: |
| | | line = line.rstrip() |
| | | if field is not None: |
| | | # e.g. field="2-" |
| | | # uttidA hello world!! -> hello world!! |
| | | tokens = line.split(delimiter) |
| | | tokens = tokens[field] |
| | | if delimiter is None: |
| | | line = " ".join(tokens) |
| | | else: |
| | | line = delimiter.join(tokens) |
| | | |
| | | line = cleaner(line) |
| | | tokens = tokenizer.text2tokens(line) |
| | | if not write_vocabulary: |
| | | fout.write(" ".join(tokens) + "\n") |
| | | else: |
| | | for t in tokens: |
| | | counter[t] += 1 |
| | | |
| | | if not write_vocabulary: |
| | | return |
| | | |
| | | ## FIXME |
| | | ## del duplicate add_symbols in counter |
| | | for symbol_and_id in add_symbol: |
| | | # e.g symbol="<blank>:0" |
| | | try: |
| | | symbol, idx = symbol_and_id.split(":") |
| | | except ValueError: |
| | | raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}") |
| | | symbol = symbol.strip() |
| | | if symbol in counter: |
| | | del counter[symbol] |
| | | |
| | | # ======= write_vocabulary mode from here ======= |
| | | # Sort by the number of occurrences in descending order |
| | | # and filter lower frequency words than cutoff value |
| | | words_and_counts = list( |
| | | filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1])) |
| | | ) |
| | | # Restrict the vocabulary size |
| | | if vocabulary_size > 0: |
| | | if vocabulary_size < len(add_symbol): |
| | | raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}") |
| | | words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)] |
| | | |
| | | # Parse the values of --add_symbol |
| | | for symbol_and_id in add_symbol: |
| | | # e.g symbol="<blank>:0" |
| | | try: |
| | | symbol, idx = symbol_and_id.split(":") |
| | | idx = int(idx) |
| | | except ValueError: |
| | | raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}") |
| | | symbol = symbol.strip() |
| | | |
| | | # e.g. idx=0 -> append as the first symbol |
| | | # e.g. idx=-1 -> append as the last symbol |
| | | if idx < 0: |
| | | idx = len(words_and_counts) + 1 + idx |
| | | words_and_counts.insert(idx, (symbol, None)) |
| | | |
| | | # Write words |
| | | for w, c in words_and_counts: |
| | | fout.write(w + "\n") |
| | | |
| | | # Logging |
| | | total_count = sum(counter.values()) |
| | | invocab_count = sum(c for w, c in words_and_counts if c is not None) |
| | | logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %") |
| | | |
| | | |
| | | def get_parser() -> argparse.ArgumentParser: |
| | | parser = argparse.ArgumentParser( |
| | | description="Tokenize texts", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | parser.add_argument( |
| | | "--log_level", |
| | | type=lambda x: x.upper(), |
| | | default="INFO", |
| | | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| | | help="The verbose level of logging", |
| | | ) |
| | | |
| | | parser.add_argument( |
| | | "--input", "-i", required=True, help="Input text. - indicates sys.stdin" |
| | | ) |
| | | parser.add_argument( |
| | | "--output", "-o", required=True, help="Output text. - indicates sys.stdout" |
| | | ) |
| | | parser.add_argument( |
| | | "--field", |
| | | "-f", |
| | | help="The target columns of the input text as 1-based integer. e.g 2-", |
| | | ) |
| | | parser.add_argument( |
| | | "--token_type", |
| | | "-t", |
| | | default="char", |
| | | choices=["char", "bpe", "word", "phn"], |
| | | help="Token type", |
| | | ) |
| | | parser.add_argument("--delimiter", "-d", default=None, help="The delimiter") |
| | | parser.add_argument("--space_symbol", default="<space>", help="The space symbol") |
| | | parser.add_argument("--bpemodel", default=None, help="The bpemodel file path") |
| | | parser.add_argument( |
| | | "--non_linguistic_symbols", |
| | | type=str_or_none, |
| | | help="non_linguistic_symbols file path", |
| | | ) |
| | | parser.add_argument( |
| | | "--remove_non_linguistic_symbols", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Remove non-language-symbols from tokens", |
| | | ) |
| | | parser.add_argument( |
| | | "--cleaner", |
| | | type=str_or_none, |
| | | choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"], |
| | | default=None, |
| | | help="Apply text cleaning", |
| | | ) |
| | | parser.add_argument( |
| | | "--g2p", |
| | | type=str_or_none, |
| | | choices=g2p_choices, |
| | | default=None, |
| | | help="Specify g2p method if --token_type=phn", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("write_vocabulary mode related") |
| | | group.add_argument( |
| | | "--write_vocabulary", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Write tokens list instead of tokenized text per line", |
| | | ) |
| | | group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size") |
| | | group.add_argument( |
| | | "--cutoff", |
| | | default=0, |
| | | type=int, |
| | | help="cut-off frequency used for write-vocabulary mode", |
| | | ) |
| | | group.add_argument( |
| | | "--add_symbol", |
| | | type=str, |
| | | default=[], |
| | | action="append", |
| | | help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'", |
| | | ) |
| | | |
| | | return parser |
| | | |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | parser = get_parser() |
| | | args = parser.parse_args(cmd) |
| | | kwargs = vars(args) |
| | | tokenize(**kwargs) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | |
| | | continue |
| | | return out_txt.strip().split() |
| | | |
| | | def seg_tokenize_wo_pattern(txt, seg_dict): |
| | | out_txt = "" |
| | | for word in txt: |
| | | if word in seg_dict: |
| | | out_txt += seg_dict[word] + " " |
| | | else: |
| | | out_txt += "<unk>" + " " |
| | | return out_txt.strip().split() |
| | | |
| | | |
| | | def framing( |
| | | x, |
| | |
| | | data = self._text_process(data) |
| | | return data |
| | | |
| | | ## FIXME |
| | | class LMPreprocessor(CommonPreprocessor): |
| | | def __init__( |
| | | self, |
| | | train: bool, |
| | | token_type: str = None, |
| | | token_list: Union[Path, str, Iterable[str]] = None, |
| | | bpemodel: Union[Path, str, Iterable[str]] = None, |
| | | text_cleaner: Collection[str] = None, |
| | | g2p_type: str = None, |
| | | unk_symbol: str = "<unk>", |
| | | space_symbol: str = "<space>", |
| | | non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, |
| | | delimiter: str = None, |
| | | rir_scp: str = None, |
| | | rir_apply_prob: float = 1.0, |
| | | noise_scp: str = None, |
| | | noise_apply_prob: float = 1.0, |
| | | noise_db_range: str = "3_10", |
| | | speech_volume_normalize: float = None, |
| | | speech_name: str = "speech", |
| | | text_name: str = "text", |
| | | split_with_space: bool = False, |
| | | seg_dict_file: str = None, |
| | | ): |
| | | super().__init__(train, |
| | | token_type, |
| | | token_list, |
| | | bpemodel, |
| | | text_cleaner, |
| | | g2p_type, |
| | | unk_symbol, |
| | | space_symbol, |
| | | non_linguistic_symbols, |
| | | delimiter, |
| | | rir_scp, |
| | | rir_apply_prob, |
| | | noise_scp, |
| | | noise_apply_prob, |
| | | noise_db_range, |
| | | speech_volume_normalize, |
| | | speech_name, |
| | | text_name, |
| | | split_with_space, |
| | | seg_dict_file, |
| | | ) |
| | | |
| | | def _text_process( |
| | | self, data: Dict[str, Union[str, np.ndarray]] |
| | | ) -> Dict[str, np.ndarray]: |
| | | if self.text_name in data and self.tokenizer is not None: |
| | | text = data[self.text_name] |
| | | text = self.text_cleaner(text) |
| | | if self.split_with_space: |
| | | tokens = text.strip().split(" ") |
| | | if self.seg_dict is not None: |
| | | tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict) |
| | | else: |
| | | tokens = self.tokenizer.text2tokens(text) |
| | | text_ints = self.token_id_converter.tokens2ids(tokens) |
| | | data[self.text_name] = np.array(text_ints, dtype=np.int64) |
| | | assert check_return_type(data) |
| | | return data |
| | | |
| | | |
| | | class CommonPreprocessor_multi(AbsPreprocessor): |
| | | def __init__( |
| | |
| | | |
| | | # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>' |
| | | # text: (Batch, Length) -> x, y: (Batch, Length + 1) |
| | | x = F.pad(text, [1, 0], "constant", self.eos) |
| | | x = F.pad(text, [1, 0], "constant", self.sos) |
| | | t = F.pad(text, [0, 1], "constant", self.ignore_id) |
| | | for i, l in enumerate(text_lengths): |
| | | t[i, l] = self.sos |
| | | t[i, l] = self.eos |
| | | x_lengths = text_lengths + 1 |
| | | |
| | | # 2. Forward Language model |
| New file |
| | |
| | | from typing import List |
| | | from typing import Tuple |
| | | import logging |
| | | import torch |
| | | import torch.nn as nn |
| | | import numpy as np |
| | | |
| | | from funasr.modules.streaming_utils import utils as myutils |
| | | from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt |
| | | from funasr.modules.embedding import PositionalEncoding |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM |
| | | from funasr.modules.repeat import repeat |
| | | from funasr.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder |
| | | |
| | | |
| | | class ContextualDecoderLayer(nn.Module): |
| | | def __init__( |
| | | self, |
| | | size, |
| | | self_attn, |
| | | src_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | ): |
| | | """Construct an DecoderLayer object.""" |
| | | super(ContextualDecoderLayer, self).__init__() |
| | | self.size = size |
| | | self.self_attn = self_attn |
| | | self.src_attn = src_attn |
| | | self.feed_forward = feed_forward |
| | | self.norm1 = LayerNorm(size) |
| | | if self_attn is not None: |
| | | self.norm2 = LayerNorm(size) |
| | | if src_attn is not None: |
| | | self.norm3 = LayerNorm(size) |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | self.normalize_before = normalize_before |
| | | self.concat_after = concat_after |
| | | if self.concat_after: |
| | | self.concat_linear1 = nn.Linear(size + size, size) |
| | | self.concat_linear2 = nn.Linear(size + size, size) |
| | | |
| | | def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,): |
| | | # tgt = self.dropout(tgt) |
| | | if isinstance(tgt, Tuple): |
| | | tgt, _ = tgt |
| | | residual = tgt |
| | | if self.normalize_before: |
| | | tgt = self.norm1(tgt) |
| | | tgt = self.feed_forward(tgt) |
| | | |
| | | x = tgt |
| | | if self.normalize_before: |
| | | tgt = self.norm2(tgt) |
| | | if self.training: |
| | | cache = None |
| | | x, cache = self.self_attn(tgt, tgt_mask, cache=cache) |
| | | x = residual + self.dropout(x) |
| | | x_self_attn = x |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm3(x) |
| | | x = self.src_attn(x, memory, memory_mask) |
| | | x_src_attn = x |
| | | |
| | | x = residual + self.dropout(x) |
| | | return x, tgt_mask, x_self_attn, x_src_attn |
| | | |
| | | |
| | | class ContexutalBiasDecoder(nn.Module): |
| | | def __init__( |
| | | self, |
| | | size, |
| | | src_attn, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | ): |
| | | """Construct an DecoderLayer object.""" |
| | | super(ContexutalBiasDecoder, self).__init__() |
| | | self.size = size |
| | | self.src_attn = src_attn |
| | | if src_attn is not None: |
| | | self.norm3 = LayerNorm(size) |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | self.normalize_before = normalize_before |
| | | |
| | | def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): |
| | | x = tgt |
| | | if self.src_attn is not None: |
| | | if self.normalize_before: |
| | | x = self.norm3(x) |
| | | x = self.dropout(self.src_attn(x, memory, memory_mask)) |
| | | return x, tgt_mask, memory, memory_mask, cache |
| | | |
| | | |
| | | class ContextualParaformerDecoder(ParaformerSANMDecoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | """ |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | att_layer_num: int = 6, |
| | | kernel_size: int = 21, |
| | | sanm_shfit: int = 0, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | | dropout_rate=dropout_rate, |
| | | positional_dropout_rate=positional_dropout_rate, |
| | | input_layer=input_layer, |
| | | use_output_layer=use_output_layer, |
| | | pos_enc_class=pos_enc_class, |
| | | normalize_before=normalize_before, |
| | | ) |
| | | |
| | | attention_dim = encoder_output_size |
| | | if input_layer == 'none': |
| | | self.embed = None |
| | | if input_layer == "embed": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(vocab_size, attention_dim), |
| | | # pos_enc_class(attention_dim, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "linear": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Linear(vocab_size, attention_dim), |
| | | torch.nn.LayerNorm(attention_dim), |
| | | torch.nn.Dropout(dropout_rate), |
| | | torch.nn.ReLU(), |
| | | pos_enc_class(attention_dim, positional_dropout_rate), |
| | | ) |
| | | else: |
| | | raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") |
| | | |
| | | self.normalize_before = normalize_before |
| | | if self.normalize_before: |
| | | self.after_norm = LayerNorm(attention_dim) |
| | | if use_output_layer: |
| | | self.output_layer = torch.nn.Linear(attention_dim, vocab_size) |
| | | else: |
| | | self.output_layer = None |
| | | |
| | | self.att_layer_num = att_layer_num |
| | | self.num_blocks = num_blocks |
| | | if sanm_shfit is None: |
| | | sanm_shfit = (kernel_size - 1) // 2 |
| | | self.decoders = repeat( |
| | | att_layer_num - 1, |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit |
| | | ), |
| | | MultiHeadedAttentionCrossAtt( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | self.bias_decoder = ContexutalBiasDecoder( |
| | | size=attention_dim, |
| | | src_attn=MultiHeadedAttentionCrossAtt( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | dropout_rate=dropout_rate, |
| | | normalize_before=True, |
| | | ) |
| | | self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False) |
| | | self.last_decoder = ContextualDecoderLayer( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit |
| | | ), |
| | | MultiHeadedAttentionCrossAtt( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ) |
| | | if num_blocks - att_layer_num <= 0: |
| | | self.decoders2 = None |
| | | else: |
| | | self.decoders2 = repeat( |
| | | num_blocks - att_layer_num, |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0 |
| | | ), |
| | | None, |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | |
| | | self.decoders3 = repeat( |
| | | 1, |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | None, |
| | | None, |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | contextual_info: torch.Tensor, |
| | | return_hidden: bool = False, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Forward decoder. |
| | | |
| | | Args: |
| | | hs_pad: encoded memory, float32 (batch, maxlen_in, feat) |
| | | hlens: (batch) |
| | | ys_in_pad: |
| | | input token ids, int64 (batch, maxlen_out) |
| | | if input_layer == "embed" |
| | | input tensor (batch, maxlen_out, #mels) in the other cases |
| | | ys_in_lens: (batch) |
| | | Returns: |
| | | (tuple): tuple containing: |
| | | |
| | | x: decoded token score before softmax (batch, maxlen_out, token) |
| | | if use_output_layer is True, |
| | | olens: (batch, ) |
| | | """ |
| | | tgt = ys_in_pad |
| | | tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] |
| | | |
| | | memory = hs_pad |
| | | memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | |
| | | x = tgt |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | _, _, x_self_attn, x_src_attn = self.last_decoder( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | |
| | | # contextual paraformer related |
| | | contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0]) |
| | | contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :] |
| | | cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask) |
| | | |
| | | if self.bias_output is not None: |
| | | x = torch.cat([x_src_attn, cx], dim=2) |
| | | x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D |
| | | x = x_self_attn + self.dropout(x) |
| | | |
| | | if self.decoders2 is not None: |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders2( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders3( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | olens = tgt_mask.sum(1) |
| | | if self.output_layer is not None and return_hidden is False: |
| | | x = self.output_layer(x) |
| | | return x, olens |
| | | |
| | | def gen_tf2torch_map_dict(self): |
| | | |
| | | tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch |
| | | tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf |
| | | map_dict_local = { |
| | | |
| | | ## decoder |
| | | # ffn |
| | | "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,1024),(1,1024,256) |
| | | |
| | | # fsmn |
| | | "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format( |
| | | tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format( |
| | | tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format( |
| | | tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 2, 0), |
| | | }, # (256,1,31),(1,31,256,1) |
| | | # src att |
| | | "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,256),(1,256,256) |
| | | "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,256),(1,256,256) |
| | | "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | # dnn |
| | | "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,1024),(1,1024,256) |
| | | |
| | | # embed_concat_ffn |
| | | "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,1024),(1,1024,256) |
| | | |
| | | # out norm |
| | | "{}.after_norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.after_norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | |
| | | # in embed |
| | | "{}.embed.0.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/w_embs".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (4235,256),(4235,256) |
| | | |
| | | # out layer |
| | | "{}.output_layer.weight".format(tensor_name_prefix_torch): |
| | | {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)], |
| | | "squeeze": [None, None], |
| | | "transpose": [(1, 0), None], |
| | | }, # (4235,256),(256,4235) |
| | | "{}.output_layer.bias".format(tensor_name_prefix_torch): |
| | | {"name": ["{}/dense/bias".format(tensor_name_prefix_tf), |
| | | "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"], |
| | | "squeeze": [None, None], |
| | | "transpose": [None, None], |
| | | }, # (4235,),(4235,) |
| | | |
| | | ## clas decoder |
| | | # src att |
| | | "{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,256),(1,256,256) |
| | | "{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,256),(1,256,256) |
| | | "{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | # dnn |
| | | "{}.bias_output.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": (2, 1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | |
| | | } |
| | | return map_dict_local |
| | | |
| | | def convert_tf2torch(self, |
| | | var_dict_tf, |
| | | var_dict_torch, |
| | | ): |
| | | map_dict = self.gen_tf2torch_map_dict() |
| | | var_dict_torch_update = dict() |
| | | decoder_layeridx_sets = set() |
| | | for name in sorted(var_dict_torch.keys(), reverse=False): |
| | | names = name.split('.') |
| | | if names[0] == self.tf2torch_tensor_name_prefix_torch: |
| | | if names[1] == "decoders": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | elif names[1] == "last_decoder": |
| | | layeridx = 15 |
| | | name_q = name.replace("last_decoder", "decoders.layeridx") |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | |
| | | elif names[1] == "decoders2": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | name_q = name_q.replace("decoders2", "decoders") |
| | | layeridx_bias = len(decoder_layeridx_sets) |
| | | |
| | | layeridx += layeridx_bias |
| | | if "decoders." in name: |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "decoders3": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | if "decoders." in name: |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | elif names[1] == "bias_decoder": |
| | | name_q = name |
| | | |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | |
| | | elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output": |
| | | name_tf = map_dict[name]["name"] |
| | | if isinstance(name_tf, list): |
| | | idx_list = 0 |
| | | if name_tf[idx_list] in var_dict_tf.keys(): |
| | | pass |
| | | else: |
| | | idx_list = 1 |
| | | data_tf = var_dict_tf[name_tf[idx_list]] |
| | | if map_dict[name]["squeeze"][idx_list] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list]) |
| | | if map_dict[name]["transpose"][idx_list] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), |
| | | name_tf[idx_list], |
| | | var_dict_tf[name_tf[ |
| | | idx_list]].shape)) |
| | | |
| | | else: |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) |
| | | if map_dict[name]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "after_norm": |
| | | name_tf = map_dict[name]["name"] |
| | | data_tf = var_dict_tf[name_tf] |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "embed_concat_ffn": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | if "decoders." in name: |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | return var_dict_torch_update |
| | |
| | | from typing import Union |
| | | |
| | | import torch |
| | | import random |
| | | import numpy as np |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | |
| | | from funasr.models.preencoder.abs_preencoder import AbsPreEncoder |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.modules.add_sos_eos import add_sos_eos |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.modules.nets_utils import make_pad_mask, pad_list |
| | | from funasr.modules.nets_utils import th_accuracy |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | |
| | | |
| | | class BiCifParaformer(Paraformer): |
| | | |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | """ |
| | | Paraformer model with an extra cif predictor |
| | | to conduct accurate timestamp prediction |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | |
| | | ) |
| | | assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3" |
| | | |
| | | def _calc_att_loss( |
| | | def _calc_pre2_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | |
| | | if self.predictor_bias == 1: |
| | | _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) |
| | | ys_pad_lens = ys_pad_lens + self.predictor_bias |
| | | pre_acoustic_embeds, pre_token_length, _, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, |
| | | ignore_id=self.ignore_id) |
| | | _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id) |
| | | |
| | | # 0. sampler |
| | | decoder_out_1st = None |
| | | if self.sampling_ratio > 0.0: |
| | | if self.step_cur < 2: |
| | | logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) |
| | | sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, |
| | | pre_acoustic_embeds) |
| | | else: |
| | | if self.step_cur < 2: |
| | | logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) |
| | | sematic_embeds = pre_acoustic_embeds |
| | | # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) |
| | | loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2) |
| | | |
| | | # 1. Forward decoder |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens |
| | | ) |
| | | decoder_out, _ = decoder_outs[0], decoder_outs[1] |
| | | |
| | | if decoder_out_1st is None: |
| | | decoder_out_1st = decoder_out |
| | | # 2. Compute attention loss |
| | | loss_att = self.criterion_att(decoder_out, ys_pad) |
| | | acc_att = th_accuracy( |
| | | decoder_out_1st.view(-1, self.vocab_size), |
| | | ys_pad, |
| | | ignore_label=self.ignore_id, |
| | | ) |
| | | loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) |
| | | loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length2) |
| | | |
| | | # Compute cer/wer using attention-decoder |
| | | if self.training or self.error_calculator is None: |
| | | cer_att, wer_att = None, None |
| | | else: |
| | | ys_hat = decoder_out_1st.argmax(dim=-1) |
| | | cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) |
| | | |
| | | return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_pre2 |
| | | return loss_pre2 |
| | | |
| | | def calc_predictor(self, encoder_out, encoder_out_lens): |
| | | |
| | |
| | | def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num): |
| | | encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( |
| | | encoder_out.device) |
| | | ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out, None, encoder_out_mask, token_num=token_num, |
| | | ignore_id=self.ignore_id) |
| | | import pdb; pdb.set_trace() |
| | | ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out, |
| | | encoder_out_mask, |
| | | token_num) |
| | | return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Frontend + Encoder + Decoder + Calc loss |
| | | |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | assert text_lengths.dim() == 1, text_lengths.shape |
| | | # Check that batch_size is unified |
| | | assert ( |
| | | speech.shape[0] |
| | | == speech_lengths.shape[0] |
| | | == text.shape[0] |
| | | == text_lengths.shape[0] |
| | | ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) |
| | | batch_size = speech.shape[0] |
| | | self.step_cur += 1 |
| | | # for data-parallel |
| | | text = text[:, : text_lengths.max()] |
| | | speech = speech[:, :speech_lengths.max()] |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | stats = dict() |
| | | |
| | | loss_pre2 = self._calc_pre2_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | loss = loss_pre2 |
| | | |
| | | stats["loss_pre2"] = loss_pre2.detach().cpu() |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | class ContextualParaformer(Paraformer): |
| | | """ |
| | | Paraformer model with contextual hotword |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | token_list: Union[Tuple[str, ...], List[str]], |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | normalize: Optional[AbsNormalize], |
| | | preencoder: Optional[AbsPreEncoder], |
| | | encoder: AbsEncoder, |
| | | postencoder: Optional[AbsPostEncoder], |
| | | decoder: AbsDecoder, |
| | | ctc: CTC, |
| | | ctc_weight: float = 0.5, |
| | | interctc_weight: float = 0.0, |
| | | ignore_id: int = -1, |
| | | blank_id: int = 0, |
| | | sos: int = 1, |
| | | eos: int = 2, |
| | | lsm_weight: float = 0.0, |
| | | length_normalized_loss: bool = False, |
| | | report_cer: bool = True, |
| | | report_wer: bool = True, |
| | | sym_space: str = "<space>", |
| | | sym_blank: str = "<blank>", |
| | | extract_feats_in_collect_stats: bool = True, |
| | | predictor=None, |
| | | predictor_weight: float = 0.0, |
| | | predictor_bias: int = 0, |
| | | sampling_ratio: float = 0.2, |
| | | min_hw_length: int = 2, |
| | | max_hw_length: int = 4, |
| | | sample_rate: float = 0.6, |
| | | batch_rate: float = 0.5, |
| | | double_rate: float = -1.0, |
| | | target_buffer_length: int = -1, |
| | | inner_dim: int = 256, |
| | | bias_encoder_type: str = 'lstm', |
| | | label_bracket: bool = False, |
| | | ): |
| | | assert check_argument_types() |
| | | assert 0.0 <= ctc_weight <= 1.0, ctc_weight |
| | | assert 0.0 <= interctc_weight < 1.0, interctc_weight |
| | | |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | preencoder=preencoder, |
| | | encoder=encoder, |
| | | postencoder=postencoder, |
| | | decoder=decoder, |
| | | ctc=ctc, |
| | | ctc_weight=ctc_weight, |
| | | interctc_weight=interctc_weight, |
| | | ignore_id=ignore_id, |
| | | blank_id=blank_id, |
| | | sos=sos, |
| | | eos=eos, |
| | | lsm_weight=lsm_weight, |
| | | length_normalized_loss=length_normalized_loss, |
| | | report_cer=report_cer, |
| | | report_wer=report_wer, |
| | | sym_space=sym_space, |
| | | sym_blank=sym_blank, |
| | | extract_feats_in_collect_stats=extract_feats_in_collect_stats, |
| | | predictor=predictor, |
| | | predictor_weight=predictor_weight, |
| | | predictor_bias=predictor_bias, |
| | | sampling_ratio=sampling_ratio, |
| | | ) |
| | | |
| | | if bias_encoder_type == 'lstm': |
| | | logging.warning("enable bias encoder sampling and contextual training") |
| | | self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=0) |
| | | self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim) |
| | | else: |
| | | logging.error("Unsupport bias encoder type") |
| | | |
| | | self.min_hw_length = min_hw_length |
| | | self.max_hw_length = max_hw_length |
| | | self.sample_rate = sample_rate |
| | | self.batch_rate = batch_rate |
| | | self.target_buffer_length = target_buffer_length |
| | | self.double_rate = double_rate |
| | | |
| | | if self.target_buffer_length > 0: |
| | | self.hotword_buffer = None |
| | | self.length_record = [] |
| | | self.current_buffer_length = 0 |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | |
| | | # 2b. Attention decoder branch |
| | | if self.ctc_weight != 1.0: |
| | | loss_att, acc_att, cer_att, wer_att, loss_pre, loss_pre2 = self._calc_att_loss( |
| | | loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # 3. CTC-Att loss definition |
| | | if self.ctc_weight == 0.0: |
| | | loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight |
| | | loss = loss_att + loss_pre * self.predictor_weight |
| | | elif self.ctc_weight == 1.0: |
| | | loss = loss_ctc |
| | | else: |
| | | loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight |
| | | loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight |
| | | |
| | | # Collect Attn branch stats |
| | | stats["loss_att"] = loss_att.detach() if loss_att is not None else None |
| | |
| | | stats["cer"] = cer_att |
| | | stats["wer"] = wer_att |
| | | stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None |
| | | stats["loss_pre2"] = loss_pre2.detach().cpu() if loss_pre is not None else None |
| | | |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | return loss, stats, weight |
| | | |
| | | def _sample_hot_word(self, ys_pad, ys_pad_lens): |
| | | hw_list = [torch.Tensor([0]).long().to(ys_pad.device)] |
| | | hw_lengths = [0] # this length is actually for indice, so -1 |
| | | for i, length in enumerate(ys_pad_lens): |
| | | if length < 2: |
| | | continue |
| | | if length > self.min_hw_length + self.max_hw_length + 2 and random.random() < self.double_rate: |
| | | # sample double hotword |
| | | _max_hw_length = min(self.max_hw_length, length // 2) |
| | | # first hotword |
| | | start1 = random.randint(0, length // 3) |
| | | end1 = random.randint(start1 + self.min_hw_length - 1, start1 + _max_hw_length - 1) |
| | | hw_tokens1 = ys_pad[i][start1:end1 + 1] |
| | | hw_lengths.append(len(hw_tokens1) - 1) |
| | | hw_list.append(hw_tokens1) |
| | | # second hotword |
| | | start2 = random.randint(end1 + 1, length - self.min_hw_length) |
| | | end2 = random.randint(min(length - 1, start2 + self.min_hw_length - 1), |
| | | min(length - 1, start2 + self.max_hw_length - 1)) |
| | | hw_tokens2 = ys_pad[i][start2:end2 + 1] |
| | | hw_lengths.append(len(hw_tokens2) - 1) |
| | | hw_list.append(hw_tokens2) |
| | | continue |
| | | if random.random() < self.sample_rate: |
| | | if length == 2: |
| | | hw_tokens = ys_pad[i][:2] |
| | | hw_lengths.append(1) |
| | | hw_list.append(hw_tokens) |
| | | else: |
| | | start = random.randint(0, length - self.min_hw_length) |
| | | end = random.randint(min(length - 1, start + self.min_hw_length - 1), |
| | | min(length - 1, start + self.max_hw_length - 1)) + 1 |
| | | # print(start, end) |
| | | hw_tokens = ys_pad[i][start:end] |
| | | hw_lengths.append(len(hw_tokens) - 1) |
| | | hw_list.append(hw_tokens) |
| | | # padding |
| | | hw_list_pad = pad_list(hw_list, 0) |
| | | hw_embed = self.decoder.embed(hw_list_pad) |
| | | hw_embed, (_, _) = self.bias_encoder(hw_embed) |
| | | _ind = np.arange(0, len(hw_list)).tolist() |
| | | # update self.hotword_buffer, throw a part if oversize |
| | | selected = hw_embed[_ind, hw_lengths] |
| | | if self.target_buffer_length > 0: |
| | | _b = selected.shape[0] |
| | | if self.hotword_buffer is None: |
| | | self.hotword_buffer = selected |
| | | self.length_record.append(selected.shape[0]) |
| | | self.current_buffer_length = _b |
| | | elif self.current_buffer_length + _b < self.target_buffer_length: |
| | | self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0) |
| | | self.current_buffer_length += _b |
| | | selected = self.hotword_buffer |
| | | else: |
| | | self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0) |
| | | random_throw = random.randint(self.target_buffer_length // 2, self.target_buffer_length) + 10 |
| | | self.hotword_buffer = self.hotword_buffer[-1 * random_throw:] |
| | | selected = self.hotword_buffer |
| | | self.current_buffer_length = selected.shape[0] |
| | | return selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device) |
| | | |
| | | def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info): |
| | | |
| | | tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device) |
| | | ys_pad = ys_pad * tgt_mask[:, :, 0] |
| | | if self.share_embedding: |
| | | ys_pad_embed = self.decoder.output_layer.weight[ys_pad] |
| | | else: |
| | | ys_pad_embed = self.decoder.embed(ys_pad) |
| | | with torch.no_grad(): |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info |
| | | ) |
| | | decoder_out, _ = decoder_outs[0], decoder_outs[1] |
| | | pred_tokens = decoder_out.argmax(-1) |
| | | nonpad_positions = ys_pad.ne(self.ignore_id) |
| | | seq_lens = (nonpad_positions).sum(1) |
| | | same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1) |
| | | input_mask = torch.ones_like(nonpad_positions) |
| | | bsz, seq_len = ys_pad.size() |
| | | for li in range(bsz): |
| | | target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long() |
| | | if target_num > 0: |
| | | input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0) |
| | | input_mask = input_mask.eq(1) |
| | | input_mask = input_mask.masked_fill(~nonpad_positions, False) |
| | | input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device) |
| | | |
| | | sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill( |
| | | input_mask_expand_dim, 0) |
| | | return sematic_embeds * tgt_mask, decoder_out * tgt_mask |
| | | |
| | | def _calc_att_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ): |
| | | encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( |
| | | encoder_out.device) |
| | | if self.predictor_bias == 1: |
| | | _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) |
| | | ys_pad_lens = ys_pad_lens + self.predictor_bias |
| | | pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, |
| | | encoder_out_mask, |
| | | ignore_id=self.ignore_id) |
| | | |
| | | # sample hot word |
| | | contextual_info = self._sample_hot_word(ys_pad, ys_pad_lens) |
| | | |
| | | # 0. sampler |
| | | decoder_out_1st = None |
| | | if self.sampling_ratio > 0.0: |
| | | if self.step_cur < 2: |
| | | logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) |
| | | sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, |
| | | pre_acoustic_embeds, contextual_info) |
| | | else: |
| | | if self.step_cur < 2: |
| | | logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) |
| | | sematic_embeds = pre_acoustic_embeds |
| | | |
| | | # 1. Forward decoder |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info |
| | | ) |
| | | decoder_out, _ = decoder_outs[0], decoder_outs[1] |
| | | |
| | | if decoder_out_1st is None: |
| | | decoder_out_1st = decoder_out |
| | | # 2. Compute attention loss |
| | | loss_att = self.criterion_att(decoder_out, ys_pad) |
| | | acc_att = th_accuracy( |
| | | decoder_out_1st.view(-1, self.vocab_size), |
| | | ys_pad, |
| | | ignore_label=self.ignore_id, |
| | | ) |
| | | loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) |
| | | |
| | | # Compute cer/wer using attention-decoder |
| | | if self.training or self.error_calculator is None: |
| | | cer_att, wer_att = None, None |
| | | else: |
| | | ys_hat = decoder_out_1st.argmax(dim=-1) |
| | | cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) |
| | | |
| | | return loss_att, acc_att, cer_att, wer_att, loss_pre |
| | | |
| | | def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None): |
| | | if hw_list is None: |
| | | # default hotword list |
| | | hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)] # empty hotword list |
| | | hw_list_pad = pad_list(hw_list, 0) |
| | | hw_embed = self.bias_embed(hw_list_pad) |
| | | _, (h_n, _) = self.bias_encoder(hw_embed) |
| | | contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1) |
| | | else: |
| | | hw_lengths = [len(i) for i in hw_list] |
| | | hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device) |
| | | hw_embed = self.bias_embed(hw_list_pad) |
| | | hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True, |
| | | enforce_sorted=False) |
| | | _, (h_n, _) = self.bias_encoder(hw_embed) |
| | | # hw_embed, _ = torch.nn.utils.rnn.pad_packed_sequence(hw_embed, batch_first=True) |
| | | contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1) |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info |
| | | ) |
| | | decoder_out = decoder_outs[0] |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |
| | | return decoder_out, ys_pad_lens |
| | | |
| | | def gen_clas_tf2torch_map_dict(self): |
| | | tensor_name_prefix_torch = "bias_encoder" |
| | | tensor_name_prefix_tf = "seq2seq/clas_charrnn" |
| | | |
| | | tensor_name_prefix_torch_emb = "bias_embed" |
| | | tensor_name_prefix_tf_emb = "seq2seq" |
| | | |
| | | map_dict_local = { |
| | | # in lstm |
| | | "{}.weight_ih_l0".format(tensor_name_prefix_torch): |
| | | {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": (1, 0), |
| | | "slice": (0, 512), |
| | | "unit_k": 512, |
| | | }, # (1024, 2048),(2048,512) |
| | | "{}.weight_hh_l0".format(tensor_name_prefix_torch): |
| | | {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": (1, 0), |
| | | "slice": (512, 1024), |
| | | "unit_k": 512, |
| | | }, # (1024, 2048),(2048,512) |
| | | "{}.bias_ih_l0".format(tensor_name_prefix_torch): |
| | | {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | "scale": 0.5, |
| | | "unit_b": 512, |
| | | }, # (2048,),(2048,) |
| | | "{}.bias_hh_l0".format(tensor_name_prefix_torch): |
| | | {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | "scale": 0.5, |
| | | "unit_b": 512, |
| | | }, # (2048,),(2048,) |
| | | |
| | | # in embed |
| | | "{}.weight".format(tensor_name_prefix_torch_emb): |
| | | {"name": "{}/contextual_encoder/w_char_embs".format(tensor_name_prefix_tf_emb), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (4235,256),(4235,256) |
| | | } |
| | | return map_dict_local |
| | | |
| | | def clas_convert_tf2torch(self, |
| | | var_dict_tf, |
| | | var_dict_torch): |
| | | map_dict = self.gen_clas_tf2torch_map_dict() |
| | | var_dict_torch_update = dict() |
| | | for name in sorted(var_dict_torch.keys(), reverse=False): |
| | | names = name.split('.') |
| | | if names[0] == "bias_encoder": |
| | | name_q = name |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q].get("unit_k") is not None: |
| | | dim = map_dict[name_q]["unit_k"] |
| | | i = data_tf[:, 0:dim].copy() |
| | | f = data_tf[:, dim:2 * dim].copy() |
| | | o = data_tf[:, 2 * dim:3 * dim].copy() |
| | | g = data_tf[:, 3 * dim:4 * dim].copy() |
| | | data_tf = np.concatenate([i, o, f, g], axis=1) |
| | | if map_dict[name_q].get("unit_b") is not None: |
| | | dim = map_dict[name_q]["unit_b"] |
| | | i = data_tf[0:dim].copy() |
| | | f = data_tf[dim:2 * dim].copy() |
| | | o = data_tf[2 * dim:3 * dim].copy() |
| | | g = data_tf[3 * dim:4 * dim].copy() |
| | | data_tf = np.concatenate([i, o, f, g], axis=0) |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q].get("slice") is not None: |
| | | data_tf = data_tf[map_dict[name_q]["slice"][0]:map_dict[name_q]["slice"][1]] |
| | | if map_dict[name_q].get("scale") is not None: |
| | | data_tf = data_tf * map_dict[name_q]["scale"] |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | elif names[0] == "bias_embed": |
| | | name_tf = map_dict[name]["name"] |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) |
| | | if map_dict[name]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | return var_dict_torch_update |
| | |
| | | token_num_int = torch.max(token_num).type(torch.int32).item()
|
| | | acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
| | | return acoustic_embeds, token_num, alphas, cif_peak, token_num2
|
| | | |
| | | def get_upsample_timestamp(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
|
| | | target_label_length=None, token_num=None):
|
| | |
|
| | | def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
|
| | | h = hidden
|
| | | b = hidden.shape[0]
|
| | | context = h.transpose(1, 2)
|
| | |
| | | from funasr.iterators.chunk_iter_factory import ChunkIterFactory |
| | | from funasr.iterators.multiple_iter_factory import MultipleIterFactory |
| | | from funasr.iterators.sequence_iter_factory import SequenceIterFactory |
| | | from funasr.main_funcs.collect_stats import collect_stats |
| | | from funasr.optimizers.sgd import SGD |
| | | from funasr.optimizers.fairseq_adam import FairseqAdam |
| | | from funasr.samplers.build_batch_sampler import BATCH_TYPES |
| | |
| | | |
| | | if args.dry_run: |
| | | pass |
| | | elif args.collect_stats: |
| | | # Perform on collect_stats mode. This mode has two roles |
| | | # - Derive the length and dimension of all input data |
| | | # - Accumulate feats, square values, and the length for whitening |
| | | |
| | | if args.valid_batch_size is None: |
| | | args.valid_batch_size = args.batch_size |
| | | |
| | | if len(args.train_shape_file) != 0: |
| | | train_key_file = args.train_shape_file[0] |
| | | else: |
| | | train_key_file = None |
| | | if len(args.valid_shape_file) != 0: |
| | | valid_key_file = args.valid_shape_file[0] |
| | | else: |
| | | valid_key_file = None |
| | | |
| | | collect_stats( |
| | | model=model, |
| | | train_iter=cls.build_streaming_iterator( |
| | | data_path_and_name_and_type=args.train_data_path_and_name_and_type, |
| | | key_file=train_key_file, |
| | | batch_size=args.batch_size, |
| | | dtype=args.train_dtype, |
| | | num_workers=args.num_workers, |
| | | allow_variable_data_keys=args.allow_variable_data_keys, |
| | | ngpu=args.ngpu, |
| | | preprocess_fn=cls.build_preprocess_fn(args, train=False), |
| | | collate_fn=cls.build_collate_fn(args, train=False), |
| | | ), |
| | | valid_iter=cls.build_streaming_iterator( |
| | | data_path_and_name_and_type=args.valid_data_path_and_name_and_type, |
| | | key_file=valid_key_file, |
| | | batch_size=args.valid_batch_size, |
| | | dtype=args.train_dtype, |
| | | num_workers=args.num_workers, |
| | | allow_variable_data_keys=args.allow_variable_data_keys, |
| | | ngpu=args.ngpu, |
| | | preprocess_fn=cls.build_preprocess_fn(args, train=False), |
| | | collate_fn=cls.build_collate_fn(args, train=False), |
| | | ), |
| | | output_dir=output_dir, |
| | | ngpu=args.ngpu, |
| | | log_interval=args.log_interval, |
| | | write_collected_feats=args.write_collected_feats, |
| | | ) |
| | | else: |
| | | logging.info("Training args: {}".format(args)) |
| | | # 6. Loads pre-trained model |
| | |
| | | ) |
| | | from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN |
| | | from funasr.models.decoder.transformer_decoder import TransformerDecoder |
| | | from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder |
| | | from funasr.models.e2e_asr import ESPnetASRModel |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer |
| | | from funasr.models.e2e_uni_asr import UniASR |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder |
| | |
| | | paraformer=Paraformer, |
| | | paraformer_bert=ParaformerBert, |
| | | bicif_paraformer=BiCifParaformer, |
| | | contextual_paraformer=ContextualParaformer, |
| | | ), |
| | | type_check=AbsESPnetModel, |
| | | default="asr", |
| | |
| | | fsmn_scama_opt=FsmnDecoderSCAMAOpt, |
| | | paraformer_decoder_sanm=ParaformerSANMDecoder, |
| | | paraformer_decoder_san=ParaformerDecoderSAN, |
| | | contextual_paraformer_decoder=ContextualParaformerDecoder, |
| | | ), |
| | | type_check=AbsDecoder, |
| | | default="rnn", |
| | |
| | | # decoder |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # bias_encoder |
| | | var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | |
| | | return var_dict_torch_update |
| | |
| | | # NOTE(kamo): add_arguments(..., required=True) can't be used |
| | | # to provide --print_config mode. Instead of it, do as |
| | | required = parser.get_default("required") |
| | | required += ["token_list"] |
| | | # required += ["token_list"] |
| | | |
| | | group.add_argument( |
| | | "--token_list", |
| | |
| | | return sentence, ts_lists, real_word_lists |
| | | else: |
| | | word_lists = abbr_dispose(word_lists) |
| | | real_word_lists = [] |
| | | for ch in word_lists: |
| | | if ch != ' ': |
| | | real_word_lists.append(ch) |
| | | sentence = ''.join(word_lists).strip() |
| | | return sentence |
| | | return sentence, real_word_lists |
| | |
| | | else: |
| | | return time_stamp_list |
| | | |
| | | |
| | | def time_stamp_lfr6_advance(tst: List, text: str): |
| | | # advanced timestamp prediction for BiCIF_Paraformer using upsampled alphas |
| | | ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = tst |
| | | if text.endswith('</s>'): |
| | | text = text[:-4] |
| | | def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None): |
| | | START_END_THRESHOLD = 5 |
| | | TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled |
| | | if len(us_alphas.shape) == 3: |
| | | alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only |
| | | else: |
| | | text = text[:-1] |
| | | logging.warning("found text does not end with </s>") |
| | | assert int(ds_alphas.sum() + 1e-4) - 1 == len(text) |
| | | |
| | | alphas, cif_peak = us_alphas, us_cif_peak |
| | | num_frames = cif_peak.shape[0] |
| | | if char_list[-1] == '</s>': |
| | | char_list = char_list[:-1] |
| | | # char_list = [i for i in text] |
| | | timestamp_list = [] |
| | | # for bicif model trained with large data, cif2 actually fires when a character starts |
| | | # so treat the frames between two peaks as the duration of the former token |
| | | fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5 |
| | | num_peak = len(fire_place) |
| | | assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1 |
| | | # begin silence |
| | | if fire_place[0] > START_END_THRESHOLD: |
| | | char_list.insert(0, '<sil>') |
| | | timestamp_list.append([0.0, fire_place[0]*TIME_RATE]) |
| | | # tokens timestamp |
| | | for i in range(len(fire_place)-1): |
| | | # the peak is always a little ahead of the start time |
| | | # timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE]) |
| | | timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE]) |
| | | # cut the duration to token and sil of the 0-weight frames last long |
| | | # tail token and end silence |
| | | if num_frames - fire_place[-1] > START_END_THRESHOLD: |
| | | _end = (num_frames + fire_place[-1]) / 2 |
| | | timestamp_list[-1][1] = _end*TIME_RATE |
| | | timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE]) |
| | | char_list.append("<sil>") |
| | | else: |
| | | timestamp_list[-1][1] = num_frames*TIME_RATE |
| | | if begin_time: # add offset time in model with vad |
| | | for i in range(len(timestamp_list)): |
| | | timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0 |
| | | timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0 |
| | | res_txt = "" |
| | | for char, timestamp in zip(char_list, timestamp_list): |
| | | res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1]) |
| | | res = [] |
| | | for char, timestamp in zip(char_list, timestamp_list): |
| | | if char != '<sil>': |
| | | res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)]) |
| | | return res |
| | | |