From f3cd90dcf21e2d4ca451abbfdc841ac6abfc68ee Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 14 二月 2023 14:59:03 +0800
Subject: [PATCH] Merge pull request #105 from yufan-aslp/main
---
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer.py | 103 ++
funasr/bin/asr_inference_mfcca.py | 771 +++++++++++++++++++
funasr/models/e2e_asr_mfcca.py | 322 ++++++++
funasr/models/encoder/mfcca_encoder.py | 450 +++++++++++
funasr/models/encoder/encoder_layer_mfcca.py | 270 ++++++
funasr/tasks/asr.py | 138 +++
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/RESULTS.md | 40 +
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/finetune.py | 35
funasr/models/frontend/default.py | 125 +++
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/README.md | 53 +
funasr/bin/build_trainer.py | 2
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py | 67 +
funasr/bin/asr_inference_launch.py | 6
13 files changed, 2,382 insertions(+), 0 deletions(-)
diff --git a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/README.md b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/README.md
new file mode 100644
index 0000000..9097e7a
--- /dev/null
+++ b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/README.md
@@ -0,0 +1,53 @@
+# ModelScope Model
+
+## How to finetune and infer using a pretrained Paraformer-large Model
+
+### Finetune
+
+- Modify finetune training related parameters in `finetune.py`
+ - <strong>output_dir:</strong> # result dir
+ - <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
+ - <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
+ - <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
+ - <strong>max_epoch:</strong> # number of training epoch
+ - <strong>lr:</strong> # learning rate
+
+- Then you can run the pipeline to finetune with:
+```python
+ python finetune.py
+```
+
+### Inference
+
+Or you can use the finetuned model for inference directly.
+
+- Setting parameters in `infer.py`
+ - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
+ - <strong>output_dir:</strong> # result dir
+ - <strong>ngpu:</strong> # the number of GPUs for decoding
+ - <strong>njob:</strong> # the number of jobs for each GPU
+
+- Then you can run the pipeline to infer with:
+```python
+ python infer.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/1best_recog/text.sp.cer` and `$output_dir/1best_recog/text.nosp.cer`, which includes recognition results with or without separating character (src) of each sample and the CER metric of the whole test set.
+
+### Inference using local finetuned model
+
+- Modify inference related parameters in `infer_after_finetune.py`
+ - <strong>output_dir:</strong> # result dir
+ - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
+ - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
+
+- Then you can run the pipeline to finetune with:
+```python
+ python infer_after_finetune.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/1best_recog/text.sp.cer` and `$output_dir/1best_recog/text.nosp.cer`, which includes recognition results with or without separating character (src) of each sample and the CER metric of the whole test set.
diff --git a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/RESULTS.md b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/RESULTS.md
new file mode 100644
index 0000000..8f58259
--- /dev/null
+++ b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/RESULTS.md
@@ -0,0 +1,40 @@
+# Paraformer-Large
+- Model link: <https://www.modelscope.cn/models/yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary>
+- Model size: 45M
+
+# Environments
+- date: `Tue Feb 13 20:13:22 CST 2023`
+- python version: `3.7.12`
+- FunASR version: `0.1.0`
+- pytorch version: `pytorch 1.7.0`
+- Git hash: ``
+- Commit date: ``
+
+# Beachmark Results
+
+## result (paper)
+beam=20锛孋ER tool锛歨ttps://github.com/yufan-aslp/AliMeeting
+
+| model | Para (M) | Data (hrs) | Eval (CER%) | Test (CER%) |
+|:-------------------:|:---------:|:---------:|:---------:| :---------:|
+| MFCCA | 45 | 917 | 16.1 | 17.5 |
+
+## result锛坢odelscope锛�
+
+beam=10
+
+with separating character (src)
+
+| model | Para (M) | Data (hrs) | Eval_sp (CER%) | Test_sp (CER%) |
+|:-------------------:|:---------:|:---------:|:---------:| :---------:|
+| MFCCA | 45 | 917 | 17.1 | 18.6 |
+
+without separating character (src)
+
+| model | Para (M) | Data (hrs) | Eval_nosp (CER%) | Test_nosp (CER%) |
+|:-------------------:|:---------:|:---------:|:---------:| :---------:|
+| MFCCA | 45 | 917 | 16.4 | 18.0 |
+
+## 鍋忓樊
+
+Considering the differences of the CER calculation tool and decoding beam size, the results of CER are biased (<0.5%).
\ No newline at end of file
diff --git a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/finetune.py b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/finetune.py
new file mode 100755
index 0000000..281292f
--- /dev/null
+++ b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/finetune.py
@@ -0,0 +1,35 @@
+import os
+from modelscope.metainfo import Trainers
+from modelscope.trainers import build_trainer
+from funasr.datasets.ms_dataset import MsDataset
+from funasr.utils.modelscope_param import modelscope_args
+
+def modelscope_finetune(params):
+ if not os.path.exists(params.output_dir):
+ os.makedirs(params.output_dir, exist_ok=True)
+ # dataset split ["train", "validation"]
+ ds_dict = MsDataset.load(params.data_path)
+ kwargs = dict(
+ model=params.model,
+ model_revision=params.model_revision,
+ data_dir=ds_dict,
+ dataset_type=params.dataset_type,
+ work_dir=params.output_dir,
+ batch_bins=params.batch_bins,
+ max_epoch=params.max_epoch,
+ lr=params.lr)
+ trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
+ trainer.train()
+
+
+if __name__ == '__main__':
+
+ params = modelscope_args(model="yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950")
+ params.output_dir = "./checkpoint" # m妯″瀷淇濆瓨璺緞
+ params.data_path = "./example_data/" # 鏁版嵁璺緞
+ params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
+ params.batch_bins = 1000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
+ params.max_epoch = 10 # 鏈�澶ц缁冭疆鏁�
+ params.lr = 0.0001 # 璁剧疆瀛︿範鐜�
+ params.model_revision = 'v2.0.0'
+ modelscope_finetune(params)
diff --git a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer.py b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer.py
new file mode 100755
index 0000000..3054394
--- /dev/null
+++ b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer.py
@@ -0,0 +1,103 @@
+import os
+import shutil
+from multiprocessing import Pool
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+from funasr.utils.compute_wer import compute_wer
+
+import pdb;
+def modelscope_infer_core(output_dir, split_dir, njob, idx):
+ output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
+ gpu_id = (int(idx) - 1) // njob
+ if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
+ gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
+ else:
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
+ inference_pipline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950',
+ model_revision='v2.0.0',
+ output_dir=output_dir_job,
+ batch_size=1,
+ )
+ audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
+ inference_pipline(audio_in=audio_in)
+
+
+def modelscope_infer(params):
+ # prepare for multi-GPU decoding
+ ngpu = params["ngpu"]
+ njob = params["njob"]
+ output_dir = params["output_dir"]
+ if os.path.exists(output_dir):
+ shutil.rmtree(output_dir)
+ os.mkdir(output_dir)
+ split_dir = os.path.join(output_dir, "split")
+ os.mkdir(split_dir)
+ nj = ngpu * njob
+ wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ num_lines = len(lines)
+ num_job_lines = num_lines // nj
+ start = 0
+ for i in range(nj):
+ end = start + num_job_lines
+ file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
+ with open(file, "w") as f:
+ if i == nj - 1:
+ f.writelines(lines[start:])
+ else:
+ f.writelines(lines[start:end])
+ start = end
+ p = Pool(nj)
+ for i in range(nj):
+ p.apply_async(modelscope_infer_core,
+ args=(output_dir, split_dir, njob, str(i + 1)))
+ p.close()
+ p.join()
+
+ # combine decoding results
+ best_recog_path = os.path.join(output_dir, "1best_recog")
+ os.mkdir(best_recog_path)
+ files = ["text", "token", "score"]
+ for file in files:
+ with open(os.path.join(best_recog_path, file), "w") as f:
+ for i in range(nj):
+ job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
+ with open(job_file) as f_job:
+ lines = f_job.readlines()
+ f.writelines(lines)
+
+ # If text exists, compute CER
+ text_in = os.path.join(params["data_dir"], "text")
+ if os.path.exists(text_in):
+ text_proc_file = os.path.join(best_recog_path, "token")
+ text_proc_file2 = os.path.join(best_recog_path, "token_nosep")
+ with open(text_proc_file, 'r') as hyp_reader:
+ with open(text_proc_file2, 'w') as hyp_writer:
+ for line in hyp_reader:
+ new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
+ hyp_writer.write(new_context+'\n')
+ text_in2 = os.path.join(best_recog_path, "ref_text_nosep")
+ with open(text_in, 'r') as ref_reader:
+ with open(text_in2, 'w') as ref_writer:
+ for line in ref_reader:
+ new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
+ ref_writer.write(new_context+'\n')
+
+
+ compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.sp.cer"))
+ compute_wer(text_in2, text_proc_file2, os.path.join(best_recog_path, "text.nosp.cer"))
+
+
+if __name__ == "__main__":
+ params = {}
+ params["data_dir"] = "./example_data/validation"
+ params["output_dir"] = "./output_dir"
+ params["ngpu"] = 1
+ params["njob"] = 1
+ modelscope_infer(params)
diff --git a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py
new file mode 100755
index 0000000..00faad0
--- /dev/null
+++ b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py
@@ -0,0 +1,67 @@
+import json
+import os
+import shutil
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+from funasr.utils.compute_wer import compute_wer
+
+
+def modelscope_infer_after_finetune(params):
+ # prepare for decoding
+ pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
+ for file_name in params["required_files"]:
+ if file_name == "configuration.json":
+ with open(os.path.join(pretrained_model_path, file_name)) as f:
+ config_dict = json.load(f)
+ config_dict["model"]["am_model_name"] = params["decoding_model_name"]
+ with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
+ json.dump(config_dict, f, indent=4, separators=(',', ': '))
+ else:
+ shutil.copy(os.path.join(pretrained_model_path, file_name),
+ os.path.join(params["output_dir"], file_name))
+ decoding_path = os.path.join(params["output_dir"], "decode_results")
+ if os.path.exists(decoding_path):
+ shutil.rmtree(decoding_path)
+ os.mkdir(decoding_path)
+
+ # decoding
+ inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=params["output_dir"],
+ output_dir=decoding_path,
+ batch_size=1
+ )
+ audio_in = os.path.join(params["data_dir"], "wav.scp")
+ inference_pipeline(audio_in=audio_in)
+
+ # computer CER if GT text is set
+ text_in = os.path.join(params["data_dir"], "text")
+ if text_in is not None:
+ text_proc_file = os.path.join(decoding_path, "1best_recog/token")
+ text_proc_file2 = os.path.join(decoding_path, "1best_recog/token_nosep")
+ with open(text_proc_file, 'r') as hyp_reader:
+ with open(text_proc_file2, 'w') as hyp_writer:
+ for line in hyp_reader:
+ new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
+ hyp_writer.write(new_context+'\n')
+ text_in2 = os.path.join(decoding_path, "1best_recog/ref_text_nosep")
+ with open(text_in, 'r') as ref_reader:
+ with open(text_in2, 'w') as ref_writer:
+ for line in ref_reader:
+ new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
+ ref_writer.write(new_context+'\n')
+
+
+ compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.sp.cer"))
+ compute_wer(text_in2, text_proc_file2, os.path.join(decoding_path, "text.nosp.cer"))
+
+if __name__ == '__main__':
+ params = {}
+ params["modelscope_model_name"] = "yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950"
+ params["required_files"] = ["feats_stats.npz", "decoding.yaml", "configuration.json"]
+ params["output_dir"] = "./checkpoint"
+ params["data_dir"] = "./example_data/validation"
+ params["decoding_model_name"] = "valid.acc.ave.pth"
+ modelscope_infer_after_finetune(params)
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index d2798b1..d8d5679 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -228,6 +228,9 @@
elif mode == "vad":
from funasr.bin.vad_inference import inference_modelscope
return inference_modelscope(**kwargs)
+ elif mode == "mfcca":
+ from funasr.bin.asr_inference_mfcca import inference_modelscope
+ return inference_modelscope(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
@@ -253,6 +256,9 @@
elif mode == "vad":
from funasr.bin.vad_inference import inference
return inference(**kwargs)
+ elif mode == "mfcca":
+ from funasr.bin.asr_inference_mfcca import inference_modelscope
+ return inference_modelscope(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
diff --git a/funasr/bin/asr_inference_mfcca.py b/funasr/bin/asr_inference_mfcca.py
new file mode 100644
index 0000000..0da66c5
--- /dev/null
+++ b/funasr/bin/asr_inference_mfcca.py
@@ -0,0 +1,771 @@
+#!/usr/bin/env python3
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import argparse
+import logging
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch
+from funasr.modules.beam_search.beam_search import BeamSearch
+from funasr.modules.beam_search.beam_search import Hypothesis
+from funasr.modules.scorers.ctc import CTCPrefixScorer
+from funasr.modules.scorers.length_bonus import LengthBonus
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
+from funasr.tasks.lm import LMTask
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+import pdb
+
+header_colors = '\033[95m'
+end_colors = '\033[0m'
+
+global_asr_language: str = 'zh-cn'
+global_sample_rate: Union[int, Dict[Any, int]] = {
+ 'audio_fs': 16000,
+ 'model_fs': 16000
+}
+
+class Speech2Text:
+ """Speech2Text class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2text(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build ASR model
+ scorers = {}
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+
+ logging.info("asr_model: {}".format(asr_model))
+ logging.info("asr_train_args: {}".format(asr_train_args))
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ decoder = asr_model.decoder
+
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ token_list = asr_model.token_list
+ scorers.update(
+ decoder=decoder,
+ ctc=ctc,
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 2. Build Language model
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, device
+ )
+ lm.to(device)
+ scorers["lm"] = lm.lm
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ # 4. Build BeamSearch object
+ # transducer is not supported now
+ beam_search_transducer = None
+
+ weights = dict(
+ decoder=1.0 - ctc_weight,
+ ctc=ctc_weight,
+ lm=lm_weight,
+ ngram=ngram_weight,
+ length_bonus=penalty,
+ )
+ beam_search = BeamSearch(
+ beam_size=beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=asr_model.sos,
+ eos=asr_model.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+ )
+ #beam_search.__class__ = BatchBeamSearch
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+ self.beam_search = beam_search
+ self.beam_search_transducer = beam_search_transducer
+ self.maxlenratio = maxlenratio
+ self.minlenratio = minlenratio
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+ ) -> List[
+ Tuple[
+ Optional[str],
+ List[str],
+ List[int],
+ Union[Hypothesis],
+ ]
+ ]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+
+ #speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ speech = speech.to(getattr(torch, self.dtype))
+ # lenghts: (1,)
+ lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+ batch = {"speech": speech, "speech_lengths": lengths}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ enc, _ = self.asr_model.encode(**batch)
+
+ assert len(enc) == 1, len(enc)
+
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+
+ results = []
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+ results.append((text, token, token_int, hyp))
+
+ assert check_return_type(results)
+ return results
+
+
+# def inference(
+# maxlenratio: float,
+# minlenratio: float,
+# batch_size: int,
+# beam_size: int,
+# ngpu: int,
+# ctc_weight: float,
+# lm_weight: float,
+# penalty: float,
+# log_level: Union[int, str],
+# data_path_and_name_and_type,
+# asr_train_config: Optional[str],
+# asr_model_file: Optional[str],
+# cmvn_file: Optional[str] = None,
+# lm_train_config: Optional[str] = None,
+# lm_file: Optional[str] = None,
+# token_type: Optional[str] = None,
+# key_file: Optional[str] = None,
+# word_lm_train_config: Optional[str] = None,
+# bpemodel: Optional[str] = None,
+# allow_variable_data_keys: bool = False,
+# streaming: bool = False,
+# output_dir: Optional[str] = None,
+# dtype: str = "float32",
+# seed: int = 0,
+# ngram_weight: float = 0.9,
+# nbest: int = 1,
+# num_workers: int = 1,
+# **kwargs,
+# ):
+# assert check_argument_types()
+# if batch_size > 1:
+# raise NotImplementedError("batch decoding is not implemented")
+# if word_lm_train_config is not None:
+# raise NotImplementedError("Word LM is not implemented")
+# if ngpu > 1:
+# raise NotImplementedError("only single GPU decoding is supported")
+#
+# logging.basicConfig(
+# level=log_level,
+# format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+# )
+#
+# if ngpu >= 1 and torch.cuda.is_available():
+# device = "cuda"
+# else:
+# device = "cpu"
+#
+# # 1. Set random-seed
+# set_all_random_seed(seed)
+#
+# # 2. Build speech2text
+# speech2text_kwargs = dict(
+# asr_train_config=asr_train_config,
+# asr_model_file=asr_model_file,
+# cmvn_file=cmvn_file,
+# lm_train_config=lm_train_config,
+# lm_file=lm_file,
+# token_type=token_type,
+# bpemodel=bpemodel,
+# device=device,
+# maxlenratio=maxlenratio,
+# minlenratio=minlenratio,
+# dtype=dtype,
+# beam_size=beam_size,
+# ctc_weight=ctc_weight,
+# lm_weight=lm_weight,
+# ngram_weight=ngram_weight,
+# penalty=penalty,
+# nbest=nbest,
+# streaming=streaming,
+# )
+# logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+# speech2text = Speech2Text(**speech2text_kwargs)
+#
+# # 3. Build data-iterator
+# loader = ASRTask.build_streaming_iterator(
+# data_path_and_name_and_type,
+# dtype=dtype,
+# batch_size=batch_size,
+# key_file=key_file,
+# num_workers=num_workers,
+# preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+# collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+# allow_variable_data_keys=allow_variable_data_keys,
+# inference=True,
+# )
+#
+# finish_count = 0
+# file_count = 1
+# # 7 .Start for-loop
+# # FIXME(kamo): The output format should be discussed about
+# asr_result_list = []
+# if output_dir is not None:
+# writer = DatadirWriter(output_dir)
+# else:
+# writer = None
+#
+# for keys, batch in loader:
+# assert isinstance(batch, dict), type(batch)
+# assert all(isinstance(s, str) for s in keys), keys
+# _bs = len(next(iter(batch.values())))
+# assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+# #batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+#
+# # N-best list of (text, token, token_int, hyp_object)
+# try:
+# results = speech2text(**batch)
+# except TooShortUttError as e:
+# logging.warning(f"Utterance {keys} {e}")
+# hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+# results = [[" ", ["<space>"], [2], hyp]] * nbest
+#
+# # Only supporting batch_size==1
+# key = keys[0]
+# for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+# # Create a directory: outdir/{n}best_recog
+# if writer is not None:
+# ibest_writer = writer[f"{n}best_recog"]
+#
+# # Write the result to each file
+# ibest_writer["token"][key] = " ".join(token)
+# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+# ibest_writer["score"][key] = str(hyp.score)
+#
+# if text is not None:
+# text_postprocessed = postprocess_utils.sentence_postprocess(token)
+# item = {'key': key, 'value': text_postprocessed}
+# asr_result_list.append(item)
+# finish_count += 1
+# asr_utils.print_progress(finish_count / file_count)
+# if writer is not None:
+# ibest_writer["text"][key] = text
+# return asr_result_list
+
+def inference(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ **kwargs,
+):
+ inference_pipeline = inference_modelscope(
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ batch_size=batch_size,
+ beam_size=beam_size,
+ ngpu=ngpu,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ penalty=penalty,
+ log_level=log_level,
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ raw_inputs=raw_inputs,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ key_file=key_file,
+ word_lm_train_config=word_lm_train_config,
+ bpemodel=bpemodel,
+ allow_variable_data_keys=allow_variable_data_keys,
+ streaming=streaming,
+ output_dir=output_dir,
+ dtype=dtype,
+ seed=seed,
+ ngram_weight=ngram_weight,
+ nbest=nbest,
+ num_workers=num_workers,
+ **kwargs,
+ )
+ return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+def inference_modelscope(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ **kwargs,
+):
+ assert check_argument_types()
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ )
+ logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+ speech2text = Speech2Text(**speech2text_kwargs)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["<space>"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+
+ if text is not None:
+ text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+ return asr_result_list
+
+ return _forward
+
+def set_parameters(language: str = None,
+ sample_rate: Union[int, Dict[Any, int]] = None):
+ if language is not None:
+ global global_asr_language
+ global_asr_language = language
+ if sample_rate is not None:
+ global global_sample_rate
+ global_sample_rate = sample_rate
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="ASR Decoding",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # Note(kamo): Use '_' instead of '-' as separator.
+ # '-' is confusing if written in yaml.
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument(
+ "--gpuid_list",
+ type=str,
+ default="",
+ help="The visible gpus",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ required=False,
+ action="append",
+ )
+ group.add_argument("--raw_inputs", type=list, default=None)
+ # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument(
+ "--asr_train_config",
+ type=str,
+ help="ASR training configuration",
+ )
+ group.add_argument(
+ "--asr_model_file",
+ type=str,
+ help="ASR model parameter file",
+ )
+ group.add_argument(
+ "--cmvn_file",
+ type=str,
+ help="Global cmvn file",
+ )
+ group.add_argument(
+ "--lm_train_config",
+ type=str,
+ help="LM training configuration",
+ )
+ group.add_argument(
+ "--lm_file",
+ type=str,
+ help="LM parameter file",
+ )
+ group.add_argument(
+ "--word_lm_train_config",
+ type=str,
+ help="Word LM training configuration",
+ )
+ group.add_argument(
+ "--word_lm_file",
+ type=str,
+ help="Word LM parameter file",
+ )
+ group.add_argument(
+ "--ngram_file",
+ type=str,
+ help="N-gram parameter file",
+ )
+ group.add_argument(
+ "--model_tag",
+ type=str,
+ help="Pretrained model tag. If specify this option, *_train_config and "
+ "*_file will be overwritten",
+ )
+
+ group = parser.add_argument_group("Beam-search related")
+ group.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
+ group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+ group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+ group.add_argument(
+ "--maxlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain max output length. "
+ "If maxlenratio=0.0 (default), it uses a end-detect "
+ "function "
+ "to automatically find maximum hypothesis lengths."
+ "If maxlenratio<0.0, its absolute value is interpreted"
+ "as a constant max output length",
+ )
+ group.add_argument(
+ "--minlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain min output length",
+ )
+ group.add_argument(
+ "--ctc_weight",
+ type=float,
+ default=0.5,
+ help="CTC weight in joint decoding",
+ )
+ group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+ group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+ group.add_argument("--streaming", type=str2bool, default=False)
+
+ group = parser.add_argument_group("Text converter related")
+ group.add_argument(
+ "--token_type",
+ type=str_or_none,
+ default=None,
+ choices=["char", "bpe", None],
+ help="The token type for ASR model. "
+ "If not given, refers from the training args",
+ )
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model path of sentencepiece. "
+ "If not given, refers from the training args",
+ )
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+ inference(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
index bb1d7a7..8dee758 100644
--- a/funasr/bin/build_trainer.py
+++ b/funasr/bin/build_trainer.py
@@ -27,6 +27,8 @@
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
elif mode == "uniasr":
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
+ elif mode == "mfcca":
+ from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
else:
raise ValueError("Unknown mode: {}".format(mode))
parser = ASRTask.get_parser()
diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py
new file mode 100644
index 0000000..0336133
--- /dev/null
+++ b/funasr/models/e2e_asr_mfcca.py
@@ -0,0 +1,322 @@
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+from typeguard import check_argument_types
+
+from funasr.modules.e2e_asr_common import ErrorCalculator
+from funasr.modules.nets_utils import th_accuracy
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.losses.label_smoothing_loss import (
+ LabelSmoothingLoss, # noqa: H301
+)
+from funasr.models.ctc import CTC
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+import pdb
+import random
+import math
+class MFCCA(AbsESPnetModel):
+ """CTC-attention hybrid Encoder-Decoder model"""
+
+ def __init__(
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ preencoder: Optional[AbsPreEncoder],
+ encoder: AbsEncoder,
+ decoder: AbsDecoder,
+ ctc: CTC,
+ rnnt_decoder: None,
+ ctc_weight: float = 0.5,
+ ignore_id: int = -1,
+ lsm_weight: float = 0.0,
+ mask_ratio: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ ):
+ assert check_argument_types()
+ assert 0.0 <= ctc_weight <= 1.0, ctc_weight
+ assert rnnt_decoder is None, "Not implemented"
+
+ super().__init__()
+ # note that eos is the same as sos (equivalent ID)
+ self.sos = vocab_size - 1
+ self.eos = vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.ctc_weight = ctc_weight
+ self.token_list = token_list.copy()
+
+ self.mask_ratio = mask_ratio
+
+
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ self.preencoder = preencoder
+ self.encoder = encoder
+ # we set self.decoder = None in the CTC mode since
+ # self.decoder parameters were never used and PyTorch complained
+ # and threw an Exception in the multi-GPU experiment.
+ # thanks Jeff Farris for pointing out the issue.
+ if ctc_weight == 1.0:
+ self.decoder = None
+ else:
+ self.decoder = decoder
+ if ctc_weight == 0.0:
+ self.ctc = None
+ else:
+ self.ctc = ctc
+ self.rnnt_decoder = rnnt_decoder
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+
+ if report_cer or report_wer:
+ self.error_calculator = ErrorCalculator(
+ token_list, sym_space, sym_blank, report_cer, report_wer
+ )
+ else:
+ self.error_calculator = None
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+ #pdb.set_trace()
+ if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
+ rate_num = random.random()
+ #rate_num = 0.1
+ if(rate_num<=self.mask_ratio):
+ retain_channel = math.ceil(random.random() *8)
+ if(retain_channel>1):
+ speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
+ else:
+ speech = speech[:,:,torch.randperm(8)[0]]
+ #pdb.set_trace()
+ batch_size = speech.shape[0]
+ # for data-parallel
+ text = text[:, : text_lengths.max()]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ # 2a. Attention-decoder branch
+ if self.ctc_weight == 1.0:
+ loss_att, acc_att, cer_att, wer_att = None, None, None, None
+ else:
+ loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 2b. CTC branch
+ if self.ctc_weight == 0.0:
+ loss_ctc, cer_ctc = None, None
+ else:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 2c. RNN-T branch
+ if self.rnnt_decoder is not None:
+ _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths)
+
+ if self.ctc_weight == 0.0:
+ loss = loss_att
+ elif self.ctc_weight == 1.0:
+ loss = loss_ctc
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
+
+ stats = dict(
+ loss=loss.detach(),
+ loss_att=loss_att.detach() if loss_att is not None else None,
+ loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
+ acc=acc_att,
+ cer=cer_att,
+ wer=wer_att,
+ cer_ctc=cer_ctc,
+ )
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def collect_feats(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Dict[str, torch.Tensor]:
+ feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
+ return {"feats": feats, "feats_lengths": feats_lengths}
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
+ # 2. Data augmentation
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # Pre-encoder, e.g. used for raw input data
+ if self.preencoder is not None:
+ feats, feats_lengths = self.preencoder(feats, feats_lengths)
+ #pdb.set_trace()
+ encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
+
+ assert encoder_out.size(0) == speech.size(0), (
+ encoder_out.size(),
+ speech.size(0),
+ )
+ if(encoder_out.dim()==4):
+ assert encoder_out.size(2) <= encoder_out_lens.max(), (
+ encoder_out.size(),
+ encoder_out_lens.max(),
+ )
+ else:
+ assert encoder_out.size(1) <= encoder_out_lens.max(), (
+ encoder_out.size(),
+ encoder_out_lens.max(),
+ )
+
+ return encoder_out, encoder_out_lens
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+ if self.frontend is not None:
+ # Frontend
+ # e.g. STFT and Feature extract
+ # data_loader may send time-domain signal in this case
+ # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
+ feats, feats_lengths, channel_size = self.frontend(speech, speech_lengths)
+ else:
+ # No frontend and no feature extract
+ feats, feats_lengths = speech, speech_lengths
+ channel_size = 1
+ return feats, feats_lengths, channel_size
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_in_lens = ys_pad_lens + 1
+
+ # 1. Forward decoder
+ decoder_out, _ = self.decoder(
+ encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+ )
+
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_out_pad)
+ acc_att = th_accuracy(
+ decoder_out.view(-1, self.vocab_size),
+ ys_out_pad,
+ ignore_label=self.ignore_id,
+ )
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ if(encoder_out.dim()==4):
+ encoder_out = encoder_out.mean(1)
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
+
+ def _calc_rnnt_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ raise NotImplementedError
diff --git a/funasr/models/encoder/encoder_layer_mfcca.py b/funasr/models/encoder/encoder_layer_mfcca.py
new file mode 100644
index 0000000..e0bd006
--- /dev/null
+++ b/funasr/models/encoder/encoder_layer_mfcca.py
@@ -0,0 +1,270 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
+# Northwestern Polytechnical University (Pengcheng Guo)
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Encoder self-attention layer definition."""
+
+import torch
+
+from torch import nn
+
+from funasr.modules.layer_norm import LayerNorm
+from torch.autograd import Variable
+
+
+
+class Encoder_Conformer_Layer(nn.Module):
+ """Encoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
+ can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+ """
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ feed_forward,
+ feed_forward_macaron,
+ conv_module,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ cca_pos=0,
+ ):
+ """Construct an Encoder_Conformer_Layer object."""
+ super(Encoder_Conformer_Layer, self).__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.conv_module = conv_module
+ self.norm_ff = LayerNorm(size) # for the FNN module
+ self.norm_mha = LayerNorm(size) # for the MHA module
+ if feed_forward_macaron is not None:
+ self.norm_ff_macaron = LayerNorm(size)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+ if self.conv_module is not None:
+ self.norm_conv = LayerNorm(size) # for the CNN module
+ self.norm_final = LayerNorm(size) # for the final output of the block
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ self.cca_pos = cca_pos
+
+ if self.concat_after:
+ self.concat_linear = nn.Linear(size + size, size)
+
+ def forward(self, x_input, mask, cache=None):
+ """Compute encoded features.
+
+ Args:
+ x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
+ - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
+ - w/o pos emb: Tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+
+ """
+ if isinstance(x_input, tuple):
+ x, pos_emb = x_input[0], x_input[1]
+ else:
+ x, pos_emb = x_input, None
+ # whether to use macaron style
+ if self.feed_forward_macaron is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff_macaron(x)
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
+ if not self.normalize_before:
+ x = self.norm_ff_macaron(x)
+
+ # multi-headed self-attention module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_mha(x)
+
+
+ if cache is None:
+ x_q = x
+ else:
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
+ x_q = x[:, -1:, :]
+ residual = residual[:, -1:, :]
+ mask = None if mask is None else mask[:, -1:, :]
+
+ if self.cca_pos<2:
+ if pos_emb is not None:
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
+ else:
+ x_att = self.self_attn(x_q, x, x, mask)
+ else:
+ x_att = self.self_attn(x_q, x, x, mask)
+
+ if self.concat_after:
+ x_concat = torch.cat((x, x_att), dim=-1)
+ x = residual + self.concat_linear(x_concat)
+ else:
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm_mha(x)
+
+ # convolution module
+ if self.conv_module is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_conv(x)
+ x = residual + self.dropout(self.conv_module(x))
+ if not self.normalize_before:
+ x = self.norm_conv(x)
+
+ # feed forward module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff(x)
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm_ff(x)
+
+ if self.conv_module is not None:
+ x = self.norm_final(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ if pos_emb is not None:
+ return (x, pos_emb), mask
+
+ return x, mask
+
+
+
+
+class EncoderLayer(nn.Module):
+ """Encoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
+ can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+ """
+
+ def __init__(
+ self,
+ size,
+ self_attn_cros_channel,
+ self_attn_conformer,
+ feed_forward_csa,
+ feed_forward_macaron_csa,
+ conv_module_csa,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an EncoderLayer object."""
+ super(EncoderLayer, self).__init__()
+
+ self.encoder_cros_channel_atten = self_attn_cros_channel
+ self.encoder_csa = Encoder_Conformer_Layer(
+ size,
+ self_attn_conformer,
+ feed_forward_csa,
+ feed_forward_macaron_csa,
+ conv_module_csa,
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ cca_pos=0)
+ self.norm_mha = LayerNorm(size) # for the MHA module
+ self.dropout = nn.Dropout(dropout_rate)
+
+
+ def forward(self, x_input, mask, channel_size, cache=None):
+ """Compute encoded features.
+
+ Args:
+ x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
+ - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
+ - w/o pos emb: Tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+
+ """
+ if isinstance(x_input, tuple):
+ x, pos_emb = x_input[0], x_input[1]
+ else:
+ x, pos_emb = x_input, None
+ residual = x
+ x = self.norm_mha(x)
+ t_leng = x.size(1)
+ d_dim = x.size(2)
+ x_new = x.reshape(-1,channel_size,t_leng,d_dim).transpose(1,2) # x_new B*T * C * D
+ x_k_v = x_new.new(x_new.size(0),x_new.size(1),5,x_new.size(2),x_new.size(3))
+ pad_before = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
+ pad_after = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
+ x_pad = torch.cat([pad_before,x_new, pad_after], 1)
+ x_k_v[:,:,0,:,:]=x_pad[:,0:-4,:,:]
+ x_k_v[:,:,1,:,:]=x_pad[:,1:-3,:,:]
+ x_k_v[:,:,2,:,:]=x_pad[:,2:-2,:,:]
+ x_k_v[:,:,3,:,:]=x_pad[:,3:-1,:,:]
+ x_k_v[:,:,4,:,:]=x_pad[:,4:,:,:]
+ x_new = x_new.reshape(-1,channel_size,d_dim)
+ x_k_v = x_k_v.reshape(-1,5*channel_size,d_dim)
+ x_att = self.encoder_cros_channel_atten(x_new, x_k_v, x_k_v, None)
+ x_att = x_att.reshape(-1,t_leng,channel_size,d_dim).transpose(1,2).reshape(-1,t_leng,d_dim)
+ x = residual + self.dropout(x_att)
+ if pos_emb is not None:
+ x_input = (x, pos_emb)
+ else:
+ x_input = x
+ x_input, mask = self.encoder_csa(x_input, mask)
+
+
+ return x_input, mask , channel_size
diff --git a/funasr/models/encoder/mfcca_encoder.py b/funasr/models/encoder/mfcca_encoder.py
new file mode 100644
index 0000000..83d0b0e
--- /dev/null
+++ b/funasr/models/encoder/mfcca_encoder.py
@@ -0,0 +1,450 @@
+from typing import Optional
+from typing import Tuple
+
+import logging
+import torch
+from torch import nn
+
+from typeguard import check_argument_types
+
+from funasr.models.encoder.encoder_layer_mfcca import EncoderLayer
+from funasr.modules.nets_utils import get_activation
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.attention import (
+ MultiHeadedAttention, # noqa: H301
+ RelPositionMultiHeadedAttention, # noqa: H301
+ LegacyRelPositionMultiHeadedAttention, # noqa: H301
+)
+from funasr.modules.embedding import (
+ PositionalEncoding, # noqa: H301
+ ScaledPositionalEncoding, # noqa: H301
+ RelPositionalEncoding, # noqa: H301
+ LegacyRelPositionalEncoding, # noqa: H301
+)
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.multi_layer_conv import Conv1dLinear
+from funasr.modules.multi_layer_conv import MultiLayeredConv1d
+from funasr.modules.positionwise_feed_forward import (
+ PositionwiseFeedForward, # noqa: H301
+)
+from funasr.modules.repeat import repeat
+from funasr.modules.subsampling import Conv2dSubsampling
+from funasr.modules.subsampling import Conv2dSubsampling2
+from funasr.modules.subsampling import Conv2dSubsampling6
+from funasr.modules.subsampling import Conv2dSubsampling8
+from funasr.modules.subsampling import TooShortUttError
+from funasr.modules.subsampling import check_short_utt
+from funasr.models.encoder.abs_encoder import AbsEncoder
+import pdb
+import math
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model.
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+
+ """
+
+ def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
+ """Construct an ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ groups=channels,
+ bias=bias,
+ )
+ self.norm = nn.BatchNorm1d(channels)
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = activation
+
+ def forward(self, x):
+ """Compute convolution module.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2)
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ x = self.activation(self.norm(x))
+
+ x = self.pointwise_conv2(x)
+
+ return x.transpose(1, 2)
+
+
+
+class MFCCAEncoder(AbsEncoder):
+ """Conformer encoder module.
+
+ Args:
+ input_size (int): Input dimension.
+ output_size (int): Dimention of attention.
+ attention_heads (int): The number of heads of multi head attention.
+ linear_units (int): The number of units of position-wise feed forward.
+ num_blocks (int): The number of decoder blocks.
+ dropout_rate (float): Dropout rate.
+ attention_dropout_rate (float): Dropout rate in attention.
+ positional_dropout_rate (float): Dropout rate after adding positional encoding.
+ input_layer (Union[str, torch.nn.Module]): Input layer type.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ If True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ If False, no additional linear will be applied. i.e. x -> x + att(x)
+ positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
+ positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
+ rel_pos_type (str): Whether to use the latest relative positional encoding or
+ the legacy one. The legacy relative positional encoding will be deprecated
+ in the future. More Details can be found in
+ https://github.com/espnet/espnet/pull/2816.
+ encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
+ encoder_attn_layer_type (str): Encoder attention layer type.
+ activation_type (str): Encoder activation function type.
+ macaron_style (bool): Whether to use macaron style for positionwise layer.
+ use_cnn_module (bool): Whether to use convolution module.
+ zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
+ cnn_module_kernel (int): Kernerl size of convolution module.
+ padding_idx (int): Padding idx for input_layer=embed.
+
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 3,
+ macaron_style: bool = False,
+ rel_pos_type: str = "legacy",
+ pos_enc_layer_type: str = "rel_pos",
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ zero_triu: bool = False,
+ cnn_module_kernel: int = 31,
+ padding_idx: int = -1,
+ ):
+ assert check_argument_types()
+ super().__init__()
+ self._output_size = output_size
+
+ if rel_pos_type == "legacy":
+ if pos_enc_layer_type == "rel_pos":
+ pos_enc_layer_type = "legacy_rel_pos"
+ if selfattention_layer_type == "rel_selfattn":
+ selfattention_layer_type = "legacy_rel_selfattn"
+ elif rel_pos_type == "latest":
+ assert selfattention_layer_type != "legacy_rel_selfattn"
+ assert pos_enc_layer_type != "legacy_rel_pos"
+ else:
+ raise ValueError("unknown rel_pos_type: " + rel_pos_type)
+
+ activation = get_activation(activation_type)
+ if pos_enc_layer_type == "abs_pos":
+ pos_enc_class = PositionalEncoding
+ elif pos_enc_layer_type == "scaled_abs_pos":
+ pos_enc_class = ScaledPositionalEncoding
+ elif pos_enc_layer_type == "rel_pos":
+ assert selfattention_layer_type == "rel_selfattn"
+ pos_enc_class = RelPositionalEncoding
+ elif pos_enc_layer_type == "legacy_rel_pos":
+ assert selfattention_layer_type == "legacy_rel_selfattn"
+ pos_enc_class = LegacyRelPositionalEncoding
+ logging.warning(
+ "Using legacy_rel_pos and it will be deprecated in the future."
+ )
+ else:
+ raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
+
+ if input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(input_size, output_size),
+ torch.nn.LayerNorm(output_size),
+ torch.nn.Dropout(dropout_rate),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d":
+ self.embed = Conv2dSubsampling(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d6":
+ self.embed = Conv2dSubsampling6(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d8":
+ self.embed = Conv2dSubsampling8(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif isinstance(input_layer, torch.nn.Module):
+ self.embed = torch.nn.Sequential(
+ input_layer,
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer is None:
+ self.embed = torch.nn.Sequential(
+ pos_enc_class(output_size, positional_dropout_rate)
+ )
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+ self.normalize_before = normalize_before
+ if positionwise_layer_type == "linear":
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ activation,
+ )
+ elif positionwise_layer_type == "conv1d":
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d-linear":
+ positionwise_layer = Conv1dLinear
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ else:
+ raise NotImplementedError("Support only linear or conv1d.")
+
+ if selfattention_layer_type == "selfattn":
+ encoder_selfattn_layer = MultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+ elif selfattention_layer_type == "legacy_rel_selfattn":
+ assert pos_enc_layer_type == "legacy_rel_pos"
+ encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+ logging.warning(
+ "Using legacy_rel_selfattn and it will be deprecated in the future."
+ )
+ elif selfattention_layer_type == "rel_selfattn":
+ assert pos_enc_layer_type == "rel_pos"
+ encoder_selfattn_layer = RelPositionMultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ zero_triu,
+ )
+ else:
+ raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
+
+ convolution_layer = ConvolutionModule
+ convolution_layer_args = (output_size, cnn_module_kernel, activation)
+ encoder_selfattn_layer_raw = MultiHeadedAttention
+ encoder_selfattn_layer_args_raw = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+ self.encoders = repeat(
+ num_blocks,
+ lambda lnum: EncoderLayer(
+ output_size,
+ encoder_selfattn_layer_raw(*encoder_selfattn_layer_args_raw),
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ positionwise_layer(*positionwise_layer_args) if macaron_style else None,
+ convolution_layer(*convolution_layer_args) if use_cnn_module else None,
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ if self.normalize_before:
+ self.after_norm = LayerNorm(output_size)
+ self.conv1 = torch.nn.Conv2d(8, 16, [5,7], stride=[1,1], padding=(2,3))
+
+ self.conv2 = torch.nn.Conv2d(16, 32, [5,7], stride=[1,1], padding=(2,3))
+
+ self.conv3 = torch.nn.Conv2d(32, 16, [5,7], stride=[1,1], padding=(2,3))
+
+ self.conv4 = torch.nn.Conv2d(16, 1, [5,7], stride=[1,1], padding=(2,3))
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ channel_size: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
+ ilens (torch.Tensor): Input length (#batch).
+ prev_states (torch.Tensor): Not to be used now.
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, L, output_size).
+ torch.Tensor: Output length (#batch).
+ torch.Tensor: Not to be used now.
+
+ """
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+ if (
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
+ ):
+ short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+ if short_status:
+ raise TooShortUttError(
+ f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ xs_pad.size(1),
+ limit_size,
+ )
+ xs_pad, masks = self.embed(xs_pad, masks)
+ else:
+ xs_pad = self.embed(xs_pad)
+ xs_pad, masks, channel_size = self.encoders(xs_pad, masks, channel_size)
+ if isinstance(xs_pad, tuple):
+ xs_pad = xs_pad[0]
+
+ t_leng = xs_pad.size(1)
+ d_dim = xs_pad.size(2)
+ xs_pad = xs_pad.reshape(-1,channel_size,t_leng,d_dim)
+ #pdb.set_trace()
+ if(channel_size<8):
+ repeat_num = math.ceil(8/channel_size)
+ xs_pad = xs_pad.repeat(1,repeat_num,1,1)[:,0:8,:,:]
+ xs_pad = self.conv1(xs_pad)
+ xs_pad = self.conv2(xs_pad)
+ xs_pad = self.conv3(xs_pad)
+ xs_pad = self.conv4(xs_pad)
+ xs_pad = xs_pad.squeeze().reshape(-1,t_leng,d_dim)
+ mask_tmp = masks.size(1)
+ masks = masks.reshape(-1,channel_size,mask_tmp,t_leng)[:,0,:,:]
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ olens = masks.squeeze(1).sum(1)
+ return xs_pad, olens, None
+ def forward_hidden(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
+ ilens (torch.Tensor): Input length (#batch).
+ prev_states (torch.Tensor): Not to be used now.
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, L, output_size).
+ torch.Tensor: Output length (#batch).
+ torch.Tensor: Not to be used now.
+
+ """
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+ if (
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
+ ):
+ short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+ if short_status:
+ raise TooShortUttError(
+ f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ xs_pad.size(1),
+ limit_size,
+ )
+ xs_pad, masks = self.embed(xs_pad, masks)
+ else:
+ xs_pad = self.embed(xs_pad)
+ num_layer = len(self.encoders)
+ for idx, encoder in enumerate(self.encoders):
+ xs_pad, masks = encoder(xs_pad, masks)
+ if idx == num_layer // 2 - 1:
+ hidden_feature = xs_pad
+ if isinstance(xs_pad, tuple):
+ xs_pad = xs_pad[0]
+ hidden_feature = hidden_feature[0]
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+ self.hidden_feature = self.after_norm(hidden_feature)
+
+ olens = masks.squeeze(1).sum(1)
+ return xs_pad, olens, None
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index fad6b70..9671fe9 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -131,3 +131,128 @@
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
return input_stft, feats_lens
+
+
+
+
+class MultiChannelFrontend(AbsFrontend):
+ """Conventional frontend structure for ASR.
+
+ Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
+ """
+
+ def __init__(
+ self,
+ fs: Union[int, str] = 16000,
+ n_fft: int = 512,
+ win_length: int = None,
+ hop_length: int = 128,
+ window: Optional[str] = "hann",
+ center: bool = True,
+ normalized: bool = False,
+ onesided: bool = True,
+ n_mels: int = 80,
+ fmin: int = None,
+ fmax: int = None,
+ htk: bool = False,
+ frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
+ apply_stft: bool = True,
+ frame_length: int = None,
+ frame_shift: int = None,
+ lfr_m: int = None,
+ lfr_n: int = None,
+ ):
+ assert check_argument_types()
+ super().__init__()
+ if isinstance(fs, str):
+ fs = humanfriendly.parse_size(fs)
+
+ # Deepcopy (In general, dict shouldn't be used as default arg)
+ frontend_conf = copy.deepcopy(frontend_conf)
+ self.hop_length = hop_length
+
+ if apply_stft:
+ self.stft = Stft(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ center=center,
+ window=window,
+ normalized=normalized,
+ onesided=onesided,
+ )
+ else:
+ self.stft = None
+ self.apply_stft = apply_stft
+
+ if frontend_conf is not None:
+ self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
+ else:
+ self.frontend = None
+
+ self.logmel = LogMel(
+ fs=fs,
+ n_fft=n_fft,
+ n_mels=n_mels,
+ fmin=fmin,
+ fmax=fmax,
+ htk=htk,
+ )
+ self.n_mels = n_mels
+ self.frontend_type = "multichannelfrontend"
+
+ def output_size(self) -> int:
+ return self.n_mels
+
+ def forward(
+ self, input: torch.Tensor, input_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # 1. Domain-conversion: e.g. Stft: time -> time-freq
+ #import pdb;pdb.set_trace()
+ if self.stft is not None:
+ input_stft, feats_lens = self._compute_stft(input, input_lengths)
+ else:
+ if isinstance(input, ComplexTensor):
+ input_stft = input
+ else:
+ input_stft = ComplexTensor(input[..., 0], input[..., 1])
+ feats_lens = input_lengths
+ # 2. [Option] Speech enhancement
+ if self.frontend is not None:
+ assert isinstance(input_stft, ComplexTensor), type(input_stft)
+ # input_stft: (Batch, Length, [Channel], Freq)
+ input_stft, _, mask = self.frontend(input_stft, feats_lens)
+ # 4. STFT -> Power spectrum
+ # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
+ input_power = input_stft.real ** 2 + input_stft.imag ** 2
+
+ # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
+ # input_power: (Batch, [Channel,] Length, Freq)
+ # -> input_feats: (Batch, Length, Dim)
+ input_feats, _ = self.logmel(input_power, feats_lens)
+ bt = input_feats.size(0)
+ if input_feats.dim() ==4:
+ channel_size = input_feats.size(2)
+ # batch * channel * T * D
+ #pdb.set_trace()
+ input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
+ # input_feats = input_feats.transpose(1,2)
+ # batch * channel
+ feats_lens = feats_lens.repeat(1,channel_size).squeeze()
+ else:
+ channel_size = 1
+ return input_feats, feats_lens, channel_size
+
+ def _compute_stft(
+ self, input: torch.Tensor, input_lengths: torch.Tensor
+ ) -> torch.Tensor:
+ input_stft, feats_lens = self.stft(input, input_lengths)
+
+ assert input_stft.dim() >= 4, input_stft.shape
+ # "2" refers to the real/imag parts of Complex
+ assert input_stft.shape[-1] == 2, input_stft.shape
+
+ # Change torch.Tensor to ComplexTensor
+ # input_stft: (..., F, 2) -> (..., F)
+ input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
+ return input_stft, feats_lens
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index e62a748..23ac976 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -40,6 +40,7 @@
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
@@ -47,8 +48,10 @@
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.default import MultiChannelFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
@@ -86,6 +89,7 @@
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
+ multichannelfrontend=MultiChannelFrontend,
),
type_check=AbsFrontend,
default="default",
@@ -119,6 +123,7 @@
paraformer_bert=ParaformerBert,
bicif_paraformer=BiCifParaformer,
contextual_paraformer=ContextualParaformer,
+ mfcca=MFCCA,
),
type_check=AbsESPnetModel,
default="asr",
@@ -142,6 +147,7 @@
sanm=SANMEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
+ mfcca_enc=MFCCAEncoder,
),
type_check=AbsEncoder,
default="rnn",
@@ -1106,3 +1112,135 @@
var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
+
+
+
+class ASRTaskMFCCA(ASRTask):
+ # If you need more than one optimizers, change this value
+ num_optimizers: int = 1
+
+ # Add variable objects configurations
+ class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --model and --model_conf
+ model_choices,
+ # --preencoder and --preencoder_conf
+ preencoder_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ ]
+
+ # If you need to modify train() or eval() procedures, change Trainer class here
+ trainer = Trainer
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace):
+ assert check_argument_types()
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(stats_file=args.cmvn_file,**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 4. Pre-encoder input block
+ # NOTE(kan-bayashi): Use getattr to keep the compatibility
+ if getattr(args, "preencoder", None) is not None:
+ preencoder_class = preencoder_choices.get_class(args.preencoder)
+ preencoder = preencoder_class(**args.preencoder_conf)
+ input_size = preencoder.output_size()
+ else:
+ preencoder = None
+
+ # 5. Encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+ # 7. Decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder.output_size(),
+ **args.decoder_conf,
+ )
+
+ # 8. CTC
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
+ )
+
+
+ # 10. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("asr")
+
+ rnnt_decoder = None
+
+ # 8. Build model
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ decoder=decoder,
+ ctc=ctc,
+ rnnt_decoder=rnnt_decoder,
+ token_list=token_list,
+ **args.model_conf,
+ )
+
+ # 11. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ assert check_return_type(model)
+ return model
+
+
--
Gitblit v1.9.1