From 59a791121fccd3c9ca177c4f6d33105a82d23ef3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 10 二月 2023 15:30:49 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/bin/lm_inference_launch.py | 130 ++
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/RESULTS.md | 25
funasr/bin/asr_inference_paraformer.py | 61 +
funasr/bin/asr_inference_uniasr_vad.py | 2
funasr/bin/lm_calc_perplexity.py | 3
funasr/datasets/preprocessor.py | 73 +
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py | 87 +
funasr/bin/lm_inference.py | 406 ++++++++
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py | 52 +
funasr/bin/asr_inference_paraformer_vad_punc.py | 38
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/RESULTS.md | 23
funasr/bin/lm_train.py | 50
funasr/models/e2e_asr_paraformer.py | 496 +++++++++-
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/finetune.py | 37
funasr/models/decoder/contextual_decoder.py | 776 ++++++++++++++++
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md | 75 +
funasr/bin/asr_inference_paraformer_vad.py | 19
funasr/tasks/lm.py | 2
funasr/tasks/asr.py | 8
funasr/tasks/abs_task.py | 47
funasr/utils/postprocess_utils.py | 6
funasr/lm/espnet_model.py | 4
funasr/bin/asr_inference_uniasr.py | 2
funasr/bin/asr_inference.py | 2
funasr/utils/timestamp_tools.py | 57
funasr/models/predictor/cif.py | 5
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/README.md | 53 +
funasr/bin/tokenize_text.py | 283 +++++
funasr/bin/asr_inference_paraformer_timestamp.py | 2
egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py | 17
30 files changed, 2,737 insertions(+), 104 deletions(-)
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/README.md b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/README.md
new file mode 100644
index 0000000..c2e4354
--- /dev/null
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/README.md
@@ -0,0 +1,53 @@
+# ModelScope Model
+
+## How to finetune and infer using a pretrained Paraformer-large Model
+
+### Finetune
+
+- Modify finetune training related parameters in `finetune.py`
+ - <strong>output_dir:</strong> # result dir
+ - <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
+ - <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
+ - <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
+ - <strong>max_epoch:</strong> # number of training epoch
+ - <strong>lr:</strong> # learning rate
+
+- Then you can run the pipeline to finetune with:
+```python
+ python finetune.py
+```
+
+### Inference
+
+Or you can use the finetuned model for inference directly.
+
+- Setting parameters in `infer.py`
+ - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
+ - <strong>output_dir:</strong> # result dir
+ - <strong>ngpu:</strong> # the number of GPUs for decoding
+ - <strong>njob:</strong> # the number of jobs for each GPU
+
+- Then you can run the pipeline to infer with:
+```python
+ python infer.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/1best_recog/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
+
+### Inference using local finetuned model
+
+- Modify inference related parameters in `infer_after_finetune.py`
+ - <strong>output_dir:</strong> # result dir
+ - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed~~~~
+ - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
+
+- Then you can run the pipeline to finetune with:
+```python
+ python infer_after_finetune.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/decoding_results/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/finetune.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/finetune.py
new file mode 100644
index 0000000..a5f1ee4
--- /dev/null
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/finetune.py
@@ -0,0 +1,37 @@
+import os
+
+from modelscope.metainfo import Trainers
+from modelscope.trainers import build_trainer
+
+from funasr.datasets.ms_dataset import MsDataset
+from funasr.utils.modelscope_param import modelscope_args
+
+
+def modelscope_finetune(params):
+ if not os.path.exists(params.output_dir):
+ os.makedirs(params.output_dir, exist_ok=True)
+ # dataset split ["train", "validation"]
+ ds_dict = MsDataset.load(params.data_path)
+ kwargs = dict(
+ model=params.model,
+ data_dir=ds_dict,
+ dataset_type=params.dataset_type,
+ work_dir=params.output_dir,
+ batch_bins=params.batch_bins,
+ max_epoch=params.max_epoch,
+ lr=params.lr)
+ trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
+ trainer.train()
+
+
+if __name__ == '__main__':
+ params = modelscope_args(model="damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k",
+ data_path="./data")
+ params.output_dir = "./checkpoint"
+ params.data_path = "./example_data/"
+ params.dataset_type = "small"
+ params.batch_bins = 16000
+ params.max_epoch = 50
+ params.lr = 0.00002
+
+ modelscope_finetune(params)
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py
new file mode 100644
index 0000000..c016c19
--- /dev/null
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py
@@ -0,0 +1,87 @@
+import os
+import shutil
+from multiprocessing import Pool
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+from funasr.utils.compute_wer import compute_wer
+
+
+def modelscope_infer_core(output_dir, split_dir, njob, idx):
+ output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
+ gpu_id = (int(idx) - 1) // njob
+ if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
+ gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
+ else:
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
+ inference_pipline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model="damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k",
+ output_dir=output_dir_job,
+ )
+ audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
+ inference_pipline(audio_in=audio_in)
+
+
+def modelscope_infer(params):
+ # prepare for multi-GPU decoding
+ ngpu = params["ngpu"]
+ njob = params["njob"]
+ output_dir = params["output_dir"]
+ if os.path.exists(output_dir):
+ shutil.rmtree(output_dir)
+ os.mkdir(output_dir)
+ split_dir = os.path.join(output_dir, "split")
+ os.mkdir(split_dir)
+ nj = ngpu * njob
+ wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ num_lines = len(lines)
+ num_job_lines = num_lines // nj
+ start = 0
+ for i in range(nj):
+ end = start + num_job_lines
+ file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
+ with open(file, "w") as f:
+ if i == nj - 1:
+ f.writelines(lines[start:])
+ else:
+ f.writelines(lines[start:end])
+ start = end
+
+ p = Pool(nj)
+ for i in range(nj):
+ p.apply_async(modelscope_infer_core,
+ args=(output_dir, split_dir, njob, str(i + 1)))
+ p.close()
+ p.join()
+
+ # combine decoding results
+ best_recog_path = os.path.join(output_dir, "1best_recog")
+ os.mkdir(best_recog_path)
+ files = ["text", "token", "score"]
+ for file in files:
+ with open(os.path.join(best_recog_path, file), "w") as f:
+ for i in range(nj):
+ job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
+ with open(job_file) as f_job:
+ lines = f_job.readlines()
+ f.writelines(lines)
+
+ # If text exists, compute CER
+ text_in = os.path.join(params["data_dir"], "text")
+ if os.path.exists(text_in):
+ text_proc_file = os.path.join(best_recog_path, "token")
+ compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
+
+
+if __name__ == "__main__":
+ params = {}
+ params["data_dir"] = "./data/test"
+ params["output_dir"] = "./results"
+ params["ngpu"] = 2
+ params["njob"] = 5
+ modelscope_infer(params)
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py
new file mode 100644
index 0000000..56c282c
--- /dev/null
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py
@@ -0,0 +1,52 @@
+import json
+import os
+import shutil
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+from funasr.utils.compute_wer import compute_wer
+
+
+def modelscope_infer_after_finetune(params):
+ # prepare for decoding
+ pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
+ for file_name in params["required_files"]:
+ if file_name == "configuration.json":
+ with open(os.path.join(pretrained_model_path, file_name)) as f:
+ config_dict = json.load(f)
+ config_dict["model"]["am_model_name"] = params["decoding_model_name"]
+ with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
+ json.dump(config_dict, f, indent=4, separators=(',', ': '))
+ else:
+ shutil.copy(os.path.join(pretrained_model_path, file_name),
+ os.path.join(params["output_dir"], file_name))
+ decoding_path = os.path.join(params["output_dir"], "decode_results")
+ if os.path.exists(decoding_path):
+ shutil.rmtree(decoding_path)
+ os.mkdir(decoding_path)
+
+ # decoding
+ inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=params["output_dir"],
+ output_dir=decoding_path,
+ )
+ audio_in = os.path.join(params["data_dir"], "wav.scp")
+ inference_pipeline(audio_in=audio_in)
+
+ # computer CER if GT text is set
+ text_in = os.path.join(params["data_dir"], "text")
+ if os.path.exists(text_in):
+ text_proc_file = os.path.join(decoding_path, "1best_recog/token")
+ compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
+
+
+if __name__ == '__main__':
+ params = {}
+ params["modelscope_model_name"] = "damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k"
+ params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
+ params["output_dir"] = "./checkpoint"
+ params["data_dir"] = "./data/test"
+ params["decoding_model_name"] = "valid.cer_ctc.ave.pth"
+ modelscope_infer_after_finetune(params)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/RESULTS.md b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/RESULTS.md
new file mode 100644
index 0000000..5eeae37
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/RESULTS.md
@@ -0,0 +1,23 @@
+# Paraformer-Large
+- Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/summary>
+- Model size: 220M
+
+# Environments
+- date: `Fri Feb 10 13:34:24 CST 2023`
+- python version: `3.7.12`
+- FunASR version: `0.1.6`
+- pytorch version: `pytorch 1.7.0`
+- Git hash: ``
+- Commit date: ``
+
+# Beachmark Results
+
+## AISHELL-1
+- Decode config:
+ - Decode without CTC
+ - Decode without LM
+
+| testset CER(%) | base model|finetune model |
+|:--------------:|:---------:|:-------------:|
+| dev | 1.75 |1.62 |
+| test | 1.95 |1.78 |
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/RESULTS.md b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/RESULTS.md
new file mode 100644
index 0000000..71d9fee
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/RESULTS.md
@@ -0,0 +1,25 @@
+# Paraformer-Large
+- Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/summary>
+- Model size: 220M
+
+# Environments
+- date: `Fri Feb 10 13:34:24 CST 2023`
+- python version: `3.7.12`
+- FunASR version: `0.1.6`
+- pytorch version: `pytorch 1.7.0`
+- Git hash: ``
+- Commit date: ``
+
+# Beachmark Results
+
+## AISHELL-2
+- Decode config:
+ - Decode without CTC
+ - Decode without LM
+
+| testset | base model|finetune model|
+|:------------:|:---------:|:------------:|
+| dev_ios | 2.80 |2.60 |
+| test_android | 3.13 |2.84 |
+| test_ios | 2.85 |2.82 |
+| test_mic | 3.06 |2.88 |
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md
new file mode 100644
index 0000000..ec95be3
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md
@@ -0,0 +1,75 @@
+# Paraformer-Large
+- Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary>
+- Model size: 220M
+
+# Environments
+- date: `Tue Nov 22 18:48:39 CST 2022`
+- python version: `3.7.12`
+- FunASR version: `0.1.0`
+- pytorch version: `pytorch 1.7.0`
+- Git hash: ``
+- Commit date: ``
+
+# Beachmark Results
+
+## AISHELL-1
+- Decode config:
+ - Decode without CTC
+ - Decode without LM
+
+| testset | CER(%)|
+|:---------:|:-----:|
+| dev | 1.75 |
+| test | 1.95 |
+
+## AISHELL-2
+- Decode config:
+ - Decode without CTC
+ - Decode without LM
+
+| testset | CER(%)|
+|:------------:|:-----:|
+| dev_ios | 2.80 |
+| test_android | 3.13 |
+| test_ios | 2.85 |
+| test_mic | 3.06 |
+
+## Wenetspeech
+- Decode config:
+ - Decode without CTC
+ - Decode without LM
+
+| testset | CER(%)|
+|:---------:|:-----:|
+| dev | 3.57 |
+| test | 6.97 |
+| test_net | 6.74 |
+
+## SpeechIO TIOBE
+- Decode config 1:
+ - Decode without CTC
+ - Decode without LM
+ - With text norm
+- Decode config 2:
+ - Decode without CTC
+ - Decode with Transformer-LM
+ - LM weight: 0.15
+ - With text norm
+
+| testset | w/o LM | w/ LM |
+|:------------------:|:----:|:----:|
+|SPEECHIO_ASR_ZH00001| 0.49 | 0.35 |
+|SPEECHIO_ASR_ZH00002| 3.23 | 2.86 |
+|SPEECHIO_ASR_ZH00003| 1.13 | 0.80 |
+|SPEECHIO_ASR_ZH00004| 1.33 | 1.10 |
+|SPEECHIO_ASR_ZH00005| 1.41 | 1.18 |
+|SPEECHIO_ASR_ZH00006| 5.25 | 4.85 |
+|SPEECHIO_ASR_ZH00007| 5.51 | 4.97 |
+|SPEECHIO_ASR_ZH00008| 3.69 | 3.18 |
+|SPEECHIO_ASR_ZH00009| 3.02 | 2.78 |
+|SPEECHIO_ASR_ZH000010| 3.35 | 2.99 |
+|SPEECHIO_ASR_ZH000011| 1.54 | 1.25 |
+|SPEECHIO_ASR_ZH000012| 2.06 | 1.68 |
+|SPEECHIO_ASR_ZH000013| 2.57 | 2.25 |
+|SPEECHIO_ASR_ZH000014| 3.86 | 3.08 |
+|SPEECHIO_ASR_ZH000015| 3.34 | 2.67 |
diff --git a/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py b/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py
new file mode 100644
index 0000000..ed3b7e2
--- /dev/null
+++ b/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py
@@ -0,0 +1,17 @@
+
+
+##################text浜岃繘鍒舵暟鎹�#####################
+inputs = "hello 澶� 瀹� 濂� 鍛�"
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipline = pipeline(
+ task=Tasks.language_model,
+ model='damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch',
+ output_dir="./tmp/"
+)
+
+rec_result = inference_pipline(text_in=inputs)
+print(rec_result)
+
diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index 16fa3e5..ca8f2bc 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -453,7 +453,7 @@
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 3769b6c..6c5acfc 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -3,6 +3,9 @@
import logging
import sys
import time
+import copy
+import os
+import codecs
from pathlib import Path
from typing import Optional
from typing import Sequence
@@ -35,6 +38,8 @@
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
+
header_colors = '\033[95m'
end_colors = '\033[0m'
@@ -78,6 +83,7 @@
penalty: float = 0.0,
nbest: int = 1,
frontend_conf: dict = None,
+ hotword_list_or_file: str = None,
**kwargs,
):
assert check_argument_types()
@@ -168,6 +174,34 @@
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
+
+ # 6. [Optional] Build hotword list from file or str
+ if hotword_list_or_file is None:
+ self.hotword_list = None
+ elif os.path.exists(hotword_list_or_file):
+ self.hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hotword_str_list.append(hw)
+ self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ self.hotword_list.append([1])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ else:
+ logging.info("Attempting to parse hotwords as str...")
+ self.hotword_list = []
+ hotword_str_list = []
+ for hw in hotword_list_or_file.strip().split():
+ hotword_str_list.append(hw)
+ self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ self.hotword_list.append([1])
+ hotword_str_list.append('<s>')
+ logging.info("Hotword list: {}.".format(hotword_str_list))
+
+
is_use_lm = lm_weight != 0.0 and lm_file is not None
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
beam_search = None
@@ -229,8 +263,14 @@
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ if not isinstance(self.asr_model, ContextualParaformer):
+ if self.hotword_list:
+ logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ else:
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
results = []
b, n, d = decoder_out.size()
@@ -388,6 +428,11 @@
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+ else:
+ hotword_list_or_file = None
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
@@ -416,6 +461,7 @@
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
+ hotword_list_or_file=hotword_list_or_file,
)
speech2text = Speech2Text(**speech2text_kwargs)
@@ -497,7 +543,7 @@
ibest_writer["rtf"][key] = rtf_cur
if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
@@ -551,7 +597,12 @@
default=1,
help="The number of workers used for DataLoader",
)
-
+ parser.add_argument(
+ "--hotword",
+ type=str_or_none,
+ default=None,
+ help="hotword file path or hotwords seperated by space"
+ )
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
@@ -679,8 +730,10 @@
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
+ param_dict = {'hotword': args.hotword}
kwargs = vars(args)
kwargs.pop("config", None)
+ kwargs['param_dict'] = param_dict
inference(**kwargs)
diff --git a/funasr/bin/asr_inference_paraformer_timestamp.py b/funasr/bin/asr_inference_paraformer_timestamp.py
index 7e2e414..7da48e2 100644
--- a/funasr/bin/asr_inference_paraformer_timestamp.py
+++ b/funasr/bin/asr_inference_paraformer_timestamp.py
@@ -436,7 +436,7 @@
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py
index 2832504..dbb2719 100644
--- a/funasr/bin/asr_inference_paraformer_vad.py
+++ b/funasr/bin/asr_inference_paraformer_vad.py
@@ -241,6 +241,11 @@
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
+
+ if param_dict is not None:
+ use_timestamp = param_dict.get('use_timestamp', True)
+ else:
+ use_timestamp = True
finish_count = 0
file_count = 1
@@ -284,8 +289,10 @@
text, token, token_int = result[0], result[1], result[2]
time_stamp = None if len(result) < 4 else result[3]
-
- postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+ if use_timestamp and time_stamp is not None:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+ else:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token)
text_postprocessed = ""
time_stamp_postprocessed = ""
text_postprocessed_punc = postprocessed_result
@@ -293,9 +300,11 @@
text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
postprocessed_result[1], \
postprocessed_result[2]
- text_postprocessed_punc = text_postprocessed
- if len(word_lists) > 0 and text2punc is not None:
- text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
+ else:
+ text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
+ text_postprocessed_punc = text_postprocessed
+ if len(word_lists) > 0 and text2punc is not None:
+ text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
item = {'key': key, 'value': text_postprocessed_punc}
diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index 1d09c79..c4bb61b 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -14,6 +14,7 @@
from typing import Any
from typing import List
import math
+import copy
import numpy as np
import torch
from typeguard import check_argument_types
@@ -38,8 +39,9 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_lfr6
+from funasr.utils.timestamp_tools import time_stamp_lfr6, time_stamp_lfr6_pl
from funasr.bin.punctuation_infer import Text2Punc
+from funasr.models.e2e_asr_paraformer import BiCifParaformer
header_colors = '\033[95m'
end_colors = '\033[0m'
@@ -234,6 +236,10 @@
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ if isinstance(self.asr_model, BiCifParaformer):
+ _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
+ pre_token_length) # test no bias cif2
+
results = []
b, n, d = decoder_out.size()
for i in range(b):
@@ -276,9 +282,12 @@
else:
text = None
- time_stamp = time_stamp_lfr6(alphas[i:i+1,], enc_len[i:i+1,], token, begin_time, end_time)
-
- results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
+ if isinstance(self.asr_model, BiCifParaformer):
+ timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
+ results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
+ else:
+ time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time, end_time)
+ results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
# assert check_return_type(results)
return results
@@ -561,6 +570,11 @@
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
+
+ if param_dict is not None:
+ use_timestamp = param_dict.get('use_timestamp', True)
+ else:
+ use_timestamp = True
finish_count = 0
file_count = 1
@@ -603,8 +617,11 @@
result = result_segments[0]
text, token, token_int = result[0], result[1], result[2]
time_stamp = None if len(result) < 4 else result[3]
-
- postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+
+ if use_timestamp and time_stamp is not None:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+ else:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token)
text_postprocessed = ""
time_stamp_postprocessed = ""
text_postprocessed_punc = postprocessed_result
@@ -612,9 +629,12 @@
text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
postprocessed_result[1], \
postprocessed_result[2]
- text_postprocessed_punc = text_postprocessed
- if len(word_lists) > 0 and text2punc is not None:
- text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
+ else:
+ text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
+
+ text_postprocessed_punc = text_postprocessed
+ if len(word_lists) > 0 and text2punc is not None:
+ text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
item = {'key': key, 'value': text_postprocessed_punc}
if text_postprocessed != "":
diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
index cfec9a0..0a5824c 100644
--- a/funasr/bin/asr_inference_uniasr.py
+++ b/funasr/bin/asr_inference_uniasr.py
@@ -492,7 +492,7 @@
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
diff --git a/funasr/bin/asr_inference_uniasr_vad.py b/funasr/bin/asr_inference_uniasr_vad.py
index cfec9a0..0a5824c 100644
--- a/funasr/bin/asr_inference_uniasr_vad.py
+++ b/funasr/bin/asr_inference_uniasr_vad.py
@@ -492,7 +492,7 @@
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
diff --git a/funasr/bin/lm_calc_perplexity.py b/funasr/bin/lm_calc_perplexity.py
index 27a8a71..198d578 100755
--- a/funasr/bin/lm_calc_perplexity.py
+++ b/funasr/bin/lm_calc_perplexity.py
@@ -56,7 +56,7 @@
set_all_random_seed(seed)
# 2. Build LM
- model, train_args = LMTask.build_model_from_file(train_config, model_file, device)
+ model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device)
# Wrape model to make model.nll() data-parallel
wrapped_model = ForwardAdaptor(model, "nll")
wrapped_model.to(dtype=getattr(torch, dtype)).eval()
@@ -111,6 +111,7 @@
utt_ppl = log_base ** (_nll / ntoken / np.log(log_base))
# Write PPL of each utts for debugging or analysis
+ writer["utt2nll"][key] = str(-_nll)
writer["utt2ppl"][key] = str(utt_ppl)
writer["utt2ntokens"][key] = str(ntoken)
diff --git a/funasr/bin/lm_inference.py b/funasr/bin/lm_inference.py
new file mode 100644
index 0000000..909cb02
--- /dev/null
+++ b/funasr/bin/lm_inference.py
@@ -0,0 +1,406 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+from pathlib import Path
+import sys
+import os
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+from typing import Any
+from typing import List
+
+import numpy as np
+import torch
+from torch.nn.parallel import data_parallel
+from typeguard import check_argument_types
+
+from funasr.tasks.lm import LMTask
+from funasr.datasets.preprocessor import LMPreprocessor
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import float_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+
+def inference(
+ output_dir: str,
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ log_base: Optional[float],
+ key_file: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ split_with_space: Optional[bool] = False,
+ seg_dict_file: Optional[str] = None,
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ **kwargs,
+):
+ inference_pipeline = inference_modelscope(
+ output_dir=output_dir,
+ raw_inputs=raw_inputs,
+ batch_size=batch_size,
+ dtype=dtype,
+ ngpu=ngpu,
+ seed=seed,
+ num_workers=num_workers,
+ log_level=log_level,
+ key_file=key_file,
+ train_config=train_config,
+ model_file=model_file,
+ log_base = log_base,
+ allow_variable_data_keys = allow_variable_data_keys,
+ split_with_space=split_with_space,
+ seg_dict_file=seg_dict_file,
+ **kwargs,
+ )
+ return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+
+def inference_modelscope(
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ log_base: Optional[float] = 10,
+ allow_variable_data_keys: bool = False,
+ split_with_space: Optional[bool] = False,
+ seg_dict_file: Optional[str] = None,
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build Model
+ model, train_args = LMTask.build_model_from_file(
+ train_config, model_file, device)
+ wrapped_model = ForwardAdaptor(model, "nll")
+ wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+ logging.info(f"Model:\n{model}")
+
+ preprocessor = LMPreprocessor(
+ train=False,
+ token_type=train_args.token_type,
+ token_list=train_args.token_list,
+ bpemodel=train_args.bpemodel,
+ text_cleaner=train_args.cleaner,
+ g2p_type=train_args.g2p,
+ text_name="text",
+ non_linguistic_symbols=train_args.non_linguistic_symbols,
+ split_with_space=split_with_space,
+ seg_dict_file=seg_dict_file
+ )
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: dict = None,
+ ):
+ results = []
+ if output_dir_v2 is not None:
+ writer = DatadirWriter(output_dir_v2)
+ else:
+ writer = None
+
+ if raw_inputs != None:
+ line = raw_inputs.strip()
+ key = "lm demo"
+ if line=="":
+ item = {'key': key, 'value': ""}
+ results.append(item)
+ return results
+ batch = {}
+ batch['text'] = line
+ if preprocessor != None:
+ batch = preprocessor(key, batch)
+
+ # Force data-precision
+ for name in batch:
+ value = batch[name]
+ if not isinstance(value, np.ndarray):
+ raise RuntimeError(
+ f"All values must be converted to np.ndarray object "
+ f'by preprocessing, but "{name}" is still {type(value)}.'
+ )
+ # Cast to desired type
+ if value.dtype.kind == "f":
+ value = value.astype("float32")
+ elif value.dtype.kind == "i":
+ value = value.astype("long")
+ else:
+ raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+ batch[name] = value
+
+ batch["text_lengths"] = torch.from_numpy(
+ np.array([len(batch["text"])], dtype='int32'))
+ batch["text"] = np.expand_dims(batch["text"], axis=0)
+
+ with torch.no_grad():
+ batch = to_device(batch, device)
+ if ngpu <= 1:
+ nll, lengths = wrapped_model(**batch)
+ else:
+ nll, lengths = data_parallel(
+ wrapped_model, (), range(ngpu), module_kwargs=batch
+ )
+ ## compute ppl
+ ppl_out_batch = ""
+ ids2tokens = preprocessor.token_id_converter.ids2tokens
+ for sent_ids, sent_nll in zip(batch['text'], nll):
+ pre_word = "<s>"
+ cur_word = None
+ sent_lst = ids2tokens(sent_ids) + ['</s>']
+ ppl_out = " ".join(sent_lst) + "\n"
+ for word, word_nll in zip(sent_lst, sent_nll):
+ cur_word = word
+ word_nll = -word_nll.cpu()
+ if log_base is None:
+ word_prob = np.exp(word_nll)
+ else:
+ word_prob = log_base ** (word_nll / np.log(log_base))
+ ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+ cur=cur_word,
+ pre=pre_word,
+ prob=round(word_prob.item(), 8),
+ word_nll=round(word_nll.item(), 8)
+ )
+ pre_word = cur_word
+
+ sent_nll_mean = sent_nll.mean().cpu().numpy()
+ sent_nll_sum = sent_nll.sum().cpu().numpy()
+ if log_base is None:
+ sent_ppl = np.exp(sent_nll_mean)
+ else:
+ sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+ ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+ sent_nll=round(-sent_nll_sum.item(), 4),
+ sent_ppl=round(sent_ppl.item(), 4)
+ )
+ ppl_out_batch += ppl_out
+ item = {'key': key, 'value': ppl_out}
+ if writer is not None:
+ writer["ppl"][key+":\n"] = ppl_out
+ results.append(item)
+
+ return results
+
+ # 3. Build data-iterator
+ loader = LMTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=preprocessor,
+ collate_fn=LMTask.build_collate_fn(train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ # 4. Start for-loop
+ total_nll = 0.0
+ total_ntokens = 0
+ ppl_out_all = ""
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+
+ ppl_out_batch = ""
+ with torch.no_grad():
+ batch = to_device(batch, device)
+ if ngpu <= 1:
+ # NOTE(kamo): data_parallel also should work with ngpu=1,
+ # but for debuggability it's better to keep this block.
+ nll, lengths = wrapped_model(**batch)
+ else:
+ nll, lengths = data_parallel(
+ wrapped_model, (), range(ngpu), module_kwargs=batch
+ )
+ ## print ppl
+ ids2tokens = preprocessor.token_id_converter.ids2tokens
+ for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
+ pre_word = "<s>"
+ cur_word = None
+ sent_lst = ids2tokens(sent_ids) + ['</s>']
+ ppl_out = " ".join(sent_lst) + "\n"
+ for word, word_nll in zip(sent_lst, sent_nll):
+ cur_word = word
+ word_nll = -word_nll.cpu()
+ if log_base is None:
+ word_prob = np.exp(word_nll)
+ else:
+ word_prob = log_base ** (word_nll / np.log(log_base))
+ ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+ cur=cur_word,
+ pre=pre_word,
+ prob=round(word_prob.item(), 8),
+ word_nll=round(word_nll.item(), 8)
+ )
+ pre_word = cur_word
+
+ sent_nll_mean = sent_nll.mean().cpu().numpy()
+ sent_nll_sum = sent_nll.sum().cpu().numpy()
+ if log_base is None:
+ sent_ppl = np.exp(sent_nll_mean)
+ else:
+ sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+ ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+ sent_nll=round(-sent_nll_sum.item(), 4),
+ sent_ppl=round(sent_ppl.item(), 4)
+ )
+ ppl_out_batch += ppl_out
+ utt2nll = round(-sent_nll_sum.item(), 5)
+ item = {'key': key, 'value': ppl_out}
+ if writer is not None:
+ writer["ppl"][key+":\n"] = ppl_out
+ writer["utt2nll"][key] = str(utt2nll)
+ results.append(item)
+
+ ppl_out_all += ppl_out_batch
+
+ assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
+ # nll: (B, L) -> (B,)
+ nll = nll.detach().cpu().numpy().sum(1)
+ # lengths: (B,)
+ lengths = lengths.detach().cpu().numpy()
+ total_nll += nll.sum()
+ total_ntokens += lengths.sum()
+
+ if log_base is None:
+ ppl = np.exp(total_nll / total_ntokens)
+ else:
+ ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
+
+ avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
+ total_nll=round(-total_nll.item(), 4),
+ total_ppl=round(ppl.item(), 4)
+ )
+ item = {'key': 'AVG PPL', 'value': avg_ppl}
+ ppl_out_all += avg_ppl
+ if writer is not None:
+ writer["ppl"]["AVG PPL : "] = avg_ppl
+ results.append(item)
+
+ return results
+
+ return _forward
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="Calc perplexity",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=False)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ parser.add_argument(
+ "--log_base",
+ type=float_or_none,
+ default=10,
+ help="The base of logarithm for Perplexity. "
+ "If None, napier's constant is used.",
+ required=False
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ action="append",
+ required=False
+ )
+ group.add_argument(
+ "--raw_inputs",
+ type=str,
+ required=False
+ )
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group.add_argument("--split_with_space", type=str2bool, default=False)
+ group.add_argument("--seg_dict_file", type=str_or_none)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument("--train_config", type=str)
+ group.add_argument("--model_file", type=str)
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ inference(**kwargs)
+
+if __name__ == "__main__":
+ main()
+
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
new file mode 100644
index 0000000..492ebab
--- /dev/null
+++ b/funasr/bin/lm_inference_launch.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import argparse
+import logging
+import os
+import sys
+from typing import Union, Dict, Any
+
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils.types import float_or_none
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="Calc perplexity",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument("--gpuid_list", type=str, required=True)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument("--njob", type=int, default=1, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ parser.add_argument(
+ "--log_base",
+ type=float_or_none,
+ default=10,
+ help="The base of logarithm for Perplexity. "
+ "If None, napier's constant is used.",
+ required=False
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ action="append",
+ required=False
+ )
+ group.add_argument(
+ "--raw_inputs",
+ type=str,
+ required=False
+ )
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group.add_argument("--split_with_space", type=str2bool, default=False)
+ group.add_argument("--seg_dict_file", type=str_or_none)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument("--train_config", type=str)
+ group.add_argument("--model_file", type=str)
+ group.add_argument("--mode", type=str, default="lm")
+ return parser
+
+def inference_launch(mode, **kwargs):
+ if mode == "transformer":
+ from funasr.bin.lm_inference import inference_modelscope
+ return inference_modelscope(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+
+ # set logging messages
+ logging.basicConfig(
+ level=args.log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ logging.info("Decoding args: {}".format(kwargs))
+
+ # gpu setting
+ if args.ngpu > 0:
+ jobid = int(args.output_dir.split(".")[-1])
+ gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
+
+ kwargs.pop("gpuid_list", None)
+ kwargs.pop("njob", None)
+ results = inference_launch(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/funasr/bin/lm_train.py b/funasr/bin/lm_train.py
index faa7a45..8641465 100755
--- a/funasr/bin/lm_train.py
+++ b/funasr/bin/lm_train.py
@@ -1,22 +1,46 @@
#!/usr/bin/env python3
+
+import os
+
from funasr.tasks.lm import LMTask
-def get_parser():
+# for LM Training
+def parse_args():
parser = LMTask.get_parser()
- return parser
+ parser.add_argument(
+ "--gpu_id",
+ type=int,
+ default=0,
+ help="local gpu id.",
+ )
+ args = parser.parse_args()
+ return args
-def main(cmd=None):
- """LM training.
-
- Example:
-
- % python lm_train.py asr --print_config --optim adadelta
- % python lm_train.py --config conf/train_asr.yaml
- """
- LMTask.main(cmd=cmd)
+def main(args=None, cmd=None):
+ # for LM Training
+ LMTask.main(args=args, cmd=cmd)
-if __name__ == "__main__":
- main()
+if __name__ == '__main__':
+ args = parse_args()
+
+ # setup local gpu_id
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+ # DDP settings
+ if args.ngpu > 1:
+ args.distributed = True
+ else:
+ args.distributed = False
+ assert args.num_worker_count == 1
+
+ # re-compute batch size: when dataset type is small
+ if args.dataset_type == "small" and args.ngpu != 0:
+ if args.batch_size is not None:
+ args.batch_size = args.batch_size * args.ngpu
+ if args.batch_bins is not None:
+ args.batch_bins = args.batch_bins * args.ngpu
+
+ main(args=args)
diff --git a/funasr/bin/tokenize_text.py b/funasr/bin/tokenize_text.py
new file mode 100755
index 0000000..dc565d0
--- /dev/null
+++ b/funasr/bin/tokenize_text.py
@@ -0,0 +1,283 @@
+#!/usr/bin/env python3
+import argparse
+from collections import Counter
+import logging
+from pathlib import Path
+import sys
+from typing import List
+from typing import Optional
+
+from typeguard import check_argument_types
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.cleaner import TextCleaner
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+
+def field2slice(field: Optional[str]) -> slice:
+ """Convert field string to slice
+
+ Note that field string accepts 1-based integer.
+
+ Examples:
+ >>> field2slice("1-")
+ slice(0, None, None)
+ >>> field2slice("1-3")
+ slice(0, 3, None)
+ >>> field2slice("-3")
+ slice(None, 3, None)
+ """
+ field = field.strip()
+ try:
+ if "-" in field:
+ # e.g. "2-" or "2-5" or "-7"
+ s1, s2 = field.split("-", maxsplit=1)
+ if s1.strip() == "":
+ s1 = None
+ else:
+ s1 = int(s1)
+ if s1 == 0:
+ raise ValueError("1-based string")
+ if s2.strip() == "":
+ s2 = None
+ else:
+ s2 = int(s2)
+ else:
+ # e.g. "2"
+ s1 = int(field)
+ s2 = s1 + 1
+ if s1 == 0:
+ raise ValueError("must be 1 or more value")
+ except ValueError:
+ raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
+
+ if s1 is None:
+ slic = slice(None, s2)
+ else:
+ # -1 because of 1-based integer following "cut" command
+ # e.g "1-3" -> slice(0, 3)
+ slic = slice(s1 - 1, s2)
+ return slic
+
+
+def tokenize(
+ input: str,
+ output: str,
+ field: Optional[str],
+ delimiter: Optional[str],
+ token_type: str,
+ space_symbol: str,
+ non_linguistic_symbols: Optional[str],
+ bpemodel: Optional[str],
+ log_level: str,
+ write_vocabulary: bool,
+ vocabulary_size: int,
+ remove_non_linguistic_symbols: bool,
+ cutoff: int,
+ add_symbol: List[str],
+ cleaner: Optional[str],
+ g2p: Optional[str],
+):
+ assert check_argument_types()
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ if input == "-":
+ fin = sys.stdin
+ else:
+ fin = Path(input).open("r", encoding="utf-8")
+ if output == "-":
+ fout = sys.stdout
+ else:
+ p = Path(output)
+ p.parent.mkdir(parents=True, exist_ok=True)
+ fout = p.open("w", encoding="utf-8")
+
+ cleaner = TextCleaner(cleaner)
+ tokenizer = build_tokenizer(
+ token_type=token_type,
+ bpemodel=bpemodel,
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ remove_non_linguistic_symbols=remove_non_linguistic_symbols,
+ g2p_type=g2p,
+ )
+
+ counter = Counter()
+ if field is not None:
+ field = field2slice(field)
+
+ for line in fin:
+ line = line.rstrip()
+ if field is not None:
+ # e.g. field="2-"
+ # uttidA hello world!! -> hello world!!
+ tokens = line.split(delimiter)
+ tokens = tokens[field]
+ if delimiter is None:
+ line = " ".join(tokens)
+ else:
+ line = delimiter.join(tokens)
+
+ line = cleaner(line)
+ tokens = tokenizer.text2tokens(line)
+ if not write_vocabulary:
+ fout.write(" ".join(tokens) + "\n")
+ else:
+ for t in tokens:
+ counter[t] += 1
+
+ if not write_vocabulary:
+ return
+
+ ## FIXME
+ ## del duplicate add_symbols in counter
+ for symbol_and_id in add_symbol:
+ # e.g symbol="<blank>:0"
+ try:
+ symbol, idx = symbol_and_id.split(":")
+ except ValueError:
+ raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
+ symbol = symbol.strip()
+ if symbol in counter:
+ del counter[symbol]
+
+ # ======= write_vocabulary mode from here =======
+ # Sort by the number of occurrences in descending order
+ # and filter lower frequency words than cutoff value
+ words_and_counts = list(
+ filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
+ )
+ # Restrict the vocabulary size
+ if vocabulary_size > 0:
+ if vocabulary_size < len(add_symbol):
+ raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
+ words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
+
+ # Parse the values of --add_symbol
+ for symbol_and_id in add_symbol:
+ # e.g symbol="<blank>:0"
+ try:
+ symbol, idx = symbol_and_id.split(":")
+ idx = int(idx)
+ except ValueError:
+ raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
+ symbol = symbol.strip()
+
+ # e.g. idx=0 -> append as the first symbol
+ # e.g. idx=-1 -> append as the last symbol
+ if idx < 0:
+ idx = len(words_and_counts) + 1 + idx
+ words_and_counts.insert(idx, (symbol, None))
+
+ # Write words
+ for w, c in words_and_counts:
+ fout.write(w + "\n")
+
+ # Logging
+ total_count = sum(counter.values())
+ invocab_count = sum(c for w, c in words_and_counts if c is not None)
+ logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
+
+
+def get_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(
+ description="Tokenize texts",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument(
+ "--input", "-i", required=True, help="Input text. - indicates sys.stdin"
+ )
+ parser.add_argument(
+ "--output", "-o", required=True, help="Output text. - indicates sys.stdout"
+ )
+ parser.add_argument(
+ "--field",
+ "-f",
+ help="The target columns of the input text as 1-based integer. e.g 2-",
+ )
+ parser.add_argument(
+ "--token_type",
+ "-t",
+ default="char",
+ choices=["char", "bpe", "word", "phn"],
+ help="Token type",
+ )
+ parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
+ parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
+ parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
+ parser.add_argument(
+ "--non_linguistic_symbols",
+ type=str_or_none,
+ help="non_linguistic_symbols file path",
+ )
+ parser.add_argument(
+ "--remove_non_linguistic_symbols",
+ type=str2bool,
+ default=False,
+ help="Remove non-language-symbols from tokens",
+ )
+ parser.add_argument(
+ "--cleaner",
+ type=str_or_none,
+ choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
+ default=None,
+ help="Apply text cleaning",
+ )
+ parser.add_argument(
+ "--g2p",
+ type=str_or_none,
+ choices=g2p_choices,
+ default=None,
+ help="Specify g2p method if --token_type=phn",
+ )
+
+ group = parser.add_argument_group("write_vocabulary mode related")
+ group.add_argument(
+ "--write_vocabulary",
+ type=str2bool,
+ default=False,
+ help="Write tokens list instead of tokenized text per line",
+ )
+ group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
+ group.add_argument(
+ "--cutoff",
+ default=0,
+ type=int,
+ help="cut-off frequency used for write-vocabulary mode",
+ )
+ group.add_argument(
+ "--add_symbol",
+ type=str,
+ default=[],
+ action="append",
+ help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'",
+ )
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ tokenize(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 10fbccb..79540c1 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -58,6 +58,15 @@
continue
return out_txt.strip().split()
+def seg_tokenize_wo_pattern(txt, seg_dict):
+ out_txt = ""
+ for word in txt:
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
def framing(
x,
@@ -372,6 +381,70 @@
data = self._text_process(data)
return data
+## FIXME
+class LMPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: str = "text",
+ split_with_space: bool = False,
+ seg_dict_file: str = None,
+ ):
+ super().__init__(train,
+ token_type,
+ token_list,
+ bpemodel,
+ text_cleaner,
+ g2p_type,
+ unk_symbol,
+ space_symbol,
+ non_linguistic_symbols,
+ delimiter,
+ rir_scp,
+ rir_apply_prob,
+ noise_scp,
+ noise_apply_prob,
+ noise_db_range,
+ speech_volume_normalize,
+ speech_name,
+ text_name,
+ split_with_space,
+ seg_dict_file,
+ )
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ if self.text_name in data and self.tokenizer is not None:
+ text = data[self.text_name]
+ text = self.text_cleaner(text)
+ if self.split_with_space:
+ tokens = text.strip().split(" ")
+ if self.seg_dict is not None:
+ tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
+ else:
+ tokens = self.tokenizer.text2tokens(text)
+ text_ints = self.token_id_converter.tokens2ids(tokens)
+ data[self.text_name] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
class CommonPreprocessor_multi(AbsPreprocessor):
def __init__(
diff --git a/funasr/lm/espnet_model.py b/funasr/lm/espnet_model.py
index 4fc3b49..db11b67 100644
--- a/funasr/lm/espnet_model.py
+++ b/funasr/lm/espnet_model.py
@@ -46,10 +46,10 @@
# 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
# text: (Batch, Length) -> x, y: (Batch, Length + 1)
- x = F.pad(text, [1, 0], "constant", self.eos)
+ x = F.pad(text, [1, 0], "constant", self.sos)
t = F.pad(text, [0, 1], "constant", self.ignore_id)
for i, l in enumerate(text_lengths):
- t[i, l] = self.sos
+ t[i, l] = self.eos
x_lengths = text_lengths + 1
# 2. Forward Language model
diff --git a/funasr/models/decoder/contextual_decoder.py b/funasr/models/decoder/contextual_decoder.py
new file mode 100644
index 0000000..32f550a
--- /dev/null
+++ b/funasr/models/decoder/contextual_decoder.py
@@ -0,0 +1,776 @@
+from typing import List
+from typing import Tuple
+import logging
+import torch
+import torch.nn as nn
+import numpy as np
+
+from funasr.modules.streaming_utils import utils as myutils
+from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
+from typeguard import check_argument_types
+
+from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.modules.embedding import PositionalEncoding
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.modules.repeat import repeat
+from funasr.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder
+
+
+class ContextualDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ size,
+ self_attn,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an DecoderLayer object."""
+ super(ContextualDecoderLayer, self).__init__()
+ self.size = size
+ self.self_attn = self_attn
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ if self_attn is not None:
+ self.norm2 = LayerNorm(size)
+ if src_attn is not None:
+ self.norm3 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
+ # tgt = self.dropout(tgt)
+ if isinstance(tgt, Tuple):
+ tgt, _ = tgt
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ if self.training:
+ cache = None
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + self.dropout(x)
+ x_self_attn = x
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+ x = self.src_attn(x, memory, memory_mask)
+ x_src_attn = x
+
+ x = residual + self.dropout(x)
+ return x, tgt_mask, x_self_attn, x_src_attn
+
+
+class ContexutalBiasDecoder(nn.Module):
+ def __init__(
+ self,
+ size,
+ src_attn,
+ dropout_rate,
+ normalize_before=True,
+ ):
+ """Construct an DecoderLayer object."""
+ super(ContexutalBiasDecoder, self).__init__()
+ self.size = size
+ self.src_attn = src_attn
+ if src_attn is not None:
+ self.norm3 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ x = tgt
+ if self.src_attn is not None:
+ if self.normalize_before:
+ x = self.norm3(x)
+ x = self.dropout(self.src_attn(x, memory, memory_mask))
+ return x, tgt_mask, memory, memory_mask, cache
+
+
+class ContextualParaformerDecoder(ParaformerSANMDecoder):
+ """
+ author: Speech Lab, Alibaba Group, China
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2006.01713
+ """
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ att_layer_num: int = 6,
+ kernel_size: int = 21,
+ sanm_shfit: int = 0,
+ ):
+ assert check_argument_types()
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_output_layer=use_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+
+ attention_dim = encoder_output_size
+ if input_layer == 'none':
+ self.embed = None
+ if input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(vocab_size, attention_dim),
+ # pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ elif input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(vocab_size, attention_dim),
+ torch.nn.LayerNorm(attention_dim),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ else:
+ raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+ self.normalize_before = normalize_before
+ if self.normalize_before:
+ self.after_norm = LayerNorm(attention_dim)
+ if use_output_layer:
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+ else:
+ self.output_layer = None
+
+ self.att_layer_num = att_layer_num
+ self.num_blocks = num_blocks
+ if sanm_shfit is None:
+ sanm_shfit = (kernel_size - 1) // 2
+ self.decoders = repeat(
+ att_layer_num - 1,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ MultiHeadedAttentionSANMDecoder(
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+ ),
+ MultiHeadedAttentionCrossAtt(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ self.dropout = nn.Dropout(dropout_rate)
+ self.bias_decoder = ContexutalBiasDecoder(
+ size=attention_dim,
+ src_attn=MultiHeadedAttentionCrossAtt(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ dropout_rate=dropout_rate,
+ normalize_before=True,
+ )
+ self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False)
+ self.last_decoder = ContextualDecoderLayer(
+ attention_dim,
+ MultiHeadedAttentionSANMDecoder(
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+ ),
+ MultiHeadedAttentionCrossAtt(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ )
+ if num_blocks - att_layer_num <= 0:
+ self.decoders2 = None
+ else:
+ self.decoders2 = repeat(
+ num_blocks - att_layer_num,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ MultiHeadedAttentionSANMDecoder(
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
+ ),
+ None,
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+ self.decoders3 = repeat(
+ 1,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ None,
+ None,
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ contextual_info: torch.Tensor,
+ return_hidden: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ x = tgt
+ x, tgt_mask, memory, memory_mask, _ = self.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+ _, _, x_self_attn, x_src_attn = self.last_decoder(
+ x, tgt_mask, memory, memory_mask
+ )
+
+ # contextual paraformer related
+ contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
+ contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
+ cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
+
+ if self.bias_output is not None:
+ x = torch.cat([x_src_attn, cx], dim=2)
+ x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
+ x = x_self_attn + self.dropout(x)
+
+ if self.decoders2 is not None:
+ x, tgt_mask, memory, memory_mask, _ = self.decoders2(
+ x, tgt_mask, memory, memory_mask
+ )
+
+ x, tgt_mask, memory, memory_mask, _ = self.decoders3(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ olens = tgt_mask.sum(1)
+ if self.output_layer is not None and return_hidden is False:
+ x = self.output_layer(x)
+ return x, olens
+
+ def gen_tf2torch_map_dict(self):
+
+ tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+ tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+ map_dict_local = {
+
+ ## decoder
+ # ffn
+ "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+
+ # fsmn
+ "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
+ tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
+ tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
+ tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 2, 0),
+ }, # (256,1,31),(1,31,256,1)
+ # src att
+ "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ # dnn
+ "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+
+ # embed_concat_ffn
+ "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+
+ # out norm
+ "{}.after_norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.after_norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+
+ # in embed
+ "{}.embed.0.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/w_embs".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (4235,256),(4235,256)
+
+ # out layer
+ "{}.output_layer.weight".format(tensor_name_prefix_torch):
+ {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
+ "squeeze": [None, None],
+ "transpose": [(1, 0), None],
+ }, # (4235,256),(256,4235)
+ "{}.output_layer.bias".format(tensor_name_prefix_torch):
+ {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
+ "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
+ "squeeze": [None, None],
+ "transpose": [None, None],
+ }, # (4235,),(4235,)
+
+ ## clas decoder
+ # src att
+ "{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ # dnn
+ "{}.bias_output.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (2, 1, 0),
+ }, # (1024,256),(1,256,1024)
+
+ }
+ return map_dict_local
+
+ def convert_tf2torch(self,
+ var_dict_tf,
+ var_dict_torch,
+ ):
+ map_dict = self.gen_tf2torch_map_dict()
+ var_dict_torch_update = dict()
+ decoder_layeridx_sets = set()
+ for name in sorted(var_dict_torch.keys(), reverse=False):
+ names = name.split('.')
+ if names[0] == self.tf2torch_tensor_name_prefix_torch:
+ if names[1] == "decoders":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+ elif names[1] == "last_decoder":
+ layeridx = 15
+ name_q = name.replace("last_decoder", "decoders.layeridx")
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+
+ elif names[1] == "decoders2":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+ name_q = name_q.replace("decoders2", "decoders")
+ layeridx_bias = len(decoder_layeridx_sets)
+
+ layeridx += layeridx_bias
+ if "decoders." in name:
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "decoders3":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ if "decoders." in name:
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+ elif names[1] == "bias_decoder":
+ name_q = name
+
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+
+ elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
+ name_tf = map_dict[name]["name"]
+ if isinstance(name_tf, list):
+ idx_list = 0
+ if name_tf[idx_list] in var_dict_tf.keys():
+ pass
+ else:
+ idx_list = 1
+ data_tf = var_dict_tf[name_tf[idx_list]]
+ if map_dict[name]["squeeze"][idx_list] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
+ if map_dict[name]["transpose"][idx_list] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
+ name_tf[idx_list],
+ var_dict_tf[name_tf[
+ idx_list]].shape))
+
+ else:
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+ if map_dict[name]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "after_norm":
+ name_tf = map_dict[name]["name"]
+ data_tf = var_dict_tf[name_tf]
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "embed_concat_ffn":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ if "decoders." in name:
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+ return var_dict_torch_update
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 7596896..5786bc4 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -8,6 +8,8 @@
from typing import Union
import torch
+import random
+import numpy as np
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
@@ -24,7 +26,7 @@
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.modules.add_sos_eos import add_sos_eos
-from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -824,7 +826,10 @@
class BiCifParaformer(Paraformer):
- """CTC-attention hybrid Encoder-Decoder model"""
+ """
+ Paraformer model with an extra cif predictor
+ to conduct accurate timestamp prediction
+ """
def __init__(
self,
@@ -891,7 +896,7 @@
)
assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
- def _calc_att_loss(
+ def _calc_pre2_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
@@ -903,47 +908,12 @@
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
+ _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
- # 0. sampler
- decoder_out_1st = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
+ # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+ loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
-
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length2)
-
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
- return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_pre2
+ return loss_pre2
def calc_predictor(self, encoder_out, encoder_out_lens):
@@ -956,10 +926,154 @@
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out, None, encoder_out_mask, token_num=token_num,
- ignore_id=self.ignore_id)
- import pdb; pdb.set_trace()
+ ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
+ encoder_out_mask,
+ token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+ batch_size = speech.shape[0]
+ self.step_cur += 1
+ # for data-parallel
+ text = text[:, : text_lengths.max()]
+ speech = speech[:, :speech_lengths.max()]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ stats = dict()
+
+ loss_pre2 = self._calc_pre2_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ loss = loss_pre2
+
+ stats["loss_pre2"] = loss_pre2.detach().cpu()
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+class ContextualParaformer(Paraformer):
+ """
+ Paraformer model with contextual hotword
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ preencoder: Optional[AbsPreEncoder],
+ encoder: AbsEncoder,
+ postencoder: Optional[AbsPostEncoder],
+ decoder: AbsDecoder,
+ ctc: CTC,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ extract_feats_in_collect_stats: bool = True,
+ predictor=None,
+ predictor_weight: float = 0.0,
+ predictor_bias: int = 0,
+ sampling_ratio: float = 0.2,
+ min_hw_length: int = 2,
+ max_hw_length: int = 4,
+ sample_rate: float = 0.6,
+ batch_rate: float = 0.5,
+ double_rate: float = -1.0,
+ target_buffer_length: int = -1,
+ inner_dim: int = 256,
+ bias_encoder_type: str = 'lstm',
+ label_bracket: bool = False,
+ ):
+ assert check_argument_types()
+ assert 0.0 <= ctc_weight <= 1.0, ctc_weight
+ assert 0.0 <= interctc_weight < 1.0, interctc_weight
+
+ super().__init__(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ postencoder=postencoder,
+ decoder=decoder,
+ ctc=ctc,
+ ctc_weight=ctc_weight,
+ interctc_weight=interctc_weight,
+ ignore_id=ignore_id,
+ blank_id=blank_id,
+ sos=sos,
+ eos=eos,
+ lsm_weight=lsm_weight,
+ length_normalized_loss=length_normalized_loss,
+ report_cer=report_cer,
+ report_wer=report_wer,
+ sym_space=sym_space,
+ sym_blank=sym_blank,
+ extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+ predictor=predictor,
+ predictor_weight=predictor_weight,
+ predictor_bias=predictor_bias,
+ sampling_ratio=sampling_ratio,
+ )
+
+ if bias_encoder_type == 'lstm':
+ logging.warning("enable bias encoder sampling and contextual training")
+ self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=0)
+ self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim)
+ else:
+ logging.error("Unsupport bias encoder type")
+
+ self.min_hw_length = min_hw_length
+ self.max_hw_length = max_hw_length
+ self.sample_rate = sample_rate
+ self.batch_rate = batch_rate
+ self.target_buffer_length = target_buffer_length
+ self.double_rate = double_rate
+
+ if self.target_buffer_length > 0:
+ self.hotword_buffer = None
+ self.length_record = []
+ self.current_buffer_length = 0
def forward(
self,
@@ -1038,17 +1152,17 @@
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att, loss_pre, loss_pre2 = self._calc_att_loss(
+ loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight
+ loss = loss_att + loss_pre * self.predictor_weight
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -1056,10 +1170,292 @@
stats["cer"] = cer_att
stats["wer"] = wer_att
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
- stats["loss_pre2"] = loss_pre2.detach().cpu() if loss_pre is not None else None
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
\ No newline at end of file
+ return loss, stats, weight
+
+ def _sample_hot_word(self, ys_pad, ys_pad_lens):
+ hw_list = [torch.Tensor([0]).long().to(ys_pad.device)]
+ hw_lengths = [0] # this length is actually for indice, so -1
+ for i, length in enumerate(ys_pad_lens):
+ if length < 2:
+ continue
+ if length > self.min_hw_length + self.max_hw_length + 2 and random.random() < self.double_rate:
+ # sample double hotword
+ _max_hw_length = min(self.max_hw_length, length // 2)
+ # first hotword
+ start1 = random.randint(0, length // 3)
+ end1 = random.randint(start1 + self.min_hw_length - 1, start1 + _max_hw_length - 1)
+ hw_tokens1 = ys_pad[i][start1:end1 + 1]
+ hw_lengths.append(len(hw_tokens1) - 1)
+ hw_list.append(hw_tokens1)
+ # second hotword
+ start2 = random.randint(end1 + 1, length - self.min_hw_length)
+ end2 = random.randint(min(length - 1, start2 + self.min_hw_length - 1),
+ min(length - 1, start2 + self.max_hw_length - 1))
+ hw_tokens2 = ys_pad[i][start2:end2 + 1]
+ hw_lengths.append(len(hw_tokens2) - 1)
+ hw_list.append(hw_tokens2)
+ continue
+ if random.random() < self.sample_rate:
+ if length == 2:
+ hw_tokens = ys_pad[i][:2]
+ hw_lengths.append(1)
+ hw_list.append(hw_tokens)
+ else:
+ start = random.randint(0, length - self.min_hw_length)
+ end = random.randint(min(length - 1, start + self.min_hw_length - 1),
+ min(length - 1, start + self.max_hw_length - 1)) + 1
+ # print(start, end)
+ hw_tokens = ys_pad[i][start:end]
+ hw_lengths.append(len(hw_tokens) - 1)
+ hw_list.append(hw_tokens)
+ # padding
+ hw_list_pad = pad_list(hw_list, 0)
+ hw_embed = self.decoder.embed(hw_list_pad)
+ hw_embed, (_, _) = self.bias_encoder(hw_embed)
+ _ind = np.arange(0, len(hw_list)).tolist()
+ # update self.hotword_buffer, throw a part if oversize
+ selected = hw_embed[_ind, hw_lengths]
+ if self.target_buffer_length > 0:
+ _b = selected.shape[0]
+ if self.hotword_buffer is None:
+ self.hotword_buffer = selected
+ self.length_record.append(selected.shape[0])
+ self.current_buffer_length = _b
+ elif self.current_buffer_length + _b < self.target_buffer_length:
+ self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
+ self.current_buffer_length += _b
+ selected = self.hotword_buffer
+ else:
+ self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
+ random_throw = random.randint(self.target_buffer_length // 2, self.target_buffer_length) + 10
+ self.hotword_buffer = self.hotword_buffer[-1 * random_throw:]
+ selected = self.hotword_buffer
+ self.current_buffer_length = selected.shape[0]
+ return selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
+
+ def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
+
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad)
+ with torch.no_grad():
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # sample hot word
+ contextual_info = self._sample_hot_word(ys_pad, ys_pad_lens)
+
+ # 0. sampler
+ decoder_out_1st = None
+ if self.sampling_ratio > 0.0:
+ if self.step_cur < 2:
+ logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds, contextual_info)
+ else:
+ if self.step_cur < 2:
+ logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre
+
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
+ if hw_list is None:
+ # default hotword list
+ hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)] # empty hotword list
+ hw_list_pad = pad_list(hw_list, 0)
+ hw_embed = self.bias_embed(hw_list_pad)
+ _, (h_n, _) = self.bias_encoder(hw_embed)
+ contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
+ else:
+ hw_lengths = [len(i) for i in hw_list]
+ hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
+ hw_embed = self.bias_embed(hw_list_pad)
+ hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
+ enforce_sorted=False)
+ _, (h_n, _) = self.bias_encoder(hw_embed)
+ # hw_embed, _ = torch.nn.utils.rnn.pad_packed_sequence(hw_embed, batch_first=True)
+ contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
+ def gen_clas_tf2torch_map_dict(self):
+ tensor_name_prefix_torch = "bias_encoder"
+ tensor_name_prefix_tf = "seq2seq/clas_charrnn"
+
+ tensor_name_prefix_torch_emb = "bias_embed"
+ tensor_name_prefix_tf_emb = "seq2seq"
+
+ map_dict_local = {
+ # in lstm
+ "{}.weight_ih_l0".format(tensor_name_prefix_torch):
+ {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (1, 0),
+ "slice": (0, 512),
+ "unit_k": 512,
+ }, # (1024, 2048),(2048,512)
+ "{}.weight_hh_l0".format(tensor_name_prefix_torch):
+ {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (1, 0),
+ "slice": (512, 1024),
+ "unit_k": 512,
+ }, # (1024, 2048),(2048,512)
+ "{}.bias_ih_l0".format(tensor_name_prefix_torch):
+ {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ "scale": 0.5,
+ "unit_b": 512,
+ }, # (2048,),(2048,)
+ "{}.bias_hh_l0".format(tensor_name_prefix_torch):
+ {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ "scale": 0.5,
+ "unit_b": 512,
+ }, # (2048,),(2048,)
+
+ # in embed
+ "{}.weight".format(tensor_name_prefix_torch_emb):
+ {"name": "{}/contextual_encoder/w_char_embs".format(tensor_name_prefix_tf_emb),
+ "squeeze": None,
+ "transpose": None,
+ }, # (4235,256),(4235,256)
+ }
+ return map_dict_local
+
+ def clas_convert_tf2torch(self,
+ var_dict_tf,
+ var_dict_torch):
+ map_dict = self.gen_clas_tf2torch_map_dict()
+ var_dict_torch_update = dict()
+ for name in sorted(var_dict_torch.keys(), reverse=False):
+ names = name.split('.')
+ if names[0] == "bias_encoder":
+ name_q = name
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q].get("unit_k") is not None:
+ dim = map_dict[name_q]["unit_k"]
+ i = data_tf[:, 0:dim].copy()
+ f = data_tf[:, dim:2 * dim].copy()
+ o = data_tf[:, 2 * dim:3 * dim].copy()
+ g = data_tf[:, 3 * dim:4 * dim].copy()
+ data_tf = np.concatenate([i, o, f, g], axis=1)
+ if map_dict[name_q].get("unit_b") is not None:
+ dim = map_dict[name_q]["unit_b"]
+ i = data_tf[0:dim].copy()
+ f = data_tf[dim:2 * dim].copy()
+ o = data_tf[2 * dim:3 * dim].copy()
+ g = data_tf[3 * dim:4 * dim].copy()
+ data_tf = np.concatenate([i, o, f, g], axis=0)
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q].get("slice") is not None:
+ data_tf = data_tf[map_dict[name_q]["slice"][0]:map_dict[name_q]["slice"][1]]
+ if map_dict[name_q].get("scale") is not None:
+ data_tf = data_tf * map_dict[name_q]["scale"]
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+ elif names[0] == "bias_embed":
+ name_tf = map_dict[name]["name"]
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+ if map_dict[name]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+ var_dict_tf[name_tf].shape))
+
+ return var_dict_torch_update
\ No newline at end of file
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index c34759d..5615373 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -544,9 +544,8 @@
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak, token_num2
-
- def get_upsample_timestamp(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
- target_label_length=None, token_num=None):
+
+ def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
h = hidden
b = hidden.shape[0]
context = h.transpose(1, 2)
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 7899400..02311fd 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -43,6 +43,7 @@
from funasr.iterators.chunk_iter_factory import ChunkIterFactory
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
from funasr.iterators.sequence_iter_factory import SequenceIterFactory
+from funasr.main_funcs.collect_stats import collect_stats
from funasr.optimizers.sgd import SGD
from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.samplers.build_batch_sampler import BATCH_TYPES
@@ -1272,6 +1273,52 @@
if args.dry_run:
pass
+ elif args.collect_stats:
+ # Perform on collect_stats mode. This mode has two roles
+ # - Derive the length and dimension of all input data
+ # - Accumulate feats, square values, and the length for whitening
+
+ if args.valid_batch_size is None:
+ args.valid_batch_size = args.batch_size
+
+ if len(args.train_shape_file) != 0:
+ train_key_file = args.train_shape_file[0]
+ else:
+ train_key_file = None
+ if len(args.valid_shape_file) != 0:
+ valid_key_file = args.valid_shape_file[0]
+ else:
+ valid_key_file = None
+
+ collect_stats(
+ model=model,
+ train_iter=cls.build_streaming_iterator(
+ data_path_and_name_and_type=args.train_data_path_and_name_and_type,
+ key_file=train_key_file,
+ batch_size=args.batch_size,
+ dtype=args.train_dtype,
+ num_workers=args.num_workers,
+ allow_variable_data_keys=args.allow_variable_data_keys,
+ ngpu=args.ngpu,
+ preprocess_fn=cls.build_preprocess_fn(args, train=False),
+ collate_fn=cls.build_collate_fn(args, train=False),
+ ),
+ valid_iter=cls.build_streaming_iterator(
+ data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
+ key_file=valid_key_file,
+ batch_size=args.valid_batch_size,
+ dtype=args.train_dtype,
+ num_workers=args.num_workers,
+ allow_variable_data_keys=args.allow_variable_data_keys,
+ ngpu=args.ngpu,
+ preprocess_fn=cls.build_preprocess_fn(args, train=False),
+ collate_fn=cls.build_collate_fn(args, train=False),
+ ),
+ output_dir=output_dir,
+ ngpu=args.ngpu,
+ log_interval=args.log_interval,
+ write_collected_feats=args.write_collected_feats,
+ )
else:
logging.info("Training args: {}".format(args))
# 6. Loads pre-trained model
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 1b7f152..e62a748 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -37,8 +37,9 @@
)
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_asr import ESPnetASRModel
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
@@ -117,6 +118,7 @@
paraformer=Paraformer,
paraformer_bert=ParaformerBert,
bicif_paraformer=BiCifParaformer,
+ contextual_paraformer=ContextualParaformer,
),
type_check=AbsESPnetModel,
default="asr",
@@ -177,6 +179,7 @@
fsmn_scama_opt=FsmnDecoderSCAMAOpt,
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
+ contextual_paraformer_decoder=ContextualParaformerDecoder,
),
type_check=AbsDecoder,
default="rnn",
@@ -1098,5 +1101,8 @@
# decoder
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
+ # bias_encoder
+ var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py
index 46b9fe0..608c1d3 100644
--- a/funasr/tasks/lm.py
+++ b/funasr/tasks/lm.py
@@ -58,7 +58,7 @@
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
required = parser.get_default("required")
- required += ["token_list"]
+ # required += ["token_list"]
group.add_argument(
"--token_list",
diff --git a/funasr/utils/postprocess_utils.py b/funasr/utils/postprocess_utils.py
index 4da0d59..575fb90 100644
--- a/funasr/utils/postprocess_utils.py
+++ b/funasr/utils/postprocess_utils.py
@@ -232,5 +232,9 @@
return sentence, ts_lists, real_word_lists
else:
word_lists = abbr_dispose(word_lists)
+ real_word_lists = []
+ for ch in word_lists:
+ if ch != ' ':
+ real_word_lists.append(ch)
sentence = ''.join(word_lists).strip()
- return sentence
+ return sentence, real_word_lists
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 3afaa40..33d1255 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -86,14 +86,51 @@
else:
return time_stamp_list
-
-def time_stamp_lfr6_advance(tst: List, text: str):
- # advanced timestamp prediction for BiCIF_Paraformer using upsampled alphas
- ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = tst
- if text.endswith('</s>'):
- text = text[:-4]
+def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None):
+ START_END_THRESHOLD = 5
+ TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
+ if len(us_alphas.shape) == 3:
+ alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only
else:
- text = text[:-1]
- logging.warning("found text does not end with </s>")
- assert int(ds_alphas.sum() + 1e-4) - 1 == len(text)
-
+ alphas, cif_peak = us_alphas, us_cif_peak
+ num_frames = cif_peak.shape[0]
+ if char_list[-1] == '</s>':
+ char_list = char_list[:-1]
+ # char_list = [i for i in text]
+ timestamp_list = []
+ # for bicif model trained with large data, cif2 actually fires when a character starts
+ # so treat the frames between two peaks as the duration of the former token
+ fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5
+ num_peak = len(fire_place)
+ assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
+ # begin silence
+ if fire_place[0] > START_END_THRESHOLD:
+ char_list.insert(0, '<sil>')
+ timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
+ # tokens timestamp
+ for i in range(len(fire_place)-1):
+ # the peak is always a little ahead of the start time
+ # timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE])
+ timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE])
+ # cut the duration to token and sil of the 0-weight frames last long
+ # tail token and end silence
+ if num_frames - fire_place[-1] > START_END_THRESHOLD:
+ _end = (num_frames + fire_place[-1]) / 2
+ timestamp_list[-1][1] = _end*TIME_RATE
+ timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
+ char_list.append("<sil>")
+ else:
+ timestamp_list[-1][1] = num_frames*TIME_RATE
+ if begin_time: # add offset time in model with vad
+ for i in range(len(timestamp_list)):
+ timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0
+ timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0
+ res_txt = ""
+ for char, timestamp in zip(char_list, timestamp_list):
+ res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1])
+ res = []
+ for char, timestamp in zip(char_list, timestamp_list):
+ if char != '<sil>':
+ res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
+ return res
+
--
Gitblit v1.9.1