From 0e622e694e6cb4459955f1e5942a7c53349ce640 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 21:58:14 +0800
Subject: [PATCH] funasr2
---
funasr/frontends/utils/log_mel.py | 0
funasr/frontends/wav_frontend_kaldifeat.py | 0
funasr/models/paraformer/search.py | 2
funasr/models/ct_transformer/vad_realtime_transformer.py | 2
funasr/datasets/audio_datasets/__init__.py | 0
funasr/models/neat_contextual_paraformer/model.py | 533 ++
funasr/models/transformer/model.py | 71
examples/industrial_data_pretraining/paraformer-large/finetune.sh | 0
funasr/models/uniasr/e2e_uni_asr.py | 8
funasr/bin/inference.py | 218
funasr/models/sond/attention.py | 328 +
funasr/models/transformer/encoder.py | 332 +
funasr/frontends/__init__.py | 0
funasr/models/scama/sanm_decoder.py | 552 +
funasr/models/e_branchformer/encoder.py | 15
funasr/models/bici_paraformer/cif_predictor.py | 340 +
funasr/models/branchformer/model.py | 52
funasr/tokenizer/char_tokenizer.py | 4
funasr/models/sanm/model.py | 18
funasr/models/mfcca/mfcca_encoder.py | 14
funasr/models/paraformer/cif_predictor.py | 376 -
funasr/models/conformer/model.py | 52
funasr/models/bat/model.py | 39
funasr/models/eend/encoder.py | 2
examples/industrial_data_pretraining/paraformer-large/infer.sh | 10
funasr/frontends/fused.py | 8
funasr/datasets/audio_datasets/load_audio_extract_fbank.py | 11
funasr/utils/timestamp_tools.py | 2
funasr/tokenizer/phoneme_tokenizer.py | 2
funasr/models/ct_transformer/attention.py | 1091 ++++
funasr/models/paraformer_online/sanm_decoder.py | 507 +
funasr/models/paraformer/model.py | 1322 ----
funasr/train_utils/trainer.py | 13
funasr/utils/register.py | 72
funasr/models/mfcca/e2e_asr_mfcca.py | 6
funasr/frontends/utils/stft.py | 2
funasr/datasets/audio_datasets/index_ds.py | 64
funasr/frontends/utils/__init__.py | 0
funasr/models/language_model/rnn/encoders.py | 2
funasr/models/sa_asr/e2e_sa_asr.py | 6
funasr/models/paraformer/template.yaml | 126
funasr/models/transformer/utils/nets_utils.py | 21
funasr/bin/train.py | 89
funasr/models/language_model/rnn/decoders.py | 2
funasr/models/transducer/rnn_decoder.py | 7
funasr/datasets/audio_datasets/datasets.py | 83
funasr/datasets/__init__.py | 0
funasr/models/tp_aligner/e2e_tp.py | 8
funasr/models/sond/encoder/conv_encoder.py | 2
funasr/frontends/default.py | 17
funasr/schedulers/__init__.py | 2
funasr/frontends/utils/frontend.py | 4
funasr/datasets/audio_datasets/samplers.py | 39
funasr/frontends/utils/dnn_beamformer.py | 8
funasr/models/paraformer_online/model.py | 1284 ++++
funasr/models/bat/conformer_chunk_encoder.py | 701 ++
funasr/models/language_model/transformer_encoder.py | 231
funasr/models/cnn/ResNet_aug.py | 2
funasr/models/data2vec/data2vec.py | 25
funasr/download/download_from_hub.py | 10
funasr/models/bat/cif_predictor.py | 220
funasr/models/sond/e2e_diar_sond.py | 2
funasr/models/transformer/decoder.py | 647 ++
funasr/train_utils/model_summary.py | 8
funasr/optimizers/__init__.py | 2
funasr/frontends/s3prl.py | 9
funasr/__init__.py | 23
funasr/download/runtime_sdk_download_tool.py | 73
funasr/frontends/windowing.py | 4
funasr/models/transformer/positionwise_feed_forward.py | 22
funasr/models/e_branchformer/model.py | 52
funasr/tokenizer/abs_tokenizer.py | 12
funasr/frontends/utils/mask_estimator.py | 0
funasr/metrics/compute_acc.py | 23
funasr/models/normalize/global_mvn.py | 3
funasr/models/sa_asr/transformer_decoder.py | 442 -
funasr/models/specaug/specaug.py | 5
funasr/frontends/utils/feature_transform.py | 0
funasr/models/ct_transformer/target_delay_transformer.py | 2
funasr/models/sond/encoder/self_attention_encoder.py | 16
funasr/frontends/eend_ola_feature.py | 0
funasr/models/neat_contextual_paraformer/decoder.py | 12
funasr/models/sanm/positionwise_feed_forward.py | 34
funasr/models/sond/encoder/fsmn_encoder.py | 2
funasr/models/bat/attention.py | 238
funasr/models/ct_transformer/sanm_encoder.py | 383 +
funasr/models/transformer/attention.py | 774 --
funasr/models/data2vec/data2vec_encoder.py | 3
funasr/models/fsmn_vad/model.py | 2
funasr/models/sanm/decoder.py | 474 +
funasr/models/sanm/attention.py | 641 ++
funasr/frontends/wav_frontend.py | 18
funasr/models/cnn/DTDNN.py | 2
funasr/frontends/utils/complex_utils.py | 0
funasr/models/bici_paraformer/model.py | 328 +
funasr/models/normalize/utterance_mvn.py | 3
funasr/models/eend/e2e_diar_eend_ola.py | 2
funasr/models/branchformer/encoder.py | 11
funasr/models/scama/sanm_encoder.py | 613 ++
funasr/models/transducer/model.py | 27
funasr/models/paraformer/decoder.py | 625 ++
/dev/null | 14
funasr/frontends/utils/beamformer.py | 0
funasr/models/sa_asr/attention.py | 51
funasr/bin/tokenize_text.py | 4
funasr/models/conformer/encoder.py | 613 ++
funasr/frontends/utils/dnn_wpe.py | 2
funasr/models/cnn/ResNet.py | 2
funasr/models/xvector/e2e_sv.py | 6
funasr/models/sanm/encoder.py | 454 +
110 files changed, 11,688 insertions(+), 3,952 deletions(-)
diff --git a/examples/industrial_data_pretraining/paraformer-large/run.sh b/examples/industrial_data_pretraining/paraformer-large/finetune.sh
similarity index 100%
rename from examples/industrial_data_pretraining/paraformer-large/run.sh
rename to examples/industrial_data_pretraining/paraformer-large/finetune.sh
diff --git a/examples/industrial_data_pretraining/paraformer-large/infer.sh b/examples/industrial_data_pretraining/paraformer-large/infer.sh
index b7fbe75..48ad3bf 100644
--- a/examples/industrial_data_pretraining/paraformer-large/infer.sh
+++ b/examples/industrial_data_pretraining/paraformer-large/infer.sh
@@ -2,6 +2,12 @@
cmd="funasr/bin/inference.py"
python $cmd \
++model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
++input="/Users/zhifu/Downloads/asr_example.wav" \
++output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
++device="cpu" \
+
+python $cmd \
+model="/Users/zhifu/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
+input="/Users/zhifu/Downloads/asr_example.wav" \
+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
@@ -12,4 +18,6 @@
#+input="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
#+model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
-#+model="/Users/zhifu/modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
\ No newline at end of file
+#+model="/Users/zhifu/modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+#+model="/Users/zhifu/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
+#+"hotword='杈鹃瓟闄� 榄旀惌'"
\ No newline at end of file
diff --git a/funasr/__init__.py b/funasr/__init__.py
index ac1591d..c7cc3b6 100644
--- a/funasr/__init__.py
+++ b/funasr/__init__.py
@@ -1,10 +1,31 @@
"""Initialize funasr package."""
import os
+import pkgutil
+import importlib
dirname = os.path.dirname(__file__)
version_file = os.path.join(dirname, "version.txt")
with open(version_file, "r") as f:
__version__ = f.read().strip()
-from funasr.bin.inference import infer
\ No newline at end of file
+
+import importlib
+import pkgutil
+
+def import_submodules(package, recursive=True):
+ if isinstance(package, str):
+ package = importlib.import_module(package)
+ results = {}
+ for loader, name, is_pkg in pkgutil.walk_packages(package.__path__, package.__name__ + '.'):
+ try:
+ results[name] = importlib.import_module(name)
+ except Exception as e:
+ # 濡傛灉鎯宠鐪嬪埌瀵煎叆閿欒鐨勫叿浣撲俊鎭紝鍙互鍙栨秷娉ㄩ噴涓嬮潰鐨勮
+ # print(f"Failed to import {name}: {e}")
+ pass
+ if recursive and is_pkg:
+ results.update(import_submodules(name))
+ return results
+
+import_submodules(__name__)
diff --git a/funasr/bin/aggregate_stats_dirs.py b/funasr/bin/aggregate_stats_dirs.py
deleted file mode 100755
index 94cbdf8..0000000
--- a/funasr/bin/aggregate_stats_dirs.py
+++ /dev/null
@@ -1,108 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-import sys
-from pathlib import Path
-from typing import Iterable
-from typing import Union
-
-import numpy as np
-
-from funasr.utils.cli_utils import get_commandline_args
-
-
-def aggregate_stats_dirs(
- input_dir: Iterable[Union[str, Path]],
- output_dir: Union[str, Path],
- log_level: str,
- skip_sum_stats: bool,
-):
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) (levelname)s: %(message)s",
- )
-
- input_dirs = [Path(p) for p in input_dir]
- output_dir = Path(output_dir)
-
- for mode in ["train", "valid"]:
- with (input_dirs[0] / mode / "batch_keys").open("r", encoding="utf-8") as f:
- batch_keys = [line.strip() for line in f if line.strip() != ""]
- with (input_dirs[0] / mode / "stats_keys").open("r", encoding="utf-8") as f:
- stats_keys = [line.strip() for line in f if line.strip() != ""]
- (output_dir / mode).mkdir(parents=True, exist_ok=True)
-
- for key in batch_keys:
- with (output_dir / mode / f"{key}_shape").open(
- "w", encoding="utf-8"
- ) as fout:
- for idir in input_dirs:
- with (idir / mode / f"{key}_shape").open(
- "r", encoding="utf-8"
- ) as fin:
- # Read to the last in order to sort keys
- # because the order can be changed if num_workers>=1
- lines = fin.readlines()
- lines = sorted(lines, key=lambda x: x.split()[0])
- for line in lines:
- fout.write(line)
-
- for key in stats_keys:
- if not skip_sum_stats:
- sum_stats = None
- for idir in input_dirs:
- stats = np.load(idir / mode / f"{key}_stats.npz")
- if sum_stats is None:
- sum_stats = dict(**stats)
- else:
- for k in stats:
- sum_stats[k] += stats[k]
-
- np.savez(output_dir / mode / f"{key}_stats.npz", **sum_stats)
-
- # if --write_collected_feats=true
- p = Path(mode) / "collect_feats" / f"{key}.scp"
- scp = input_dirs[0] / p
- if scp.exists():
- (output_dir / p).parent.mkdir(parents=True, exist_ok=True)
- with (output_dir / p).open("w", encoding="utf-8") as fout:
- for idir in input_dirs:
- with (idir / p).open("r", encoding="utf-8") as fin:
- for line in fin:
- fout.write(line)
-
-
-def get_parser() -> argparse.ArgumentParser:
- parser = argparse.ArgumentParser(
- description="Aggregate statistics directories into one directory",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
- parser.add_argument(
- "--skip_sum_stats",
- default=False,
- action="store_true",
- help="Skip computing the sum of statistics.",
- )
-
- parser.add_argument("--input_dir", action="append", help="Input directories")
- parser.add_argument("--output_dir", required=True, help="Output directory")
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- aggregate_stats_dirs(**kwargs)
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/export_model.py b/funasr/bin/export_model.py
deleted file mode 100644
index 6ab9408..0000000
--- a/funasr/bin/export_model.py
+++ /dev/null
@@ -1,296 +0,0 @@
-import os
-import torch
-import random
-import logging
-import numpy as np
-from pathlib import Path
-from typing import Union, Dict, List
-from funasr.export.models import get_model
-from funasr.utils.types import str2bool, str2triple_str
-# torch_version = float(".".join(torch.__version__.split(".")[:2]))
-# assert torch_version > 1.9
-
-class ModelExport:
- def __init__(
- self,
- cache_dir: Union[Path, str] = None,
- onnx: bool = True,
- device: str = "cpu",
- quant: bool = True,
- fallback_num: int = 0,
- audio_in: str = None,
- calib_num: int = 200,
- model_revision: str = None,
- ):
- self.set_all_random_seed(0)
-
- self.cache_dir = cache_dir
- self.export_config = dict(
- feats_dim=560,
- onnx=False,
- )
-
- self.onnx = onnx
- self.device = device
- self.quant = quant
- self.fallback_num = fallback_num
- self.frontend = None
- self.audio_in = audio_in
- self.calib_num = calib_num
- self.model_revision = model_revision
-
-
- def _export(
- self,
- model,
- tag_name: str = None,
- verbose: bool = False,
- ):
-
- export_dir = self.cache_dir
- os.makedirs(export_dir, exist_ok=True)
-
- # export encoder1
- self.export_config["model_name"] = "model"
- model = get_model(
- model,
- self.export_config,
- )
- if isinstance(model, List):
- for m in model:
- m.eval()
- if self.onnx:
- self._export_onnx(m, verbose, export_dir)
- else:
- self._export_torchscripts(m, verbose, export_dir)
- print("output dir: {}".format(export_dir))
- else:
- model.eval()
- # self._export_onnx(model, verbose, export_dir)
- if self.onnx:
- self._export_onnx(model, verbose, export_dir)
- else:
- self._export_torchscripts(model, verbose, export_dir)
- print("output dir: {}".format(export_dir))
-
-
- def _torch_quantize(self, model):
- def _run_calibration_data(m):
- # using dummy inputs for a example
- if self.audio_in is not None:
- feats, feats_len = self.load_feats(self.audio_in)
- for i, (feat, len) in enumerate(zip(feats, feats_len)):
- with torch.no_grad():
- m(feat, len)
- else:
- dummy_input = model.get_dummy_inputs()
- m(*dummy_input)
-
-
- from torch_quant.module import ModuleFilter
- from torch_quant.quantizer import Backend, Quantizer
- from funasr.export.models.modules.decoder_layer import DecoderLayerSANM
- from funasr.export.models.modules.encoder_layer import EncoderLayerSANM
- module_filter = ModuleFilter(include_classes=[EncoderLayerSANM, DecoderLayerSANM])
- module_filter.exclude_op_types = [torch.nn.Conv1d]
- quantizer = Quantizer(
- module_filter=module_filter,
- backend=Backend.FBGEMM,
- )
- model.eval()
- calib_model = quantizer.calib(model)
- _run_calibration_data(calib_model)
- if self.fallback_num > 0:
- # perform automatic mixed precision quantization
- amp_model = quantizer.amp(model)
- _run_calibration_data(amp_model)
- quantizer.fallback(amp_model, num=self.fallback_num)
- print('Fallback layers:')
- print('\n'.join(quantizer.module_filter.exclude_names))
- quant_model = quantizer.quantize(model)
- return quant_model
-
-
- def _export_torchscripts(self, model, verbose, path, enc_size=None):
- if enc_size:
- dummy_input = model.get_dummy_inputs(enc_size)
- else:
- dummy_input = model.get_dummy_inputs()
-
- if self.device == 'cuda':
- model = model.cuda()
- dummy_input = tuple([i.cuda() for i in dummy_input])
-
- # model_script = torch.jit.script(model)
- model_script = torch.jit.trace(model, dummy_input)
- model_script.save(os.path.join(path, f'{model.model_name}.torchscripts'))
-
- if self.quant:
- quant_model = self._torch_quantize(model)
- model_script = torch.jit.trace(quant_model, dummy_input)
- model_script.save(os.path.join(path, f'{model.model_name}_quant.torchscripts'))
-
-
- def set_all_random_seed(self, seed: int):
- random.seed(seed)
- np.random.seed(seed)
- torch.random.manual_seed(seed)
-
- def parse_audio_in(self, audio_in):
-
- wav_list, name_list = [], []
- if audio_in.endswith(".scp"):
- f = open(audio_in, 'r')
- lines = f.readlines()[:self.calib_num]
- for line in lines:
- name, path = line.strip().split()
- name_list.append(name)
- wav_list.append(path)
- else:
- wav_list = [audio_in,]
- name_list = ["test",]
- return wav_list, name_list
-
- def load_feats(self, audio_in: str = None):
- import torchaudio
-
- wav_list, name_list = self.parse_audio_in(audio_in)
- feats = []
- feats_len = []
- for line in wav_list:
- path = line.strip()
- waveform, sampling_rate = torchaudio.load(path)
- if sampling_rate != self.frontend.fs:
- waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
- new_freq=self.frontend.fs)(waveform)
- fbank, fbank_len = self.frontend(waveform, [waveform.size(1)])
- feats.append(fbank)
- feats_len.append(fbank_len)
- return feats, feats_len
-
- def export(self,
- tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
- mode: str = None,
- ):
-
- model_dir = tag_name
- if model_dir.startswith('damo'):
- from modelscope.hub.snapshot_download import snapshot_download
- model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir, revision=self.model_revision)
- self.cache_dir = model_dir
-
- if mode is None:
- import json
- json_file = os.path.join(model_dir, 'configuration.json')
- with open(json_file, 'r') as f:
- config_data = json.load(f)
- if config_data['task'] == "punctuation":
- mode = config_data['model']['punc_model_config']['mode']
- else:
- mode = config_data['model']['model_config']['mode']
- if mode.startswith('paraformer'):
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- config = os.path.join(model_dir, 'config.yaml')
- model_file = os.path.join(model_dir, 'model.pb')
- cmvn_file = os.path.join(model_dir, 'am.mvn')
- model, asr_train_args = ASRTask.build_model_from_file(
- config, model_file, cmvn_file, 'cpu'
- )
- self.frontend = model.frontend
- self.export_config["feats_dim"] = 560
- elif mode.startswith('offline'):
- from funasr.tasks.vad import VADTask
- config = os.path.join(model_dir, 'vad.yaml')
- model_file = os.path.join(model_dir, 'vad.pb')
- cmvn_file = os.path.join(model_dir, 'vad.mvn')
-
- model, vad_infer_args = VADTask.build_model_from_file(
- config, model_file, cmvn_file=cmvn_file, device='cpu'
- )
- self.export_config["feats_dim"] = 400
- self.frontend = model.frontend
- elif mode.startswith('punc'):
- from funasr.tasks.punctuation import PunctuationTask as PUNCTask
- punc_train_config = os.path.join(model_dir, 'config.yaml')
- punc_model_file = os.path.join(model_dir, 'punc.pb')
- model, punc_train_args = PUNCTask.build_model_from_file(
- punc_train_config, punc_model_file, 'cpu'
- )
- elif mode.startswith('punc_VadRealtime'):
- from funasr.tasks.punctuation import PunctuationTask as PUNCTask
- punc_train_config = os.path.join(model_dir, 'config.yaml')
- punc_model_file = os.path.join(model_dir, 'punc.pb')
- model, punc_train_args = PUNCTask.build_model_from_file(
- punc_train_config, punc_model_file, 'cpu'
- )
- self._export(model, tag_name)
-
-
- def _export_onnx(self, model, verbose, path, enc_size=None):
- if enc_size:
- dummy_input = model.get_dummy_inputs(enc_size)
- else:
- dummy_input = model.get_dummy_inputs()
-
- # model_script = torch.jit.script(model)
- model_script = model #torch.jit.trace(model)
- model_path = os.path.join(path, f'{model.model_name}.onnx')
- # if not os.path.exists(model_path):
- torch.onnx.export(
- model_script,
- dummy_input,
- model_path,
- verbose=verbose,
- opset_version=14,
- input_names=model.get_input_names(),
- output_names=model.get_output_names(),
- dynamic_axes=model.get_dynamic_axes()
- )
-
- if self.quant:
- from onnxruntime.quantization import QuantType, quantize_dynamic
- import onnx
- quant_model_path = os.path.join(path, f'{model.model_name}_quant.onnx')
- if not os.path.exists(quant_model_path):
- onnx_model = onnx.load(model_path)
- nodes = [n.name for n in onnx_model.graph.node]
- nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m or 'bias_decoder' in m]
- quantize_dynamic(
- model_input=model_path,
- model_output=quant_model_path,
- op_types_to_quantize=['MatMul'],
- per_channel=True,
- reduce_range=False,
- weight_type=QuantType.QUInt8,
- nodes_to_exclude=nodes_to_exclude,
- )
-
-
-if __name__ == '__main__':
- import argparse
- parser = argparse.ArgumentParser()
- # parser.add_argument('--model-name', type=str, required=True)
- parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
- parser.add_argument('--export-dir', type=str, required=True)
- parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
- parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
- parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
- parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
- parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
- parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
- parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
- args = parser.parse_args()
-
- export_model = ModelExport(
- cache_dir=args.export_dir,
- onnx=args.type == 'onnx',
- device=args.device,
- quant=args.quantize,
- fallback_num=args.fallback_num,
- audio_in=args.audio_in,
- calib_num=args.calib_num,
- model_revision=args.model_revision,
- )
- for model_name in args.model_name:
- print("export model: {}".format(model_name))
- export_model.export(model_name)
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index d63ebc9..09e28f3 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -5,123 +5,25 @@
import hydra
import json
from omegaconf import DictConfig, OmegaConf
-from funasr.utils.dynamic_import import dynamic_import
import logging
from funasr.download.download_from_hub import download_model
from funasr.train_utils.set_all_random_seed import set_all_random_seed
-from funasr.tokenizer.funtoken import build_tokenizer
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_bytes
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_bytes
from funasr.train_utils.device_funcs import to_device
from tqdm import tqdm
from funasr.train_utils.load_pretrained_model import load_pretrained_model
import time
import random
import string
+from funasr.utils.register import registry_tables
-@hydra.main(config_name=None, version_base=None)
-def main_hydra(kwargs: DictConfig):
- assert "model" in kwargs
- pipeline = infer(**kwargs)
- res = pipeline(input=kwargs["input"])
- print(res)
-
-def infer(**kwargs):
-
- if ":" not in kwargs["model"]:
- logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
- kwargs = download_model(**kwargs)
-
- set_all_random_seed(kwargs.get("seed", 0))
-
-
- device = kwargs.get("device", "cuda")
- if not torch.cuda.is_available() or kwargs.get("ngpu", 1):
- device = "cpu"
- batch_size = 1
- kwargs["device"] = device
-
- # build_tokenizer
- tokenizer = build_tokenizer(
- token_type=kwargs.get("token_type", "char"),
- bpemodel=kwargs.get("bpemodel", None),
- delimiter=kwargs.get("delimiter", None),
- space_symbol=kwargs.get("space_symbol", "<space>"),
- non_linguistic_symbols=kwargs.get("non_linguistic_symbols", None),
- g2p_type=kwargs.get("g2p_type", None),
- token_list=kwargs.get("token_list", None),
- unk_symbol=kwargs.get("unk_symbol", "<unk>"),
- )
-
- import pdb;
- pdb.set_trace()
- # build model
- model_class = dynamic_import(kwargs.get("model"))
- model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
- model.eval()
- model.to(device)
- frontend = model.frontend
- kwargs["token_list"] = tokenizer.token_list
-
-
- # init_param
- init_param = kwargs.get("init_param", None)
- if init_param is not None:
- logging.info(f"Loading pretrained params from {init_param}")
- load_pretrained_model(
- model=model,
- init_param=init_param,
- ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
- oss_bucket=kwargs.get("oss_bucket", None),
- )
-
- def _forward(input, input_len=None, **cfg):
- cfg = OmegaConf.merge(kwargs, cfg)
- date_type = cfg.get("date_type", "sound")
-
- key_list, data_list = build_iter_for_infer(input, input_len=input_len, date_type=date_type, frontend=frontend)
-
- speed_stats = {}
- asr_result_list = []
- num_samples = len(data_list)
- pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
- for beg_idx in range(0, num_samples, batch_size):
-
- end_idx = min(num_samples, beg_idx + batch_size)
- data_batch = data_list[beg_idx:end_idx]
- key_batch = key_list[beg_idx:end_idx]
- batch = {"data_in": data_batch, "key": key_batch}
-
- time1 = time.perf_counter()
- results, meta_data = model.generate(**batch, tokenizer=tokenizer, **cfg)
- time2 = time.perf_counter()
-
- asr_result_list.append(results)
- pbar.update(1)
-
- # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
- batch_data_time = meta_data.get("batch_data_time", -1)
- speed_stats["load_data"] = meta_data["load_data"]
- speed_stats["extract_feat"] = meta_data["extract_feat"]
- speed_stats["forward"] = f"{time2 - time1:0.3f}"
- speed_stats["rtf"] = f"{(time2 - time1)/batch_data_time:0.3f}"
- description = (
- f"{speed_stats}, "
- )
- pbar.set_description(description)
-
- torch.cuda.empty_cache()
- return asr_result_list
-
- return _forward
-
-
-def build_iter_for_infer(data_in, input_len=None, date_type="sound", frontend=None):
+def build_iter_for_infer(data_in, input_len=None, data_type="sound"):
"""
:param input:
:param input_len:
- :param date_type:
+ :param data_type:
:param frontend:
:return:
"""
@@ -131,7 +33,7 @@
chars = string.ascii_letters + string.digits
- if isinstance(data_in, str) and os.path.exists(data_in): # wav_pat; filelist: wav.scp, file.jsonl;text.txt;
+ if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
_, file_extension = os.path.splitext(data_in)
file_extension = file_extension.lower()
if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt;
@@ -153,10 +55,10 @@
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
data_list = [data_in]
key_list = [key]
- elif isinstance(data_in, (list, tuple)): # [audio sample point, fbank, wav_path]
+ elif isinstance(data_in, (list, tuple)): # [audio sample point, fbank]
data_list = data_in
key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
- else: # raw text; audio sample point, fbank
+ else: # raw text; audio sample point, fbank; bytes
if isinstance(data_in, bytes): # audio bytes
data_in = load_bytes(data_in)
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
@@ -165,6 +67,112 @@
return key_list, data_list
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(kwargs: DictConfig):
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+
+ logging.basicConfig(level=log_level)
+
+ import pdb;
+ pdb.set_trace()
+ model = AutoModel(**kwargs)
+ res = model.generate(input=kwargs["input"])
+ print(res)
+
+class AutoModel:
+ def __init__(self, **kwargs):
+ registry_tables.print_register_tables()
+ assert "model" in kwargs
+ if "model_conf" not in kwargs:
+ logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+ kwargs = download_model(**kwargs)
+
+ set_all_random_seed(kwargs.get("seed", 0))
+
+ device = kwargs.get("device", "cuda")
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1):
+ device = "cpu"
+ kwargs["batch_size"] = 1
+ kwargs["device"] = device
+
+ # build tokenizer
+ tokenizer = kwargs.get("tokenizer", None)
+ if tokenizer is not None:
+ tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
+ tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
+ kwargs["tokenizer"] = tokenizer
+
+ # build frontend
+ frontend = kwargs.get("frontend", None)
+ if frontend is not None:
+ frontend_class = registry_tables.frontend_classes.get(frontend.lower())
+ frontend = frontend_class(**kwargs["frontend_conf"])
+ kwargs["frontend"] = frontend
+
+ # build model
+ model_class = registry_tables.model_classes.get(kwargs["model"].lower())
+ model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+ model.eval()
+ model.to(device)
+
+ kwargs["token_list"] = tokenizer.token_list
+
+ # init_param
+ init_param = kwargs.get("init_param", None)
+ if init_param is not None:
+ logging.info(f"Loading pretrained params from {init_param}")
+ load_pretrained_model(
+ model=model,
+ init_param=init_param,
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
+ oss_bucket=kwargs.get("oss_bucket", None),
+ )
+ self.kwargs = kwargs
+ self.model = model
+ self.tokenizer = tokenizer
+
+ def generate(self, input, input_len=None, **cfg):
+ self.kwargs.update(cfg)
+ data_type = self.kwargs.get("data_type", "sound")
+ batch_size = self.kwargs.get("batch_size", 1)
+ if self.kwargs.get("device", "cpu") == "cpu":
+ batch_size = 1
+
+ key_list, data_list = build_iter_for_infer(input, input_len=input_len, data_type=data_type)
+
+ speed_stats = {}
+ asr_result_list = []
+ num_samples = len(data_list)
+ pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
+ for beg_idx in range(0, num_samples, batch_size):
+ end_idx = min(num_samples, beg_idx + batch_size)
+ data_batch = data_list[beg_idx:end_idx]
+ key_batch = key_list[beg_idx:end_idx]
+ batch = {"data_in": data_batch, "key": key_batch}
+ if (end_idx - beg_idx) == 1 and isinstance(data_batch[0], torch.Tensor): # fbank
+ batch["data_batch"] = data_batch[0]
+ batch["data_lengths"] = input_len
+
+ time1 = time.perf_counter()
+ results, meta_data = self.model.generate(**batch, **self.kwargs)
+ time2 = time.perf_counter()
+
+ asr_result_list.append(results)
+ pbar.update(1)
+
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
+ batch_data_time = meta_data.get("batch_data_time", -1)
+ speed_stats["load_data"] = meta_data.get("load_data", 0.0)
+ speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
+ speed_stats["forward"] = f"{time2 - time1:0.3f}"
+ speed_stats["rtf"] = f"{(time2 - time1) / batch_data_time:0.3f}"
+ description = (
+ f"{speed_stats}, "
+ )
+ pbar.set_description(description)
+
+ torch.cuda.empty_cache()
+ return asr_result_list
if __name__ == '__main__':
main_hydra()
\ No newline at end of file
diff --git a/funasr/bin/tokenize_text.py b/funasr/bin/tokenize_text.py
index 674c1b9..0ecf2f6 100755
--- a/funasr/bin/tokenize_text.py
+++ b/funasr/bin/tokenize_text.py
@@ -11,7 +11,7 @@
from funasr.utils.cli_utils import get_commandline_args
from funasr.tokenizer.build_tokenizer import build_tokenizer
from funasr.tokenizer.cleaner import TextCleaner
-from funasr.tokenizer.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_classes
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
@@ -239,7 +239,7 @@
parser.add_argument(
"--g2p",
type=str_or_none,
- choices=g2p_choices,
+ choices=g2p_classes,
default=None,
help="Specify g2p method if --token_type=phn",
)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 6a88233..72fa9fa 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -8,34 +8,30 @@
import hydra
from omegaconf import DictConfig, OmegaConf
from funasr.train_utils.set_all_random_seed import set_all_random_seed
-# from funasr.model_class_factory1 import model_choices
from funasr.models.lora.utils import mark_only_lora_as_trainable
-from funasr.optimizers import optim_choices
-from funasr.schedulers import scheduler_choices
+from funasr.optimizers import optim_classes
+from funasr.schedulers import scheduler_classes
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.train_utils.initialize import initialize
-from funasr.datasets.fun_datasets.data_sampler import BatchSampler
# from funasr.tokenizer.build_tokenizer import build_tokenizer
# from funasr.tokenizer.token_id_converter import TokenIDConverter
-from funasr.tokenizer.funtoken import build_tokenizer
-from funasr.datasets.fun_datasets.dataset_jsonl import AudioDataset
+# from funasr.tokenizer.funtoken import build_tokenizer
from funasr.train_utils.trainer import Trainer
-# from funasr.utils.load_fr_py import load_class_from_path
-from funasr.utils.dynamic_import import dynamic_import
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from funasr.download.download_from_hub import download_model
+from funasr.utils.register import registry_tables
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
import pdb; pdb.set_trace()
- if ":" in kwargs["model"]:
+ assert "model" in kwargs
+ if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
- import pdb;
- pdb.set_trace()
+
main(**kwargs)
@@ -43,6 +39,7 @@
# preprocess_config(kwargs)
# import pdb; pdb.set_trace()
# set random seed
+ registry_tables.print_register_tables()
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
@@ -56,31 +53,38 @@
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
torch.cuda.set_device(local_rank)
-
- # build_tokenizer
- tokenizer = build_tokenizer(
- token_type=kwargs.get("token_type", "char"),
- bpemodel=kwargs.get("bpemodel", None),
- delimiter=kwargs.get("delimiter", None),
- space_symbol=kwargs.get("space_symbol", "<space>"),
- non_linguistic_symbols=kwargs.get("non_linguistic_symbols", None),
- g2p_type=kwargs.get("g2p_type", None),
- token_list=kwargs.get("token_list", None),
- unk_symbol=kwargs.get("unk_symbol", "<unk>"),
- )
+ # save config.yaml
+ if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
+ os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
+ yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
+ OmegaConf.save(config=kwargs, f=yaml_file)
+ logging.info("config.yaml is saved to: %s", yaml_file)
+ tokenizer = kwargs.get("tokenizer", None)
+ if tokenizer is not None:
+ tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
+ tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
+ kwargs["tokenizer"] = tokenizer
+
+ # build frontend if frontend is none None
+ frontend = kwargs.get("frontend", None)
+ if frontend is not None:
+ frontend_class = registry_tables.frontend_classes.get(frontend.lower())
+ frontend = frontend_class(**kwargs["frontend_conf"])
+ kwargs["frontend"] = frontend
+
# import pdb;
# pdb.set_trace()
# build model
- # model_class = model_choices.get_class(kwargs.get("model", "asr"))
- # model_class = load_class_from_path(kwargs.get("model").split(":"))
- model_class = dynamic_import(kwargs.get("model"))
+ model_class = registry_tables.model_classes.get(kwargs["model"].lower())
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
- frontend = model.frontend
+
+
+
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:
- if not isinstance(init_param, Sequence):
+ if not isinstance(init_param, (list, tuple)):
init_param = (init_param,)
logging.info("init_param is not None: %s", init_param)
for p in init_param:
@@ -93,9 +97,8 @@
)
else:
initialize(model, kwargs.get("init", "kaiming_normal"))
-
- # import pdb;
- # pdb.set_trace()
+
+
# freeze_param
freeze_param = kwargs.get("freeze_param", None)
if freeze_param is not None:
@@ -122,33 +125,33 @@
# optim
optim = kwargs.get("optim", "adam")
- assert optim in optim_choices
- optim_class = optim_choices.get(optim)
+ assert optim in optim_classes
+ optim_class = optim_classes.get(optim)
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
# scheduler
scheduler = kwargs.get("scheduler", "warmuplr")
- assert scheduler in scheduler_choices
- scheduler_class = scheduler_choices.get(scheduler)
+ assert scheduler in scheduler_classes
+ scheduler_class = scheduler_classes.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
-
+ # import pdb;
+ # pdb.set_trace()
# dataset
- dataset_tr = AudioDataset(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
+ dataset_class = registry_tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset").lower())
+ dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
# dataloader
- batch_sampler = BatchSampler(dataset_tr, **kwargs.get("dataset_conf"), **kwargs.get("dataset_conf").get("batch_conf"))
+ batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
+ batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower())
+ batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler,
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
pin_memory=True)
- if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
- os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
- yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
- OmegaConf.save(config=kwargs, f=yaml_file)
- logging.info("config.yaml is saved to: %s", yaml_file)
+
trainer = Trainer(
model=model,
diff --git a/funasr/datasets/fun_datasets/__init__.py b/funasr/datasets/__init__.py
similarity index 100%
rename from funasr/datasets/fun_datasets/__init__.py
rename to funasr/datasets/__init__.py
diff --git a/funasr/datasets/fun_datasets/__init__.py b/funasr/datasets/audio_datasets/__init__.py
similarity index 100%
copy from funasr/datasets/fun_datasets/__init__.py
copy to funasr/datasets/audio_datasets/__init__.py
diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py
new file mode 100644
index 0000000..353a3a0
--- /dev/null
+++ b/funasr/datasets/audio_datasets/datasets.py
@@ -0,0 +1,83 @@
+import torch
+import json
+import torch.distributed as dist
+import numpy as np
+import kaldiio
+import librosa
+import torchaudio
+import time
+import logging
+
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.utils.register import register_class, registry_tables
+
+@register_class("dataset_classes", "AudioDataset")
+class AudioDataset(torch.utils.data.Dataset):
+ def __init__(self,
+ path,
+ index_ds: str = None,
+ frontend=None,
+ tokenizer=None,
+ int_pad_value: int = -1,
+ float_pad_value: float = 0.0,
+ **kwargs):
+ super().__init__()
+ index_ds_class = registry_tables.index_ds_classes.get(index_ds.lower())
+ self.index_ds = index_ds_class(path)
+ self.frontend = frontend
+ self.fs = 16000 if frontend is None else frontend.fs
+ self.data_type = "sound"
+ self.tokenizer = tokenizer
+
+ self.int_pad_value = int_pad_value
+ self.float_pad_value = float_pad_value
+
+ def get_source_len(self, index):
+ item = self.index_ds[index]
+ return self.index_ds.get_source_len(item)
+
+ def get_target_len(self, index):
+ item = self.index_ds[index]
+ return self.index_ds.get_target_len(item)
+
+ def __len__(self):
+ return len(self.index_ds)
+
+ def __getitem__(self, index):
+ item = self.index_ds[index]
+ # import pdb;
+ # pdb.set_trace()
+ source = item["source"]
+ data_src = load_audio(source, fs=self.fs)
+ speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
+ target = item["target"]
+ ids = self.tokenizer.encode(target)
+ ids_lengths = len(ids)
+ text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
+
+ return {"speech": speech[0, :, :],
+ "speech_lengths": speech_lengths,
+ "text": text,
+ "text_lengths": text_lengths,
+ }
+
+
+ def collator(self, samples: list=None):
+
+
+ outputs = {}
+ for sample in samples:
+ for key in sample.keys():
+ if key not in outputs:
+ outputs[key] = []
+ outputs[key].append(sample[key])
+
+ for key, data_list in outputs.items():
+ if data_list[0].dtype == torch.int64:
+
+ pad_value = self.int_pad_value
+ else:
+ pad_value = self.float_pad_value
+ outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+ return outputs
+
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
new file mode 100644
index 0000000..33b309a
--- /dev/null
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -0,0 +1,64 @@
+import torch
+import json
+import torch.distributed as dist
+import time
+import logging
+
+from funasr.utils.register import register_class
+
+@register_class("index_ds_classes", "IndexDSJsonl")
+class IndexDSJsonl(torch.utils.data.Dataset):
+
+ def __init__(self, path):
+ super().__init__()
+
+ contents = []
+ with open(path, encoding='utf-8') as fin:
+ for line in fin:
+ data = json.loads(line.strip())
+ if "text" in data: # for sft
+ self.contents.append(data['text'])
+ if "source" in data: # for speech lab pretrain
+ prompt = data["prompt"]
+ source = data["source"]
+ target = data["target"]
+ source_len = data["source_len"]
+ target_len = data["target_len"]
+
+ contents.append({"source": source,
+ "prompt": prompt,
+ "target": target,
+ "source_len": source_len,
+ "target_len": target_len,
+ }
+ )
+
+ self.contents = []
+ total_num = len(contents)
+ try:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ except:
+ rank = 0
+ world_size = 1
+ logging.warning("distributed is not initialized, only single shard")
+ num_per_rank = total_num // world_size
+
+ # rank = 0
+ # import ipdb; ipdb.set_trace()
+ self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
+
+ logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
+
+ def __len__(self):
+ return len(self.contents)
+
+ def __getitem__(self, index):
+ return self.contents[index]
+
+ def get_source_len(self, data_dict):
+ return data_dict["source_len"]
+
+ def get_target_len(self, data_dict):
+
+ return data_dict["target_len"] if "target_len" in data_dict else 0
diff --git a/funasr/datasets/fun_datasets/load_audio_extract_fbank.py b/funasr/datasets/audio_datasets/load_audio_extract_fbank.py
similarity index 93%
rename from funasr/datasets/fun_datasets/load_audio_extract_fbank.py
rename to funasr/datasets/audio_datasets/load_audio_extract_fbank.py
index c76f346..c8883ee 100644
--- a/funasr/datasets/fun_datasets/load_audio_extract_fbank.py
+++ b/funasr/datasets/audio_datasets/load_audio_extract_fbank.py
@@ -46,15 +46,16 @@
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
-def extract_fbank(data, data_len = None, date_type: str="sound", frontend=None):
-
+def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None):
+ # import pdb;
+ # pdb.set_trace()
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
- if len(data) < 2:
+ if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, torch.Tensor):
- if len(data) < 2:
+ if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, (list, tuple)):
@@ -67,7 +68,7 @@
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
# import pdb;
# pdb.set_trace()
- if date_type == "sound":
+ if data_type == "sound":
data, data_len = frontend(data, data_len)
if isinstance(data_len, (list, tuple)):
diff --git a/funasr/datasets/fun_datasets/data_sampler.py b/funasr/datasets/audio_datasets/samplers.py
similarity index 61%
rename from funasr/datasets/fun_datasets/data_sampler.py
rename to funasr/datasets/audio_datasets/samplers.py
index 3a19a17..7d3a941 100644
--- a/funasr/datasets/fun_datasets/data_sampler.py
+++ b/funasr/datasets/audio_datasets/samplers.py
@@ -2,31 +2,38 @@
import numpy as np
+from funasr.utils.register import register_class
+
+@register_class("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
class BatchSampler(torch.utils.data.BatchSampler):
- def __init__(self, dataset, batch_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
+ def __init__(self, dataset,
+ batch_type: str="example",
+ batch_size: int=100,
+ buffer_size: int=30,
+ drop_last: bool=False,
+ shuffle: bool=True,
+ **kwargs):
self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.total_samples = len(dataset)
- # self.batch_type = args.batch_type
- # self.batch_size = args.batch_size
- # self.sort_size = args.sort_size
- # self.max_length_token = args.max_length_token
self.batch_type = batch_type
self.batch_size = batch_size
- self.sort_size = sort_size
- self.max_length_token = kwargs.get("max_length_token", 5000)
+ self.buffer_size = buffer_size
+ self.max_token_length = kwargs.get("max_token_length", 5000)
self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle
def __len__(self):
return self.total_samples
-
+
+ def set_epoch(self, epoch):
+ np.random.seed(epoch)
+
def __iter__(self):
- # print("in sampler")
if self.shuffle:
np.random.shuffle(self.shuffle_idx)
@@ -35,31 +42,31 @@
max_token = 0
num_sample = 0
- iter_num = (self.total_samples-1) // self.sort_size + 1
+ iter_num = (self.total_samples-1) // self.buffer_size + 1
# print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
datalen_with_index = []
- for i in range(self.sort_size):
- idx = iter * self.sort_size + i
+ for i in range(self.buffer_size):
+ idx = iter * self.buffer_size + i
if idx >= self.total_samples:
continue
idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
- sample_len_cur = self.dataset.indexed_dataset.get_source_len(self.dataset.indexed_dataset[idx_map]) + \
- self.dataset.indexed_dataset.get_target_len(self.dataset.indexed_dataset[idx_map])
+ sample_len_cur = self.dataset.get_source_len(idx_map) + \
+ self.dataset.get_target_len(idx_map)
datalen_with_index.append([idx, sample_len_cur])
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
for item in datalen_with_index_sort:
idx, sample_len_cur_raw = item
- if sample_len_cur_raw > self.max_length_token:
+ if sample_len_cur_raw > self.max_token_length:
continue
max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
- if self.batch_type == 'token':
+ if self.batch_type == 'length':
max_token_padding *= max_token_cur
if max_token_padding <= self.batch_size:
batch.append(idx)
diff --git a/funasr/datasets/fun_datasets/dataloader_fn.py b/funasr/datasets/fun_datasets/dataloader_fn.py
deleted file mode 100644
index 601cbeb..0000000
--- a/funasr/datasets/fun_datasets/dataloader_fn.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import time
-import torch
-from funasr.datasets.fun_datasets.dataset_jsonl import AudioDataset
-from funasr.datasets.fun_datasets.data_sampler import BatchSampler
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.tokenizer.build_tokenizer import build_tokenizer
-from funasr.tokenizer.token_id_converter import TokenIDConverter
-collate_fn = None
-# collate_fn = collate_fn,
-
-jsonl = "/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl"
-
-frontend = WavFrontend()
-token_type = 'char'
-bpemodel = None
-delimiter = None
-space_symbol = "<space>"
-non_linguistic_symbols = None
-g2p_type = None
-
-tokenizer = build_tokenizer(
- token_type=token_type,
- bpemodel=bpemodel,
- delimiter=delimiter,
- space_symbol=space_symbol,
- non_linguistic_symbols=non_linguistic_symbols,
- g2p_type=g2p_type,
-)
-token_list = "/Users/zhifu/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt"
-unk_symbol = "<unk>"
-
-token_id_converter = TokenIDConverter(
- token_list=token_list,
- unk_symbol=unk_symbol,
-)
-
-dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer, token_id_converter=token_id_converter)
-batch_sampler = BatchSampler(dataset)
-
-
-if __name__ == "__main__":
-
- dataloader_tr = torch.utils.data.DataLoader(dataset,
- collate_fn=dataset.collator,
- batch_sampler=batch_sampler,
- shuffle=False,
- num_workers=0,
- pin_memory=True)
-
- print(len(dataset))
- for i in range(3):
- print(i)
- beg = time.time()
- for j, data in enumerate(dataloader_tr):
- end = time.time()
- time_cost = end - beg
- beg = end
- print(j, time_cost)
- # data_iter = iter(dataloader_tr)
- # data = next(data_iter)
- pass
-
-
diff --git a/funasr/datasets/fun_datasets/dataset_jsonl.py b/funasr/datasets/fun_datasets/dataset_jsonl.py
deleted file mode 100644
index 21df89e..0000000
--- a/funasr/datasets/fun_datasets/dataset_jsonl.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import torch
-import json
-import torch.distributed as dist
-import numpy as np
-import kaldiio
-import librosa
-import torchaudio
-import time
-import logging
-
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
-
-
-
-class IndexedDatasetJsonl(torch.utils.data.Dataset):
-
- def __init__(self, path):
- super().__init__()
-
- contents = []
- with open(path, encoding='utf-8') as fin:
- for line in fin:
- data = json.loads(line.strip())
- if "text" in data: # for sft
- self.contents.append(data['text'])
- if "source" in data: # for speech lab pretrain
- prompt = data["prompt"]
- source = data["source"]
- target = data["target"]
- source_len = data["source_len"]
- target_len = data["target_len"]
-
- contents.append({"source": source,
- "prompt": prompt,
- "target": target,
- "source_len": source_len,
- "target_len": target_len,
- }
- )
-
- self.contents = []
- total_num = len(contents)
- try:
- rank = dist.get_rank()
- world_size = dist.get_world_size()
- except:
- rank = 0
- world_size = 1
- logging.warning("distributed is not initialized, only single shard")
- num_per_rank = total_num // world_size
-
- # rank = 0
- # import ipdb; ipdb.set_trace()
- self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
-
- logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
-
- def __len__(self):
- return len(self.contents)
-
- def __getitem__(self, index):
- return self.contents[index]
-
- def get_source_len(self, data_dict):
- return data_dict["source_len"]
-
- def get_target_len(self, data_dict):
-
- return data_dict["target_len"] if "target_len" in data_dict else 0
-
-
-class AudioDataset(torch.utils.data.Dataset):
- def __init__(self, path, frontend=None, tokenizer=None, int_pad_value: int = -1, float_pad_value: float = 0.0, **kwargs):
- super().__init__()
- self.indexed_dataset = IndexedDatasetJsonl(path)
- self.frontend = frontend.forward
- self.fs = 16000 if frontend is None else frontend.fs
- self.data_type = "sound"
- self.tokenizer = tokenizer
-
- self.int_pad_value = int_pad_value
- self.float_pad_value = float_pad_value
-
-
-
-
- def __len__(self):
- return len(self.indexed_dataset)
-
- def __getitem__(self, index):
- item = self.indexed_dataset[index]
-
- source = item["source"]
- data_src = load_audio(source, fs=self.fs)
- speech, speech_lengths = extract_fbank(data_src, self.data_type, self.frontend) # speech: [b, T, d]
- target = item["target"]
- ids = self.tokenizer.encode(target)
- ids_lengths = len(ids)
- text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
-
- return {"speech": speech[0, :, :],
- "speech_lengths": speech_lengths,
- "text": text,
- "text_lengths": text_lengths,
- }
-
-
- def collator(self, samples: list=None):
-
- # return samples
-
- outputs = {}
- for sample in samples:
- for key in sample.keys():
- if key not in outputs:
- outputs[key] = []
- outputs[key].append(sample[key])
-
- for key, data_list in outputs.items():
- if data_list[0].dtype == torch.int64:
-
- pad_value = self.int_pad_value
- else:
- pad_value = self.float_pad_value
- outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
- return outputs
-
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 24ab797..eeb5d0c 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -22,11 +22,17 @@
kwargs = OmegaConf.merge(cfg, kwargs)
init_param = os.path.join(model_or_path, "model.pb")
kwargs["init_param"] = init_param
- kwargs["token_list"] = os.path.join(model_or_path, "tokens.txt")
+ if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
+ if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+ kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
+ if os.path.exists(os.path.join(model_or_path, "bpe.model")):
+ kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
+
kwargs["model"] = cfg["model"]
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
- return kwargs
+ return OmegaConf.to_container(kwargs, resolve=True)
def get_or_download_model_dir(
model,
diff --git a/funasr/download/runtime_sdk_download_tool.py b/funasr/download/runtime_sdk_download_tool.py
index 91c5844..92416f4 100644
--- a/funasr/download/runtime_sdk_download_tool.py
+++ b/funasr/download/runtime_sdk_download_tool.py
@@ -3,38 +3,43 @@
import argparse
from funasr.utils.types import str2bool
-parser = argparse.ArgumentParser()
-parser.add_argument('--model-name', type=str, required=True)
-parser.add_argument('--export-dir', type=str, required=True)
-parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
-parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
-parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
-parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
-parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
-parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
-parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
-parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
-args = parser.parse_args()
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model-name', type=str, required=True)
+ parser.add_argument('--export-dir', type=str, required=True)
+ parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
+ parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
+ parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
+ parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
+ parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
+ parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
+ parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
+ parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
+ args = parser.parse_args()
+
+ model_dir = args.model_name
+ if not Path(args.model_name).exists():
+ from modelscope.hub.snapshot_download import snapshot_download
+ try:
+ model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
+ except:
+ raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
+ (model_dir)
+ if args.export:
+ model_file = os.path.join(model_dir, 'model.onnx')
+ if args.quantize:
+ model_file = os.path.join(model_dir, 'model_quant.onnx')
+ if not os.path.exists(model_file):
+ print(".onnx is not exist, begin to export onnx")
+ from funasr.bin.export_model import ModelExport
+ export_model = ModelExport(
+ cache_dir=args.export_dir,
+ onnx=True,
+ device="cpu",
+ quant=args.quantize,
+ )
+ export_model.export(model_dir)
-model_dir = args.model_name
-if not Path(args.model_name).exists():
- from modelscope.hub.snapshot_download import snapshot_download
- try:
- model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
- except:
- raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
- (model_dir)
-if args.export:
- model_file = os.path.join(model_dir, 'model.onnx')
- if args.quantize:
- model_file = os.path.join(model_dir, 'model_quant.onnx')
- if not os.path.exists(model_file):
- print(".onnx is not exist, begin to export onnx")
- from funasr.bin.export_model import ModelExport
- export_model = ModelExport(
- cache_dir=args.export_dir,
- onnx=True,
- device="cpu",
- quant=args.quantize,
- )
- export_model.export(model_dir)
\ No newline at end of file
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/funasr/export/README.md b/funasr/export/README.md
deleted file mode 100644
index bbb8bf8..0000000
--- a/funasr/export/README.md
+++ /dev/null
@@ -1,93 +0,0 @@
-# Export models
-
-## Environments
-### Install modelscope and funasr
-
-The installation is the same as [funasr](https://github.com/alibaba-damo-academy/FunASR/blob/main/README.md#installation)
-```shell
-# pip3 install torch torchaudio
-pip install -U modelscope funasr
-# For the users in China, you could install with the command:
-# pip install -U modelscope funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
-```
-### Install the quantization tools
-```shell
-pip install torch-quant # Optional, for torchscript quantization
-pip install onnx onnxruntime # Optional, for onnx quantization
-```
-
-## Usage
- `Tips`: torch>=1.11.0
-
- ```shell
- python -m funasr.export.export_model \
- --model-name [model_name] \
- --export-dir [export_dir] \
- --type [onnx, torch] \
- --quantize [true, false] \
- --fallback-num [fallback_num]
- ```
- `model-name`: the model is to export. It could be the models from modelscope, or local finetuned model(named: model.pb).
-
- `export-dir`: the dir where the onnx is export.
-
- `type`: `onnx` or `torch`, export onnx format model or torchscript format model.
-
- `quantize`: `true`, export quantized model at the same time; `false`, export fp32 model only.
-
- `fallback-num`: specify the number of fallback layers to perform automatic mixed precision quantization.
-
-
-### Export onnx format model
-#### Export model from modelscope
-```shell
-python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize false
-```
-#### Export model from local path
-The model'name must be `model.pb`
-```shell
-python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize false
-```
-#### Test onnx model
-Ref to [test](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export/test)
-
-### Export torchscripts format model
-#### Export model from modelscope
-```shell
-python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch --quantize false
-```
-
-#### Export model from local path
-The model'name must be `model.pb`
-```shell
-python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch --quantize false
-```
-#### Test onnx model
-Ref to [test](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export/test)
-
-## Runtime
-### ONNXRuntime
-#### ONNXRuntime-python
-Ref to [funasr-onnx](../../runtime/python/onnxruntime/README.md)
-#### ONNXRuntime-cpp
-Ref to [docs](../../runtime/readme.md)
-### Libtorch
-#### Libtorch-python
-Ref to [funasr-torch](../../runtime/python/libtorch/README.md)
-#### Libtorch-cpp
-Undo
-## Performance Benchmark
-
-### Paraformer on CPU
-
-[onnx runtime](../../runtime/docs/benchmark_onnx_cpp.md)
-
-[libtorch runtime](../../runtime/docs/benchmark_libtorch.md)
-
-### Paraformer on GPU
-[nv-triton](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/triton_gpu)
-
-
-## Acknowledge
-Torch model quantization is supported by [BladeDISC](https://github.com/alibaba/BladeDISC), an end-to-end DynamIc Shape Compiler project for machine learning workloads. BladeDISC provides general, transparent, and ease of use performance optimization for TensorFlow/PyTorch workloads on GPGPU and CPU backends. If you are interested, please contact us.
-
diff --git a/funasr/export/models/CT_Transformer.py b/funasr/export/models/CT_Transformer.py
deleted file mode 100644
index 2319c4a..0000000
--- a/funasr/export/models/CT_Transformer.py
+++ /dev/null
@@ -1,162 +0,0 @@
-from typing import Tuple
-
-import torch
-import torch.nn as nn
-
-from funasr.models.encoder.sanm_encoder import SANMEncoder
-from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
-from funasr.models.encoder.sanm_encoder import SANMVadEncoder
-from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
-
-class CT_Transformer(nn.Module):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
- https://arxiv.org/pdf/2003.01309.pdf
- """
- def __init__(
- self,
- model,
- max_seq_len=512,
- model_name='punc_model',
- **kwargs,
- ):
- super().__init__()
- onnx = False
- if "onnx" in kwargs:
- onnx = kwargs["onnx"]
- self.embed = model.embed
- self.decoder = model.decoder
- # self.model = model
- self.feats_dim = self.embed.embedding_dim
- self.num_embeddings = self.embed.num_embeddings
- self.model_name = model_name
-
- if isinstance(model.encoder, SANMEncoder):
- self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
- else:
- assert False, "Only support samn encode."
-
- def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
- """Compute loss value from buffer sequences.
-
- Args:
- input (torch.Tensor): Input ids. (batch, len)
- hidden (torch.Tensor): Target ids. (batch, len)
-
- """
- x = self.embed(inputs)
- # mask = self._target_mask(input)
- h, _ = self.encoder(x, text_lengths)
- y = self.decoder(h)
- return y
-
- def get_dummy_inputs(self):
- length = 120
- text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
- text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
- return (text_indexes, text_lengths)
-
- def get_input_names(self):
- return ['inputs', 'text_lengths']
-
- def get_output_names(self):
- return ['logits']
-
- def get_dynamic_axes(self):
- return {
- 'inputs': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
- 'text_lengths': {
- 0: 'batch_size',
- },
- 'logits': {
- 0: 'batch_size',
- 1: 'logits_length'
- },
- }
-
-
-class CT_Transformer_VadRealtime(nn.Module):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
- https://arxiv.org/pdf/2003.01309.pdf
- """
- def __init__(
- self,
- model,
- max_seq_len=512,
- model_name='punc_model',
- **kwargs,
- ):
- super().__init__()
- onnx = False
- if "onnx" in kwargs:
- onnx = kwargs["onnx"]
-
- self.embed = model.embed
- if isinstance(model.encoder, SANMVadEncoder):
- self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
- else:
- assert False, "Only support samn encode."
- self.decoder = model.decoder
- self.model_name = model_name
-
-
-
- def forward(self, inputs: torch.Tensor,
- text_lengths: torch.Tensor,
- vad_indexes: torch.Tensor,
- sub_masks: torch.Tensor,
- ) -> Tuple[torch.Tensor, None]:
- """Compute loss value from buffer sequences.
-
- Args:
- input (torch.Tensor): Input ids. (batch, len)
- hidden (torch.Tensor): Target ids. (batch, len)
-
- """
- x = self.embed(inputs)
- # mask = self._target_mask(input)
- h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
- y = self.decoder(h)
- return y
-
- def with_vad(self):
- return True
-
- def get_dummy_inputs(self):
- length = 120
- text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
- text_lengths = torch.tensor([length], dtype=torch.int32)
- vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
- sub_masks = torch.ones(length, length, dtype=torch.float32)
- sub_masks = torch.tril(sub_masks).type(torch.float32)
- return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
-
- def get_input_names(self):
- return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
-
- def get_output_names(self):
- return ['logits']
-
- def get_dynamic_axes(self):
- return {
- 'inputs': {
- 1: 'feats_length'
- },
- 'vad_masks': {
- 2: 'feats_length1',
- 3: 'feats_length2'
- },
- 'sub_masks': {
- 2: 'feats_length1',
- 3: 'feats_length2'
- },
- 'logits': {
- 1: 'logits_length'
- },
- }
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
deleted file mode 100644
index b7b0889..0000000
--- a/funasr/export/models/__init__.py
+++ /dev/null
@@ -1,43 +0,0 @@
-from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline
-from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
-from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
-# from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
-
-from funasr.models.e2e_vad import E2EVadModel
-from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
-from funasr.models.target_delay_transformer import TargetDelayTransformer
-from funasr.export.models.CT_Transformer import CT_Transformer as CT_Transformer_export
-from funasr.train.abs_model import PunctuationModel
-from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
-from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
-from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_encoder_predictor as ParaformerOnline_encoder_predictor_export
-from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_decoder as ParaformerOnline_decoder_export
-from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_backbone as ContextualParaformer_backbone_export
-from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_embedder as ContextualParaformer_embedder_export
-from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
-
-
-def get_model(model, export_config=None):
- if isinstance(model, NeatContextualParaformer):
- backbone = ContextualParaformer_backbone_export(model, **export_config)
- embedder = ContextualParaformer_embedder_export(model, **export_config)
- return [embedder, backbone]
- elif isinstance(model, BiCifParaformer):
- return BiCifParaformer_export(model, **export_config)
- elif isinstance(model, ParaformerOnline):
- encoder = ParaformerOnline_encoder_predictor_export(model, model_name="model")
- decoder = ParaformerOnline_decoder_export(model, model_name="decoder")
- return [encoder, decoder]
- elif isinstance(model, Paraformer):
- return Paraformer_export(model, **export_config)
- # elif isinstance(model, Conformer_export):
- # return Conformer_export(model, **export_config)
- elif isinstance(model, E2EVadModel):
- return E2EVadModel_export(model, **export_config)
- elif isinstance(model, PunctuationModel):
- if isinstance(model.punc_model, TargetDelayTransformer):
- return CT_Transformer_export(model.punc_model, **export_config)
- elif isinstance(model.punc_model, VadRealtimeTransformer):
- return CT_Transformer_VadRealtime_export(model.punc_model, **export_config)
- else:
- raise "Funasr does not support the given model type currently."
diff --git a/funasr/export/models/decoder/__init__.py b/funasr/export/models/decoder/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/export/models/decoder/__init__.py
+++ /dev/null
diff --git a/funasr/export/models/decoder/contextual_decoder.py b/funasr/export/models/decoder/contextual_decoder.py
deleted file mode 100644
index c6f83d0..0000000
--- a/funasr/export/models/decoder/contextual_decoder.py
+++ /dev/null
@@ -1,191 +0,0 @@
-import os
-import torch
-import torch.nn as nn
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
-from funasr.models.transformer.attention import MultiHeadedAttentionCrossAtt
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
-from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
-
-
-class ContextualSANMDecoder(nn.Module):
- def __init__(self, model,
- max_seq_len=512,
- model_name='decoder',
- onnx: bool = True,):
- super().__init__()
- # self.embed = model.embed #Embedding(model.embed, max_seq_len)
- self.model = model
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- for i, d in enumerate(self.model.decoders):
- if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
- d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
- if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
- d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
- if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
- d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
- self.model.decoders[i] = DecoderLayerSANM_export(d)
-
- if self.model.decoders2 is not None:
- for i, d in enumerate(self.model.decoders2):
- if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
- d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
- if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
- d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
- self.model.decoders2[i] = DecoderLayerSANM_export(d)
-
- for i, d in enumerate(self.model.decoders3):
- if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
- d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
- self.model.decoders3[i] = DecoderLayerSANM_export(d)
-
- self.output_layer = model.output_layer
- self.after_norm = model.after_norm
- self.model_name = model_name
-
- # bias decoder
- if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt):
- self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.bias_decoder.src_attn)
- self.bias_decoder = self.model.bias_decoder
- # last decoder
- if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt):
- self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.last_decoder.src_attn)
- if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder):
- self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoder_export(self.model.last_decoder.self_attn)
- if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM):
- self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANM_export(self.model.last_decoder.feed_forward)
- self.last_decoder = self.model.last_decoder
- self.bias_output = self.model.bias_output
- self.dropout = self.model.dropout
-
-
- def prepare_mask(self, mask):
- mask_3d_btd = mask[:, :, None]
- if len(mask.shape) == 2:
- mask_4d_bhlt = 1 - mask[:, None, None, :]
- elif len(mask.shape) == 3:
- mask_4d_bhlt = 1 - mask[:, None, :]
- mask_4d_bhlt = mask_4d_bhlt * -10000.0
-
- return mask_3d_btd, mask_4d_bhlt
-
- def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- bias_embed: torch.Tensor,
- ):
-
- tgt = ys_in_pad
- tgt_mask = self.make_pad_mask(ys_in_lens)
- tgt_mask, _ = self.prepare_mask(tgt_mask)
- # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
- memory = hs_pad
- memory_mask = self.make_pad_mask(hlens)
- _, memory_mask = self.prepare_mask(memory_mask)
- # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-
- x = tgt
- x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
- x, tgt_mask, memory, memory_mask
- )
-
- _, _, x_self_attn, x_src_attn = self.last_decoder(
- x, tgt_mask, memory, memory_mask
- )
-
- # contextual paraformer related
- contextual_length = torch.Tensor([bias_embed.shape[1]]).int().repeat(hs_pad.shape[0])
- # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
- contextual_mask = self.make_pad_mask(contextual_length)
- contextual_mask, _ = self.prepare_mask(contextual_mask)
- # import pdb; pdb.set_trace()
- contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1)
- cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask)
-
- if self.bias_output is not None:
- x = torch.cat([x_src_attn, cx], dim=2)
- x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
- x = x_self_attn + self.dropout(x)
-
- if self.model.decoders2 is not None:
- x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
- x, tgt_mask, memory, memory_mask
- )
- x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
- x, tgt_mask, memory, memory_mask
- )
- x = self.after_norm(x)
- x = self.output_layer(x)
-
- return x, ys_in_lens
-
-
- def get_dummy_inputs(self, enc_size):
- tgt = torch.LongTensor([0]).unsqueeze(0)
- memory = torch.randn(1, 100, enc_size)
- pre_acoustic_embeds = torch.randn(1, 1, enc_size)
- cache_num = len(self.model.decoders) + len(self.model.decoders2)
- cache = [
- torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
- for _ in range(cache_num)
- ]
- return (tgt, memory, pre_acoustic_embeds, cache)
-
- def is_optimizable(self):
- return True
-
- def get_input_names(self):
- cache_num = len(self.model.decoders) + len(self.model.decoders2)
- return ['tgt', 'memory', 'pre_acoustic_embeds'] \
- + ['cache_%d' % i for i in range(cache_num)]
-
- def get_output_names(self):
- cache_num = len(self.model.decoders) + len(self.model.decoders2)
- return ['y'] \
- + ['out_cache_%d' % i for i in range(cache_num)]
-
- def get_dynamic_axes(self):
- ret = {
- 'tgt': {
- 0: 'tgt_batch',
- 1: 'tgt_length'
- },
- 'memory': {
- 0: 'memory_batch',
- 1: 'memory_length'
- },
- 'pre_acoustic_embeds': {
- 0: 'acoustic_embeds_batch',
- 1: 'acoustic_embeds_length',
- }
- }
- cache_num = len(self.model.decoders) + len(self.model.decoders2)
- ret.update({
- 'cache_%d' % d: {
- 0: 'cache_%d_batch' % d,
- 2: 'cache_%d_length' % d
- }
- for d in range(cache_num)
- })
- return ret
-
- def get_model_config(self, path):
- return {
- "dec_type": "XformerDecoder",
- "model_path": os.path.join(path, f'{self.model_name}.onnx'),
- "n_layers": len(self.model.decoders) + len(self.model.decoders2),
- "odim": self.model.decoders[0].size
- }
diff --git a/funasr/export/models/decoder/sanm_decoder.py b/funasr/export/models/decoder/sanm_decoder.py
deleted file mode 100644
index 8f6e553..0000000
--- a/funasr/export/models/decoder/sanm_decoder.py
+++ /dev/null
@@ -1,314 +0,0 @@
-import os
-
-import torch
-import torch.nn as nn
-
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
-from funasr.models.transformer.attention import MultiHeadedAttentionCrossAtt
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
-from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
-
-
-class ParaformerSANMDecoder(nn.Module):
- def __init__(self, model,
- max_seq_len=512,
- model_name='decoder',
- onnx: bool = True,):
- super().__init__()
- # self.embed = model.embed #Embedding(model.embed, max_seq_len)
- self.model = model
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- for i, d in enumerate(self.model.decoders):
- if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
- d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
- if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
- d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
- if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
- d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
- self.model.decoders[i] = DecoderLayerSANM_export(d)
-
- if self.model.decoders2 is not None:
- for i, d in enumerate(self.model.decoders2):
- if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
- d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
- if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
- d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
- self.model.decoders2[i] = DecoderLayerSANM_export(d)
-
- for i, d in enumerate(self.model.decoders3):
- if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
- d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
- self.model.decoders3[i] = DecoderLayerSANM_export(d)
-
- self.output_layer = model.output_layer
- self.after_norm = model.after_norm
- self.model_name = model_name
-
-
- def prepare_mask(self, mask):
- mask_3d_btd = mask[:, :, None]
- if len(mask.shape) == 2:
- mask_4d_bhlt = 1 - mask[:, None, None, :]
- elif len(mask.shape) == 3:
- mask_4d_bhlt = 1 - mask[:, None, :]
- mask_4d_bhlt = mask_4d_bhlt * -10000.0
-
- return mask_3d_btd, mask_4d_bhlt
-
- def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- ):
-
- tgt = ys_in_pad
- tgt_mask = self.make_pad_mask(ys_in_lens)
- tgt_mask, _ = self.prepare_mask(tgt_mask)
- # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
- memory = hs_pad
- memory_mask = self.make_pad_mask(hlens)
- _, memory_mask = self.prepare_mask(memory_mask)
- # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-
- x = tgt
- x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
- x, tgt_mask, memory, memory_mask
- )
- if self.model.decoders2 is not None:
- x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
- x, tgt_mask, memory, memory_mask
- )
- x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
- x, tgt_mask, memory, memory_mask
- )
- x = self.after_norm(x)
- x = self.output_layer(x)
-
- return x, ys_in_lens
-
-
- def get_dummy_inputs(self, enc_size):
- tgt = torch.LongTensor([0]).unsqueeze(0)
- memory = torch.randn(1, 100, enc_size)
- pre_acoustic_embeds = torch.randn(1, 1, enc_size)
- cache_num = len(self.model.decoders) + len(self.model.decoders2)
- cache = [
- torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
- for _ in range(cache_num)
- ]
- return (tgt, memory, pre_acoustic_embeds, cache)
-
- def is_optimizable(self):
- return True
-
- def get_input_names(self):
- cache_num = len(self.model.decoders) + len(self.model.decoders2)
- return ['tgt', 'memory', 'pre_acoustic_embeds'] \
- + ['cache_%d' % i for i in range(cache_num)]
-
- def get_output_names(self):
- cache_num = len(self.model.decoders) + len(self.model.decoders2)
- return ['y'] \
- + ['out_cache_%d' % i for i in range(cache_num)]
-
- def get_dynamic_axes(self):
- ret = {
- 'tgt': {
- 0: 'tgt_batch',
- 1: 'tgt_length'
- },
- 'memory': {
- 0: 'memory_batch',
- 1: 'memory_length'
- },
- 'pre_acoustic_embeds': {
- 0: 'acoustic_embeds_batch',
- 1: 'acoustic_embeds_length',
- }
- }
- cache_num = len(self.model.decoders) + len(self.model.decoders2)
- ret.update({
- 'cache_%d' % d: {
- 0: 'cache_%d_batch' % d,
- 2: 'cache_%d_length' % d
- }
- for d in range(cache_num)
- })
- return ret
-
- def get_model_config(self, path):
- return {
- "dec_type": "XformerDecoder",
- "model_path": os.path.join(path, f'{self.model_name}.onnx'),
- "n_layers": len(self.model.decoders) + len(self.model.decoders2),
- "odim": self.model.decoders[0].size
- }
-
-
-class ParaformerSANMDecoderOnline(nn.Module):
- def __init__(self, model,
- max_seq_len=512,
- model_name='decoder',
- onnx: bool = True, ):
- super().__init__()
- # self.embed = model.embed #Embedding(model.embed, max_seq_len)
- self.model = model
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- for i, d in enumerate(self.model.decoders):
- if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
- d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
- if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
- d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
- if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
- d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
- self.model.decoders[i] = DecoderLayerSANM_export(d)
-
- if self.model.decoders2 is not None:
- for i, d in enumerate(self.model.decoders2):
- if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
- d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
- if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
- d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
- self.model.decoders2[i] = DecoderLayerSANM_export(d)
-
- for i, d in enumerate(self.model.decoders3):
- if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
- d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
- self.model.decoders3[i] = DecoderLayerSANM_export(d)
-
- self.output_layer = model.output_layer
- self.after_norm = model.after_norm
- self.model_name = model_name
-
- def prepare_mask(self, mask):
- mask_3d_btd = mask[:, :, None]
- if len(mask.shape) == 2:
- mask_4d_bhlt = 1 - mask[:, None, None, :]
- elif len(mask.shape) == 3:
- mask_4d_bhlt = 1 - mask[:, None, :]
- mask_4d_bhlt = mask_4d_bhlt * -10000.0
-
- return mask_3d_btd, mask_4d_bhlt
-
- def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- *args,
- ):
-
- tgt = ys_in_pad
- tgt_mask = self.make_pad_mask(ys_in_lens)
- tgt_mask, _ = self.prepare_mask(tgt_mask)
- # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
- memory = hs_pad
- memory_mask = self.make_pad_mask(hlens)
- _, memory_mask = self.prepare_mask(memory_mask)
- # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-
- x = tgt
- out_caches = list()
- for i, decoder in enumerate(self.model.decoders):
- in_cache = args[i]
- x, tgt_mask, memory, memory_mask, out_cache = decoder(
- x, tgt_mask, memory, memory_mask, cache=in_cache
- )
- out_caches.append(out_cache)
- if self.model.decoders2 is not None:
- for i, decoder in enumerate(self.model.decoders2):
- in_cache = args[i+len(self.model.decoders)]
- x, tgt_mask, memory, memory_mask, out_cache = decoder(
- x, tgt_mask, memory, memory_mask, cache=in_cache
- )
- out_caches.append(out_cache)
- x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
- x, tgt_mask, memory, memory_mask
- )
- x = self.after_norm(x)
- x = self.output_layer(x)
-
- return x, out_caches
-
- def get_dummy_inputs(self, enc_size):
- enc = torch.randn(2, 100, enc_size).type(torch.float32)
- enc_len = torch.tensor([30, 100], dtype=torch.int32)
- acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
- acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
- cache_num = len(self.model.decoders)
- if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
- cache_num += len(self.model.decoders2)
- cache = [
- torch.zeros((2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size-1), dtype=torch.float32)
- for _ in range(cache_num)
- ]
- return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
-
- def get_input_names(self):
- cache_num = len(self.model.decoders)
- if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
- cache_num += len(self.model.decoders2)
- return ['enc', 'enc_len', 'acoustic_embeds', 'acoustic_embeds_len'] \
- + ['in_cache_%d' % i for i in range(cache_num)]
-
- def get_output_names(self):
- cache_num = len(self.model.decoders)
- if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
- cache_num += len(self.model.decoders2)
- return ['logits', 'sample_ids'] \
- + ['out_cache_%d' % i for i in range(cache_num)]
-
- def get_dynamic_axes(self):
- ret = {
- 'enc': {
- 0: 'batch_size',
- 1: 'enc_length'
- },
- 'acoustic_embeds': {
- 0: 'batch_size',
- 1: 'token_length'
- },
- 'enc_len': {
- 0: 'batch_size',
- },
- 'acoustic_embeds_len': {
- 0: 'batch_size',
- },
-
- }
- cache_num = len(self.model.decoders)
- if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
- cache_num += len(self.model.decoders2)
- ret.update({
- 'in_cache_%d' % d: {
- 0: 'batch_size',
- }
- for d in range(cache_num)
- })
- ret.update({
- 'out_cache_%d' % d: {
- 0: 'batch_size',
- }
- for d in range(cache_num)
- })
- return ret
diff --git a/funasr/export/models/decoder/transformer_decoder.py b/funasr/export/models/decoder/transformer_decoder.py
deleted file mode 100644
index 727f5d7..0000000
--- a/funasr/export/models/decoder/transformer_decoder.py
+++ /dev/null
@@ -1,143 +0,0 @@
-import os
-from funasr.export import models
-
-import torch
-import torch.nn as nn
-
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
-from funasr.models.transformer.attention import MultiHeadedAttentionCrossAtt, MultiHeadedAttention
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
-from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
-from funasr.export.models.modules.decoder_layer import DecoderLayer as DecoderLayer_export
-
-
-class ParaformerDecoderSAN(nn.Module):
- def __init__(self, model,
- max_seq_len=512,
- model_name='decoder',
- onnx: bool = True,):
- super().__init__()
- # self.embed = model.embed #Embedding(model.embed, max_seq_len)
- self.model = model
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- for i, d in enumerate(self.model.decoders):
- if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
- d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
- if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
- d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
- # if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
- # d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
- if isinstance(d.src_attn, MultiHeadedAttention):
- d.src_attn = OnnxMultiHeadedAttention(d.src_attn)
- self.model.decoders[i] = DecoderLayer_export(d)
-
- self.output_layer = model.output_layer
- self.after_norm = model.after_norm
- self.model_name = model_name
-
-
- def prepare_mask(self, mask):
- mask_3d_btd = mask[:, :, None]
- if len(mask.shape) == 2:
- mask_4d_bhlt = 1 - mask[:, None, None, :]
- elif len(mask.shape) == 3:
- mask_4d_bhlt = 1 - mask[:, None, :]
- mask_4d_bhlt = mask_4d_bhlt * -10000.0
-
- return mask_3d_btd, mask_4d_bhlt
-
- def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- ):
-
- tgt = ys_in_pad
- tgt_mask = self.make_pad_mask(ys_in_lens)
- tgt_mask, _ = self.prepare_mask(tgt_mask)
- # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
- memory = hs_pad
- memory_mask = self.make_pad_mask(hlens)
- _, memory_mask = self.prepare_mask(memory_mask)
- # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-
- x = tgt
- x, tgt_mask, memory, memory_mask = self.model.decoders(
- x, tgt_mask, memory, memory_mask
- )
- x = self.after_norm(x)
- x = self.output_layer(x)
-
- return x, ys_in_lens
-
-
- def get_dummy_inputs(self, enc_size):
- tgt = torch.LongTensor([0]).unsqueeze(0)
- memory = torch.randn(1, 100, enc_size)
- pre_acoustic_embeds = torch.randn(1, 1, enc_size)
- cache_num = len(self.model.decoders) + len(self.model.decoders2)
- cache = [
- torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
- for _ in range(cache_num)
- ]
- return (tgt, memory, pre_acoustic_embeds, cache)
-
- def is_optimizable(self):
- return True
-
- def get_input_names(self):
- cache_num = len(self.model.decoders) + len(self.model.decoders2)
- return ['tgt', 'memory', 'pre_acoustic_embeds'] \
- + ['cache_%d' % i for i in range(cache_num)]
-
- def get_output_names(self):
- cache_num = len(self.model.decoders) + len(self.model.decoders2)
- return ['y'] \
- + ['out_cache_%d' % i for i in range(cache_num)]
-
- def get_dynamic_axes(self):
- ret = {
- 'tgt': {
- 0: 'tgt_batch',
- 1: 'tgt_length'
- },
- 'memory': {
- 0: 'memory_batch',
- 1: 'memory_length'
- },
- 'pre_acoustic_embeds': {
- 0: 'acoustic_embeds_batch',
- 1: 'acoustic_embeds_length',
- }
- }
- cache_num = len(self.model.decoders) + len(self.model.decoders2)
- ret.update({
- 'cache_%d' % d: {
- 0: 'cache_%d_batch' % d,
- 2: 'cache_%d_length' % d
- }
- for d in range(cache_num)
- })
- return ret
-
- def get_model_config(self, path):
- return {
- "dec_type": "XformerDecoder",
- "model_path": os.path.join(path, f'{self.model_name}.onnx'),
- "n_layers": len(self.model.decoders) + len(self.model.decoders2),
- "odim": self.model.decoders[0].size
- }
\ No newline at end of file
diff --git a/funasr/export/models/decoder/xformer_decoder.py b/funasr/export/models/decoder/xformer_decoder.py
deleted file mode 100644
index 9215c00..0000000
--- a/funasr/export/models/decoder/xformer_decoder.py
+++ /dev/null
@@ -1,121 +0,0 @@
-import os
-
-import torch
-import torch.nn as nn
-
-from funasr.models.transformer.attention import MultiHeadedAttention
-
-from funasr.export.models.modules.decoder_layer import DecoderLayer as OnnxDecoderLayer
-from funasr.export.models.language_models.embed import Embedding
-from funasr.export.models.modules.multihead_att import \
- OnnxMultiHeadedAttention
-
-from funasr.export.utils.torch_function import MakePadMask, subsequent_mask
-
-class XformerDecoder(nn.Module):
- def __init__(self,
- model,
- max_seq_len = 512,
- model_name = 'decoder',
- onnx: bool = True,):
- super().__init__()
- self.embed = Embedding(model.embed, max_seq_len)
- self.model = model
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = subsequent_mask(max_seq_len, flip=False)
-
- if isinstance(self.model.decoders[0].self_attn, MultiHeadedAttention):
- self.num_heads = self.model.decoders[0].self_attn.h
- self.hidden_size = self.model.decoders[0].self_attn.linear_out.out_features
-
- # replace multi-head attention module into customized module.
- for i, d in enumerate(self.model.decoders):
- # d is DecoderLayer
- if isinstance(d.self_attn, MultiHeadedAttention):
- d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
- if isinstance(d.src_attn, MultiHeadedAttention):
- d.src_attn = OnnxMultiHeadedAttention(d.src_attn)
- self.model.decoders[i] = OnnxDecoderLayer(d)
-
- self.model_name = model_name
-
- def prepare_mask(self, mask):
- mask_3d_btd = mask[:, :, None]
- if len(mask.shape) == 2:
- mask_4d_bhlt = 1 - mask[:, None, None, :]
- elif len(mask.shape) == 3:
- mask_4d_bhlt = 1 - mask[:, None, :]
-
- mask_4d_bhlt = mask_4d_bhlt * -10000.0
- return mask_3d_btd, mask_4d_bhlt
-
- def forward(self,
- tgt,
- memory,
- cache):
-
- mask = subsequent_mask(tgt.size(-1)).unsqueeze(0) # (B, T)
-
- x = self.embed(tgt)
- mask = self.prepare_mask(mask)
- new_cache = []
- for c, decoder in zip(cache, self.model.decoders):
- x, mask = decoder(x, mask, memory, None, c)
- new_cache.append(x)
- x = x[:, 1:, :]
-
- if self.model.normalize_before:
- y = self.model.after_norm(x[:, -1])
- else:
- y = x[:, -1]
-
- if self.model.output_layer is not None:
- y = torch.log_softmax(self.model.output_layer(y), dim=-1)
- return y, new_cache
-
- def get_dummy_inputs(self, enc_size):
- tgt = torch.LongTensor([0]).unsqueeze(0)
- memory = torch.randn(1, 100, enc_size)
- cache_num = len(self.model.decoders)
- cache = [
- torch.zeros((1, 1, self.model.decoders[0].size))
- for _ in range(cache_num)
- ]
- return (tgt, memory, cache)
-
- def is_optimizable(self):
- return True
-
- def get_input_names(self):
- cache_num = len(self.model.decoders)
- return ["tgt", "memory"] + [
- "cache_%d" % i for i in range(cache_num)
- ]
-
- def get_output_names(self):
- cache_num = len(self.model.decoders)
- return ["y"] + ["out_cache_%d" % i for i in range(cache_num)]
-
- def get_dynamic_axes(self):
- ret = {
- "tgt": {0: "tgt_batch", 1: "tgt_length"},
- "memory": {0: "memory_batch", 1: "memory_length"},
- }
- cache_num = len(self.model.decoders)
- ret.update(
- {
- "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d}
- for d in range(cache_num)
- }
- )
- return ret
-
- def get_model_config(self, path):
- return {
- "dec_type": "XformerDecoder",
- "model_path": os.path.join(path, f"{self.model_name}.onnx"),
- "n_layers": len(self.model.decoders),
- "odim": self.model.decoders[0].size,
- }
diff --git a/funasr/export/models/e2e_asr_contextual_paraformer.py b/funasr/export/models/e2e_asr_contextual_paraformer.py
deleted file mode 100644
index 0a3eba6..0000000
--- a/funasr/export/models/e2e_asr_contextual_paraformer.py
+++ /dev/null
@@ -1,174 +0,0 @@
-from audioop import bias
-import logging
-import torch
-import torch.nn as nn
-import numpy as np
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
-from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
-from funasr.models.predictor.cif import CifPredictorV2
-from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
-from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
-from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
-from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
-from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
-from funasr.export.models.decoder.contextual_decoder import ContextualSANMDecoder as ContextualSANMDecoder_export
-from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
-
-
-class ContextualParaformer_backbone(nn.Module):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
-
- def __init__(
- self,
- model,
- max_seq_len=512,
- feats_dim=560,
- model_name='model',
- **kwargs,
- ):
- super().__init__()
- onnx = False
- if "onnx" in kwargs:
- onnx = kwargs["onnx"]
- if isinstance(model.encoder, SANMEncoder):
- self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
- elif isinstance(model.encoder, ConformerEncoder):
- self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
- if isinstance(model.predictor, CifPredictorV2):
- self.predictor = CifPredictorV2_export(model.predictor)
-
- # decoder
- if isinstance(model.decoder, ContextualParaformerDecoder):
- self.decoder = ContextualSANMDecoder_export(model.decoder, onnx=onnx)
- elif isinstance(model.decoder, ParaformerSANMDecoder):
- self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
- elif isinstance(model.decoder, ParaformerDecoderSAN):
- self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
-
- self.feats_dim = feats_dim
- self.model_name = model_name
-
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- bias_embed: torch.Tensor,
- ):
- # a. To device
- batch = {"speech": speech, "speech_lengths": speech_lengths}
- # batch = to_device(batch, device=self.device)
-
- enc, enc_len = self.encoder(**batch)
- mask = self.make_pad_mask(enc_len)[:, None, :]
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
- pre_token_length = pre_token_length.floor().type(torch.int32)
-
- # bias_embed = bias_embed. squeeze(0).repeat([enc.shape[0], 1, 1])
-
- decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, bias_embed)
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- # sample_ids = decoder_out.argmax(dim=-1)
- return decoder_out, pre_token_length
-
- def get_dummy_inputs(self):
- speech = torch.randn(2, 30, self.feats_dim)
- speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
- bias_embed = torch.randn(2, 1, 512)
- return (speech, speech_lengths, bias_embed)
-
- def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
- import numpy as np
- fbank = np.loadtxt(txt_file)
- fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
- speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
- speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
- return (speech, speech_lengths)
-
- def get_input_names(self):
- return ['speech', 'speech_lengths', 'bias_embed']
-
- def get_output_names(self):
- return ['logits', 'token_num']
-
- def get_dynamic_axes(self):
- return {
- 'speech': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
- 'speech_lengths': {
- 0: 'batch_size',
- },
- 'bias_embed': {
- 0: 'batch_size',
- 1: 'num_hotwords'
- },
- 'logits': {
- 0: 'batch_size',
- 1: 'logits_length'
- },
- }
-
-
-class ContextualParaformer_embedder(nn.Module):
- def __init__(self,
- model,
- max_seq_len=512,
- feats_dim=560,
- model_name='model',
- **kwargs,):
- super().__init__()
- self.embedding = model.bias_embed
- model.bias_encoder.batch_first = False
- self.bias_encoder = model.bias_encoder
- # self.bias_encoder.batch_first = False
- self.feats_dim = feats_dim
- self.model_name = "{}_eb".format(model_name)
-
- def forward(self, hotword):
- hotword = self.embedding(hotword).transpose(0, 1) # batch second
- hw_embed, (_, _) = self.bias_encoder(hotword)
- return hw_embed
-
- def get_dummy_inputs(self):
- hotword = torch.tensor([
- [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
- [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
- [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- ],
- dtype=torch.int32)
- # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
- return (hotword)
-
- def get_input_names(self):
- return ['hotword']
-
- def get_output_names(self):
- return ['hw_embed']
-
- def get_dynamic_axes(self):
- return {
- 'hotword': {
- 0: 'num_hotwords',
- },
- 'hw_embed': {
- 0: 'num_hotwords',
- },
- }
\ No newline at end of file
diff --git a/funasr/export/models/e2e_asr_paraformer.py b/funasr/export/models/e2e_asr_paraformer.py
deleted file mode 100644
index 5697b77..0000000
--- a/funasr/export/models/e2e_asr_paraformer.py
+++ /dev/null
@@ -1,366 +0,0 @@
-import logging
-import torch
-import torch.nn as nn
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
-from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
-from funasr.models.predictor.cif import CifPredictorV2, CifPredictorV3
-from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
-from funasr.export.models.predictor.cif import CifPredictorV3 as CifPredictorV3_export
-from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
-from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
-from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
-from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
-from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoderOnline as ParaformerSANMDecoderOnline_export
-
-
-class Paraformer(nn.Module):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
-
- def __init__(
- self,
- model,
- max_seq_len=512,
- feats_dim=560,
- model_name='model',
- **kwargs,
- ):
- super().__init__()
- onnx = False
- if "onnx" in kwargs:
- onnx = kwargs["onnx"]
- if isinstance(model.encoder, SANMEncoder):
- self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
- elif isinstance(model.encoder, ConformerEncoder):
- self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
- if isinstance(model.predictor, CifPredictorV2):
- self.predictor = CifPredictorV2_export(model.predictor)
- if isinstance(model.decoder, ParaformerSANMDecoder):
- self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
- elif isinstance(model.decoder, ParaformerDecoderSAN):
- self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
-
- self.feats_dim = feats_dim
- self.model_name = model_name
-
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- ):
- # a. To device
- batch = {"speech": speech, "speech_lengths": speech_lengths}
- # batch = to_device(batch, device=self.device)
-
- enc, enc_len = self.encoder(**batch)
- mask = self.make_pad_mask(enc_len)[:, None, :]
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
- pre_token_length = pre_token_length.floor().type(torch.int32)
-
- decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- # sample_ids = decoder_out.argmax(dim=-1)
-
- return decoder_out, pre_token_length
-
- def get_dummy_inputs(self):
- speech = torch.randn(2, 30, self.feats_dim)
- speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
- return (speech, speech_lengths)
-
- def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
- import numpy as np
- fbank = np.loadtxt(txt_file)
- fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
- speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
- speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
- return (speech, speech_lengths)
-
- def get_input_names(self):
- return ['speech', 'speech_lengths']
-
- def get_output_names(self):
- return ['logits', 'token_num']
-
- def get_dynamic_axes(self):
- return {
- 'speech': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
- 'speech_lengths': {
- 0: 'batch_size',
- },
- 'logits': {
- 0: 'batch_size',
- 1: 'logits_length'
- },
- }
-
-
-class BiCifParaformer(nn.Module):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
-
- def __init__(
- self,
- model,
- max_seq_len=512,
- feats_dim=560,
- model_name='model',
- **kwargs,
- ):
- super().__init__()
- onnx = False
- if "onnx" in kwargs:
- onnx = kwargs["onnx"]
- if isinstance(model.encoder, SANMEncoder):
- self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
- elif isinstance(model.encoder, ConformerEncoder):
- self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
- else:
- logging.warning("Unsupported encoder type to export.")
- if isinstance(model.predictor, CifPredictorV3):
- self.predictor = CifPredictorV3_export(model.predictor)
- else:
- logging.warning("Wrong predictor type to export.")
- if isinstance(model.decoder, ParaformerSANMDecoder):
- self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
- elif isinstance(model.decoder, ParaformerDecoderSAN):
- self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
- else:
- logging.warning("Unsupported decoder type to export.")
-
- self.feats_dim = feats_dim
- self.model_name = model_name
-
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- ):
- # a. To device
- batch = {"speech": speech, "speech_lengths": speech_lengths}
- # batch = to_device(batch, device=self.device)
-
- enc, enc_len = self.encoder(**batch)
- mask = self.make_pad_mask(enc_len)[:, None, :]
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
- pre_token_length = pre_token_length.round().type(torch.int32)
-
- decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
-
- # get predicted timestamps
- us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
-
- return decoder_out, pre_token_length, us_alphas, us_cif_peak
-
- def get_dummy_inputs(self):
- speech = torch.randn(2, 30, self.feats_dim)
- speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
- return (speech, speech_lengths)
-
- def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
- import numpy as np
- fbank = np.loadtxt(txt_file)
- fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
- speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
- speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
- return (speech, speech_lengths)
-
- def get_input_names(self):
- return ['speech', 'speech_lengths']
-
- def get_output_names(self):
- return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
-
- def get_dynamic_axes(self):
- return {
- 'speech': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
- 'speech_lengths': {
- 0: 'batch_size',
- },
- 'logits': {
- 0: 'batch_size',
- 1: 'logits_length'
- },
- 'us_alphas': {
- 0: 'batch_size',
- 1: 'alphas_length'
- },
- 'us_cif_peak': {
- 0: 'batch_size',
- 1: 'alphas_length'
- },
- }
-
-
-class ParaformerOnline_encoder_predictor(nn.Module):
- """
- Author: Speech Lab, Alibaba Group, China
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
-
- def __init__(
- self,
- model,
- max_seq_len=512,
- feats_dim=560,
- model_name='model',
- **kwargs,
- ):
- super().__init__()
- onnx = False
- if "onnx" in kwargs:
- onnx = kwargs["onnx"]
- if isinstance(model.encoder, SANMEncoder) or isinstance(model.encoder, SANMEncoderChunkOpt):
- self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
- elif isinstance(model.encoder, ConformerEncoder):
- self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
- if isinstance(model.predictor, CifPredictorV2):
- self.predictor = CifPredictorV2_export(model.predictor)
-
- self.feats_dim = feats_dim
- self.model_name = model_name
-
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- ):
- # a. To device
- batch = {"speech": speech, "speech_lengths": speech_lengths, "online": True}
- # batch = to_device(batch, device=self.device)
-
- enc, enc_len = self.encoder(**batch)
- mask = self.make_pad_mask(enc_len)[:, None, :]
- alphas, _ = self.predictor.forward_cnn(enc, mask)
-
- return enc, enc_len, alphas
-
- def get_dummy_inputs(self):
- speech = torch.randn(2, 30, self.feats_dim)
- speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
- return (speech, speech_lengths)
-
- def get_input_names(self):
- return ['speech', 'speech_lengths']
-
- def get_output_names(self):
- return ['enc', 'enc_len', 'alphas']
-
- def get_dynamic_axes(self):
- return {
- 'speech': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
- 'speech_lengths': {
- 0: 'batch_size',
- },
- 'enc': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
- 'enc_len': {
- 0: 'batch_size',
- },
- 'alphas': {
- 0: 'batch_size',
- 1: 'feats_length'
- },
- }
-
-
-class ParaformerOnline_decoder(nn.Module):
- """
- Author: Speech Lab, Alibaba Group, China
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
-
- def __init__(
- self,
- model,
- max_seq_len=512,
- feats_dim=560,
- model_name='model',
- **kwargs,
- ):
- super().__init__()
- onnx = False
- if "onnx" in kwargs:
- onnx = kwargs["onnx"]
-
- if isinstance(model.decoder, ParaformerDecoderSAN):
- self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
- elif isinstance(model.decoder, ParaformerSANMDecoder):
- self.decoder = ParaformerSANMDecoderOnline_export(model.decoder, onnx=onnx)
-
- self.feats_dim = feats_dim
- self.model_name = model_name
- self.enc_size = model.encoder._output_size
-
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- def forward(
- self,
- enc: torch.Tensor,
- enc_len: torch.Tensor,
- acoustic_embeds: torch.Tensor,
- acoustic_embeds_len: torch.Tensor,
- *args,
- ):
- decoder_out, out_caches = self.decoder(enc, enc_len, acoustic_embeds, acoustic_embeds_len, *args)
- sample_ids = decoder_out.argmax(dim=-1)
-
- return decoder_out, sample_ids, out_caches
-
- def get_dummy_inputs(self, ):
- dummy_inputs = self.decoder.get_dummy_inputs(enc_size=self.enc_size)
- return dummy_inputs
-
- def get_input_names(self):
-
- return self.decoder.get_input_names()
-
- def get_output_names(self):
-
- return self.decoder.get_output_names()
-
- def get_dynamic_axes(self):
- return self.decoder.get_dynamic_axes()
diff --git a/funasr/export/models/e2e_vad.py b/funasr/export/models/e2e_vad.py
deleted file mode 100644
index d3e8f30..0000000
--- a/funasr/export/models/e2e_vad.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from enum import Enum
-from typing import List, Tuple, Dict, Any
-
-import torch
-from torch import nn
-import math
-
-from funasr.models.encoder.fsmn_encoder import FSMN
-from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
-
-class E2EVadModel(nn.Module):
- def __init__(self, model,
- max_seq_len=512,
- feats_dim=400,
- model_name='model',
- **kwargs,):
- super(E2EVadModel, self).__init__()
- self.feats_dim = feats_dim
- self.max_seq_len = max_seq_len
- self.model_name = model_name
- if isinstance(model.encoder, FSMN):
- self.encoder = FSMN_export(model.encoder)
- else:
- raise "unsupported encoder"
-
-
- def forward(self, feats: torch.Tensor, *args, ):
-
- scores, out_caches = self.encoder(feats, *args)
- return scores, out_caches
-
- def get_dummy_inputs(self, frame=30):
- speech = torch.randn(1, frame, self.feats_dim)
- in_cache0 = torch.randn(1, 128, 19, 1)
- in_cache1 = torch.randn(1, 128, 19, 1)
- in_cache2 = torch.randn(1, 128, 19, 1)
- in_cache3 = torch.randn(1, 128, 19, 1)
-
- return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
-
- # def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
- # import numpy as np
- # fbank = np.loadtxt(txt_file)
- # fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
- # speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
- # speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
- # return (speech, speech_lengths)
-
- def get_input_names(self):
- return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
-
- def get_output_names(self):
- return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
-
- def get_dynamic_axes(self):
- return {
- 'speech': {
- 1: 'feats_length'
- },
- }
diff --git a/funasr/export/models/encoder/__init__.py b/funasr/export/models/encoder/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/export/models/encoder/__init__.py
+++ /dev/null
diff --git a/funasr/export/models/encoder/conformer_encoder.py b/funasr/export/models/encoder/conformer_encoder.py
deleted file mode 100644
index 76719c5..0000000
--- a/funasr/export/models/encoder/conformer_encoder.py
+++ /dev/null
@@ -1,105 +0,0 @@
-import torch
-import torch.nn as nn
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.transformer.attention import MultiHeadedAttentionSANM
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
-from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
-from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as EncoderLayerConformer_export
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
-from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
-from funasr.export.models.encoder.sanm_encoder import SANMEncoder
-from funasr.models.transformer.attention import RelPositionMultiHeadedAttention
-# from funasr.export.models.modules.multihead_att import RelPositionMultiHeadedAttention as RelPositionMultiHeadedAttention_export
-from funasr.export.models.modules.multihead_att import OnnxRelPosMultiHeadedAttention as RelPositionMultiHeadedAttention_export
-
-
-class ConformerEncoder(nn.Module):
- def __init__(
- self,
- model,
- max_seq_len=512,
- feats_dim=560,
- model_name='encoder',
- onnx: bool = True,
- ):
- super().__init__()
- self.embed = model.embed
- self.model = model
- self.feats_dim = feats_dim
- self._output_size = model._output_size
-
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- for i, d in enumerate(self.model.encoders):
- if isinstance(d.self_attn, MultiHeadedAttentionSANM):
- d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
- if isinstance(d.self_attn, RelPositionMultiHeadedAttention):
- d.self_attn = RelPositionMultiHeadedAttention_export(d.self_attn)
- if isinstance(d.feed_forward, PositionwiseFeedForward):
- d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
- self.model.encoders[i] = EncoderLayerConformer_export(d)
-
- self.model_name = model_name
- self.num_heads = model.encoders[0].self_attn.h
- self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
-
-
- def prepare_mask(self, mask):
- if len(mask.shape) == 2:
- mask = 1 - mask[:, None, None, :]
- elif len(mask.shape) == 3:
- mask = 1 - mask[:, None, :]
-
- return mask * -10000.0
-
- def forward(self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- ):
- mask = self.make_pad_mask(speech_lengths)
- mask = self.prepare_mask(mask)
- if self.embed is None:
- xs_pad = speech
- else:
- xs_pad = self.embed(speech)
-
- encoder_outs = self.model.encoders(xs_pad, mask)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
- if isinstance(xs_pad, tuple):
- xs_pad = xs_pad[0]
- xs_pad = self.model.after_norm(xs_pad)
-
- return xs_pad, speech_lengths
-
- def get_output_size(self):
- return self.model.encoders[0].size
-
- def get_dummy_inputs(self):
- feats = torch.randn(1, 100, self.feats_dim)
- return (feats)
-
- def get_input_names(self):
- return ['feats']
-
- def get_output_names(self):
- return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
-
- def get_dynamic_axes(self):
- return {
- 'feats': {
- 1: 'feats_length'
- },
- 'encoder_out': {
- 1: 'enc_out_length'
- },
- 'predictor_weight':{
- 1: 'pre_out_length'
- }
-
- }
diff --git a/funasr/export/models/encoder/fsmn_encoder.py b/funasr/export/models/encoder/fsmn_encoder.py
deleted file mode 100755
index b8e6433..0000000
--- a/funasr/export/models/encoder/fsmn_encoder.py
+++ /dev/null
@@ -1,296 +0,0 @@
-from typing import Tuple, Dict
-import copy
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from funasr.models.encoder.fsmn_encoder import BasicBlock
-
-class LinearTransform(nn.Module):
-
- def __init__(self, input_dim, output_dim):
- super(LinearTransform, self).__init__()
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.linear = nn.Linear(input_dim, output_dim, bias=False)
-
- def forward(self, input):
- output = self.linear(input)
-
- return output
-
-
-class AffineTransform(nn.Module):
-
- def __init__(self, input_dim, output_dim):
- super(AffineTransform, self).__init__()
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.linear = nn.Linear(input_dim, output_dim)
-
- def forward(self, input):
- output = self.linear(input)
-
- return output
-
-
-class RectifiedLinear(nn.Module):
-
- def __init__(self, input_dim, output_dim):
- super(RectifiedLinear, self).__init__()
- self.dim = input_dim
- self.relu = nn.ReLU()
- self.dropout = nn.Dropout(0.1)
-
- def forward(self, input):
- out = self.relu(input)
- return out
-
-
-class FSMNBlock(nn.Module):
-
- def __init__(
- self,
- input_dim: int,
- output_dim: int,
- lorder=None,
- rorder=None,
- lstride=1,
- rstride=1,
- ):
- super(FSMNBlock, self).__init__()
-
- self.dim = input_dim
-
- if lorder is None:
- return
-
- self.lorder = lorder
- self.rorder = rorder
- self.lstride = lstride
- self.rstride = rstride
-
- self.conv_left = nn.Conv2d(
- self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
-
- if self.rorder > 0:
- self.conv_right = nn.Conv2d(
- self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
- else:
- self.conv_right = None
-
- def forward(self, input: torch.Tensor, cache: torch.Tensor):
- x = torch.unsqueeze(input, 1)
- x_per = x.permute(0, 3, 2, 1) # B D T C
-
- cache = cache.to(x_per.device)
- y_left = torch.cat((cache, x_per), dim=2)
- cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
- y_left = self.conv_left(y_left)
- out = x_per + y_left
-
- if self.conv_right is not None:
- # maybe need to check
- y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
- y_right = y_right[:, :, self.rstride:, :]
- y_right = self.conv_right(y_right)
- out += y_right
-
- out_per = out.permute(0, 3, 2, 1)
- output = out_per.squeeze(1)
-
- return output, cache
-
-
-class BasicBlock_export(nn.Module):
- def __init__(self,
- model,
- ):
- super(BasicBlock_export, self).__init__()
- self.linear = model.linear
- self.fsmn_block = model.fsmn_block
- self.affine = model.affine
- self.relu = model.relu
-
- def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
- x = self.linear(input) # B T D
- # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
- # if cache_layer_name not in in_cache:
- # in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
- x, out_cache = self.fsmn_block(x, in_cache)
- x = self.affine(x)
- x = self.relu(x)
- return x, out_cache
-
-
-# class FsmnStack(nn.Sequential):
-# def __init__(self, *args):
-# super(FsmnStack, self).__init__(*args)
-#
-# def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
-# x = input
-# for module in self._modules.values():
-# x = module(x, in_cache)
-# return x
-
-
-'''
-FSMN net for keyword spotting
-input_dim: input dimension
-linear_dim: fsmn input dimensionll
-proj_dim: fsmn projection dimension
-lorder: fsmn left order
-rorder: fsmn right order
-num_syn: output dimension
-fsmn_layers: no. of sequential fsmn layers
-'''
-
-
-class FSMN(nn.Module):
- def __init__(
- self, model,
- ):
- super(FSMN, self).__init__()
-
- # self.input_dim = input_dim
- # self.input_affine_dim = input_affine_dim
- # self.fsmn_layers = fsmn_layers
- # self.linear_dim = linear_dim
- # self.proj_dim = proj_dim
- # self.output_affine_dim = output_affine_dim
- # self.output_dim = output_dim
- #
- # self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
- # self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
- # self.relu = RectifiedLinear(linear_dim, linear_dim)
- # self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
- # range(fsmn_layers)])
- # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
- # self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
- # self.softmax = nn.Softmax(dim=-1)
- self.in_linear1 = model.in_linear1
- self.in_linear2 = model.in_linear2
- self.relu = model.relu
- # self.fsmn = model.fsmn
- self.out_linear1 = model.out_linear1
- self.out_linear2 = model.out_linear2
- self.softmax = model.softmax
- self.fsmn = model.fsmn
- for i, d in enumerate(model.fsmn):
- if isinstance(d, BasicBlock):
- self.fsmn[i] = BasicBlock_export(d)
-
- def fuse_modules(self):
- pass
-
- def forward(
- self,
- input: torch.Tensor,
- *args,
- ):
- """
- Args:
- input (torch.Tensor): Input tensor (B, T, D)
- in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
- {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
- """
-
- x = self.in_linear1(input)
- x = self.in_linear2(x)
- x = self.relu(x)
- # x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
- out_caches = list()
- for i, d in enumerate(self.fsmn):
- in_cache = args[i]
- x, out_cache = d(x, in_cache)
- out_caches.append(out_cache)
- x = self.out_linear1(x)
- x = self.out_linear2(x)
- x = self.softmax(x)
-
- return x, out_caches
-
-
-'''
-one deep fsmn layer
-dimproj: projection dimension, input and output dimension of memory blocks
-dimlinear: dimension of mapping layer
-lorder: left order
-rorder: right order
-lstride: left stride
-rstride: right stride
-'''
-
-
-class DFSMN(nn.Module):
-
- def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
- super(DFSMN, self).__init__()
-
- self.lorder = lorder
- self.rorder = rorder
- self.lstride = lstride
- self.rstride = rstride
-
- self.expand = AffineTransform(dimproj, dimlinear)
- self.shrink = LinearTransform(dimlinear, dimproj)
-
- self.conv_left = nn.Conv2d(
- dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
-
- if rorder > 0:
- self.conv_right = nn.Conv2d(
- dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
- else:
- self.conv_right = None
-
- def forward(self, input):
- f1 = F.relu(self.expand(input))
- p1 = self.shrink(f1)
-
- x = torch.unsqueeze(p1, 1)
- x_per = x.permute(0, 3, 2, 1)
-
- y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
-
- if self.conv_right is not None:
- y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
- y_right = y_right[:, :, self.rstride:, :]
- out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
- else:
- out = x_per + self.conv_left(y_left)
-
- out1 = out.permute(0, 3, 2, 1)
- output = input + out1.squeeze(1)
-
- return output
-
-
-'''
-build stacked dfsmn layers
-'''
-
-
-def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
- repeats = [
- nn.Sequential(
- DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
- for i in range(fsmn_layers)
- ]
-
- return nn.Sequential(*repeats)
-
-
-if __name__ == '__main__':
- fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
- print(fsmn)
-
- num_params = sum(p.numel() for p in fsmn.parameters())
- print('the number of model params: {}'.format(num_params))
- x = torch.zeros(128, 200, 400) # batch-size * time * dim
- y, _ = fsmn(x) # batch-size * time * dim
- print('input shape: {}'.format(x.shape))
- print('output shape: {}'.format(y.shape))
-
- print(fsmn.to_kaldi_net())
diff --git a/funasr/export/models/encoder/sanm_encoder.py b/funasr/export/models/encoder/sanm_encoder.py
deleted file mode 100644
index 7ef863e..0000000
--- a/funasr/export/models/encoder/sanm_encoder.py
+++ /dev/null
@@ -1,218 +0,0 @@
-import torch
-import torch.nn as nn
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.transformer.attention import MultiHeadedAttentionSANM
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
-from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
-from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
-from funasr.models.transformer.embedding import StreamSinusoidalPositionEncoder
-
-
-class SANMEncoder(nn.Module):
- def __init__(
- self,
- model,
- max_seq_len=512,
- feats_dim=560,
- model_name='encoder',
- onnx: bool = True,
- ):
- super().__init__()
- self.embed = model.embed
- if isinstance(self.embed, StreamSinusoidalPositionEncoder):
- self.embed = None
- self.model = model
- self.feats_dim = feats_dim
- self._output_size = model._output_size
-
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- if hasattr(model, 'encoders0'):
- for i, d in enumerate(self.model.encoders0):
- if isinstance(d.self_attn, MultiHeadedAttentionSANM):
- d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
- if isinstance(d.feed_forward, PositionwiseFeedForward):
- d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
- self.model.encoders0[i] = EncoderLayerSANM_export(d)
-
- for i, d in enumerate(self.model.encoders):
- if isinstance(d.self_attn, MultiHeadedAttentionSANM):
- d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
- if isinstance(d.feed_forward, PositionwiseFeedForward):
- d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
- self.model.encoders[i] = EncoderLayerSANM_export(d)
-
- self.model_name = model_name
- self.num_heads = model.encoders[0].self_attn.h
- self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
-
-
- def prepare_mask(self, mask):
- mask_3d_btd = mask[:, :, None]
- if len(mask.shape) == 2:
- mask_4d_bhlt = 1 - mask[:, None, None, :]
- elif len(mask.shape) == 3:
- mask_4d_bhlt = 1 - mask[:, None, :]
- mask_4d_bhlt = mask_4d_bhlt * -10000.0
-
- return mask_3d_btd, mask_4d_bhlt
-
- def forward(self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- online: bool = False
- ):
- if not online:
- speech = speech * self._output_size ** 0.5
- mask = self.make_pad_mask(speech_lengths)
- mask = self.prepare_mask(mask)
- if self.embed is None:
- xs_pad = speech
- else:
- xs_pad = self.embed(speech)
-
- encoder_outs = self.model.encoders0(xs_pad, mask)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
- encoder_outs = self.model.encoders(xs_pad, mask)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
- xs_pad = self.model.after_norm(xs_pad)
-
- return xs_pad, speech_lengths
-
- def get_output_size(self):
- return self.model.encoders[0].size
-
- def get_dummy_inputs(self):
- feats = torch.randn(1, 100, self.feats_dim)
- return (feats)
-
- def get_input_names(self):
- return ['feats']
-
- def get_output_names(self):
- return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
-
- def get_dynamic_axes(self):
- return {
- 'feats': {
- 1: 'feats_length'
- },
- 'encoder_out': {
- 1: 'enc_out_length'
- },
- 'predictor_weight':{
- 1: 'pre_out_length'
- }
-
- }
-
-
-class SANMVadEncoder(nn.Module):
- def __init__(
- self,
- model,
- max_seq_len=512,
- feats_dim=560,
- model_name='encoder',
- onnx: bool = True,
- ):
- super().__init__()
- self.embed = model.embed
- self.model = model
- self.feats_dim = feats_dim
- self._output_size = model._output_size
-
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- if hasattr(model, 'encoders0'):
- for i, d in enumerate(self.model.encoders0):
- if isinstance(d.self_attn, MultiHeadedAttentionSANM):
- d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
- if isinstance(d.feed_forward, PositionwiseFeedForward):
- d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
- self.model.encoders0[i] = EncoderLayerSANM_export(d)
-
- for i, d in enumerate(self.model.encoders):
- if isinstance(d.self_attn, MultiHeadedAttentionSANM):
- d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
- if isinstance(d.feed_forward, PositionwiseFeedForward):
- d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
- self.model.encoders[i] = EncoderLayerSANM_export(d)
-
- self.model_name = model_name
- self.num_heads = model.encoders[0].self_attn.h
- self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
-
- def prepare_mask(self, mask, sub_masks):
- mask_3d_btd = mask[:, :, None]
- mask_4d_bhlt = (1 - sub_masks) * -10000.0
-
- return mask_3d_btd, mask_4d_bhlt
-
- def forward(self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- vad_masks: torch.Tensor,
- sub_masks: torch.Tensor,
- ):
- speech = speech * self._output_size ** 0.5
- mask = self.make_pad_mask(speech_lengths)
- vad_masks = self.prepare_mask(mask, vad_masks)
- mask = self.prepare_mask(mask, sub_masks)
-
- if self.embed is None:
- xs_pad = speech
- else:
- xs_pad = self.embed(speech)
-
- encoder_outs = self.model.encoders0(xs_pad, mask)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
- # encoder_outs = self.model.encoders(xs_pad, mask)
- for layer_idx, encoder_layer in enumerate(self.model.encoders):
- if layer_idx == len(self.model.encoders) - 1:
- mask = vad_masks
- encoder_outs = encoder_layer(xs_pad, mask)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
- xs_pad = self.model.after_norm(xs_pad)
-
- return xs_pad, speech_lengths
-
- def get_output_size(self):
- return self.model.encoders[0].size
-
- # def get_dummy_inputs(self):
- # feats = torch.randn(1, 100, self.feats_dim)
- # return (feats)
- #
- # def get_input_names(self):
- # return ['feats']
- #
- # def get_output_names(self):
- # return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
- #
- # def get_dynamic_axes(self):
- # return {
- # 'feats': {
- # 1: 'feats_length'
- # },
- # 'encoder_out': {
- # 1: 'enc_out_length'
- # },
- # 'predictor_weight': {
- # 1: 'pre_out_length'
- # }
- #
- # }
diff --git a/funasr/export/models/modules/__init__.py b/funasr/export/models/modules/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/export/models/modules/__init__.py
+++ /dev/null
diff --git a/funasr/export/models/modules/decoder_layer.py b/funasr/export/models/modules/decoder_layer.py
deleted file mode 100644
index 9a464a4..0000000
--- a/funasr/export/models/modules/decoder_layer.py
+++ /dev/null
@@ -1,71 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-import torch
-from torch import nn
-
-
-class DecoderLayerSANM(nn.Module):
-
- def __init__(
- self,
- model
- ):
- super().__init__()
- self.self_attn = model.self_attn
- self.src_attn = model.src_attn
- self.feed_forward = model.feed_forward
- self.norm1 = model.norm1
- self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
- self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
- self.size = model.size
-
-
- def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
-
- residual = tgt
- tgt = self.norm1(tgt)
- tgt = self.feed_forward(tgt)
-
- x = tgt
- if self.self_attn is not None:
- tgt = self.norm2(tgt)
- x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
- x = residual + x
-
- if self.src_attn is not None:
- residual = x
- x = self.norm3(x)
- x = residual + self.src_attn(x, memory, memory_mask)
-
-
- return x, tgt_mask, memory, memory_mask, cache
-
-
-class DecoderLayer(nn.Module):
- def __init__(self, model):
- super().__init__()
- self.self_attn = model.self_attn
- self.src_attn = model.src_attn
- self.feed_forward = model.feed_forward
- self.norm1 = model.norm1
- self.norm2 = model.norm2
- self.norm3 = model.norm3
-
- def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
- residual = tgt
- tgt = self.norm1(tgt)
- tgt_q = tgt
- tgt_q_mask = tgt_mask
- x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
-
- residual = x
- x = self.norm2(x)
-
- x = residual + self.src_attn(x, memory, memory, memory_mask)
-
- residual = x
- x = self.norm3(x)
- x = residual + self.feed_forward(x)
-
- return x, tgt_mask, memory, memory_mask
diff --git a/funasr/export/models/modules/encoder_layer.py b/funasr/export/models/modules/encoder_layer.py
deleted file mode 100644
index 7d01397..0000000
--- a/funasr/export/models/modules/encoder_layer.py
+++ /dev/null
@@ -1,91 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-import torch
-from torch import nn
-
-
-class EncoderLayerSANM(nn.Module):
- def __init__(
- self,
- model,
- ):
- """Construct an EncoderLayer object."""
- super().__init__()
- self.self_attn = model.self_attn
- self.feed_forward = model.feed_forward
- self.norm1 = model.norm1
- self.norm2 = model.norm2
- self.in_size = model.in_size
- self.size = model.size
-
- def forward(self, x, mask):
-
- residual = x
- x = self.norm1(x)
- x = self.self_attn(x, mask)
- if self.in_size == self.size:
- x = x + residual
- residual = x
- x = self.norm2(x)
- x = self.feed_forward(x)
- x = x + residual
-
- return x, mask
-
-
-class EncoderLayerConformer(nn.Module):
- def __init__(
- self,
- model,
- ):
- """Construct an EncoderLayer object."""
- super().__init__()
- self.self_attn = model.self_attn
- self.feed_forward = model.feed_forward
- self.feed_forward_macaron = model.feed_forward_macaron
- self.conv_module = model.conv_module
- self.norm_ff = model.norm_ff
- self.norm_mha = model.norm_mha
- self.norm_ff_macaron = model.norm_ff_macaron
- self.norm_conv = model.norm_conv
- self.norm_final = model.norm_final
- self.size = model.size
-
- def forward(self, x, mask):
- if isinstance(x, tuple):
- x, pos_emb = x[0], x[1]
- else:
- x, pos_emb = x, None
-
- if self.feed_forward_macaron is not None:
- residual = x
- x = self.norm_ff_macaron(x)
- x = residual + self.feed_forward_macaron(x) * 0.5
-
- residual = x
- x = self.norm_mha(x)
-
- x_q = x
-
- if pos_emb is not None:
- x_att = self.self_attn(x_q, x, x, pos_emb, mask)
- else:
- x_att = self.self_attn(x_q, x, x, mask)
- x = residual + x_att
-
- if self.conv_module is not None:
- residual = x
- x = self.norm_conv(x)
- x = residual + self.conv_module(x)
-
- residual = x
- x = self.norm_ff(x)
- x = residual + self.feed_forward(x) * 0.5
-
- x = self.norm_final(x)
-
- if pos_emb is not None:
- return (x, pos_emb), mask
-
- return x, mask
diff --git a/funasr/export/models/modules/feedforward.py b/funasr/export/models/modules/feedforward.py
deleted file mode 100644
index 9388ae1..0000000
--- a/funasr/export/models/modules/feedforward.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-import torch
-import torch.nn as nn
-
-
-class PositionwiseFeedForward(nn.Module):
- def __init__(self, model):
- super().__init__()
- self.w_1 = model.w_1
- self.w_2 = model.w_2
- self.activation = model.activation
-
- def forward(self, x):
- x = self.activation(self.w_1(x))
- x = self.w_2(x)
- return x
-
-
-class PositionwiseFeedForwardDecoderSANM(nn.Module):
- def __init__(self, model):
- super().__init__()
- self.w_1 = model.w_1
- self.w_2 = model.w_2
- self.activation = model.activation
- self.norm = model.norm
-
- def forward(self, x):
- x = self.activation(self.w_1(x))
- x = self.w_2(self.norm(x))
- return x
\ No newline at end of file
diff --git a/funasr/export/models/modules/multihead_att.py b/funasr/export/models/modules/multihead_att.py
deleted file mode 100644
index 4885c4e..0000000
--- a/funasr/export/models/modules/multihead_att.py
+++ /dev/null
@@ -1,243 +0,0 @@
-import os
-import math
-
-import torch
-import torch.nn as nn
-
-
-class MultiHeadedAttentionSANM(nn.Module):
- def __init__(self, model):
- super().__init__()
- self.d_k = model.d_k
- self.h = model.h
- self.linear_out = model.linear_out
- self.linear_q_k_v = model.linear_q_k_v
- self.fsmn_block = model.fsmn_block
- self.pad_fn = model.pad_fn
-
- self.attn = None
- self.all_head_size = self.h * self.d_k
-
- def forward(self, x, mask):
- mask_3d_btd, mask_4d_bhlt = mask
- q_h, k_h, v_h, v = self.forward_qkv(x)
- fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
- q_h = q_h * self.d_k**(-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
- return att_outs + fsmn_memory
-
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
- new_x_shape = x.size()[:-1] + (self.h, self.d_k)
- x = x.view(new_x_shape)
- return x.permute(0, 2, 1, 3)
-
- def forward_qkv(self, x):
- q_k_v = self.linear_q_k_v(x)
- q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
- q_h = self.transpose_for_scores(q)
- k_h = self.transpose_for_scores(k)
- v_h = self.transpose_for_scores(v)
- return q_h, k_h, v_h, v
-
- def forward_fsmn(self, inputs, mask):
- # b, t, d = inputs.size()
- # mask = torch.reshape(mask, (b, -1, 1))
- inputs = inputs * mask
- x = inputs.transpose(1, 2)
- x = self.pad_fn(x)
- x = self.fsmn_block(x)
- x = x.transpose(1, 2)
- x = x + inputs
- x = x * mask
- return x
-
- def forward_attention(self, value, scores, mask):
- scores = scores + mask
-
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
-
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(new_context_layer_shape)
- return self.linear_out(context_layer) # (batch, time1, d_model)
-
-
-def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size):
- x = x * mask
- x = x.transpose(1, 2)
- if cache is None:
- x = pad_fn(x)
- else:
- x = torch.cat((cache, x), dim=2)
- cache = x[:, :, -(kernel_size-1):]
- return x, cache
-
-
-torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
-if torch_version >= (1, 8):
- import torch.fx
- torch.fx.wrap('preprocess_for_attn')
-
-
-class MultiHeadedAttentionSANMDecoder(nn.Module):
- def __init__(self, model):
- super().__init__()
- self.fsmn_block = model.fsmn_block
- self.pad_fn = model.pad_fn
- self.kernel_size = model.kernel_size
- self.attn = None
-
- def forward(self, inputs, mask, cache=None):
- x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size)
- x = self.fsmn_block(x)
- x = x.transpose(1, 2)
-
- x = x + inputs
- x = x * mask
- return x, cache
-
-
-class MultiHeadedAttentionCrossAtt(nn.Module):
- def __init__(self, model):
- super().__init__()
- self.d_k = model.d_k
- self.h = model.h
- self.linear_q = model.linear_q
- self.linear_k_v = model.linear_k_v
- self.linear_out = model.linear_out
- self.attn = None
- self.all_head_size = self.h * self.d_k
-
- def forward(self, x, memory, memory_mask):
- q, k, v = self.forward_qkv(x, memory)
- scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
- return self.forward_attention(v, scores, memory_mask)
-
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
- new_x_shape = x.size()[:-1] + (self.h, self.d_k)
- x = x.view(new_x_shape)
- return x.permute(0, 2, 1, 3)
-
- def forward_qkv(self, x, memory):
- q = self.linear_q(x)
-
- k_v = self.linear_k_v(memory)
- k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
- q = self.transpose_for_scores(q)
- k = self.transpose_for_scores(k)
- v = self.transpose_for_scores(v)
- return q, k, v
-
- def forward_attention(self, value, scores, mask):
- scores = scores + mask
-
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
-
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(new_context_layer_shape)
- return self.linear_out(context_layer) # (batch, time1, d_model)
-
-
-class OnnxMultiHeadedAttention(nn.Module):
- def __init__(self, model):
- super().__init__()
- self.d_k = model.d_k
- self.h = model.h
- self.linear_q = model.linear_q
- self.linear_k = model.linear_k
- self.linear_v = model.linear_v
- self.linear_out = model.linear_out
- self.attn = None
- self.all_head_size = self.h * self.d_k
-
- def forward(self, query, key, value, mask):
- q, k, v = self.forward_qkv(query, key, value)
- scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
- return self.forward_attention(v, scores, mask)
-
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
- new_x_shape = x.size()[:-1] + (self.h, self.d_k)
- x = x.view(new_x_shape)
- return x.permute(0, 2, 1, 3)
-
- def forward_qkv(self, query, key, value):
- q = self.linear_q(query)
- k = self.linear_k(key)
- v = self.linear_v(value)
- q = self.transpose_for_scores(q)
- k = self.transpose_for_scores(k)
- v = self.transpose_for_scores(v)
- return q, k, v
-
- def forward_attention(self, value, scores, mask):
- scores = scores + mask
-
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
-
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(new_context_layer_shape)
- return self.linear_out(context_layer) # (batch, time1, d_model)
-
-
-class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention):
- def __init__(self, model):
- super().__init__(model)
- self.linear_pos = model.linear_pos
- self.pos_bias_u = model.pos_bias_u
- self.pos_bias_v = model.pos_bias_v
-
- def forward(self, query, key, value, pos_emb, mask):
- q, k, v = self.forward_qkv(query, key, value)
- q = q.transpose(1, 2) # (batch, time1, head, d_k)
-
- p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k)
-
- # (batch, head, time1, d_k)
- q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
- # (batch, head, time1, d_k)
- q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
-
- # compute attention score
- # first compute matrix a and matrix c
- # as described in https://arxiv.org/abs/1901.02860 Section 3.3
- # (batch, head, time1, time2)
- matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
-
- # compute matrix b and matrix d
- # (batch, head, time1, time1)
- matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
- matrix_bd = self.rel_shift(matrix_bd)
-
- scores = (matrix_ac + matrix_bd) / math.sqrt(
- self.d_k
- ) # (batch, head, time1, time2)
-
- return self.forward_attention(v, scores, mask)
-
- def rel_shift(self, x):
- zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
- x_padded = torch.cat([zero_pad, x], dim=-1)
-
- x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
- x = x_padded[:, :, 1:].view_as(x)[
- :, :, :, : x.size(-1) // 2 + 1
- ] # only keep the positions from 0 to time2
- return x
-
- def forward_attention(self, value, scores, mask):
- scores = scores + mask
-
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
-
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(new_context_layer_shape)
- return self.linear_out(context_layer) # (batch, time1, d_model)
-
diff --git a/funasr/export/models/predictor/__init__.py b/funasr/export/models/predictor/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/export/models/predictor/__init__.py
+++ /dev/null
diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py
deleted file mode 100644
index 03c4433..0000000
--- a/funasr/export/models/predictor/cif.py
+++ /dev/null
@@ -1,295 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-import torch
-from torch import nn
-
-
-def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
- if maxlen is None:
- maxlen = lengths.max()
- row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
- matrix = torch.unsqueeze(lengths, dim=-1)
- mask = row_vector < matrix
- mask = mask.detach()
-
- return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
-
-def sequence_mask_scripts(lengths, maxlen:int):
- row_vector = torch.arange(0, maxlen, 1).type(lengths.dtype).to(lengths.device)
- matrix = torch.unsqueeze(lengths, dim=-1)
- mask = row_vector < matrix
- return mask.type(torch.float32).to(lengths.device)
-
-class CifPredictorV2(nn.Module):
- def __init__(self, model):
- super().__init__()
-
- self.pad = model.pad
- self.cif_conv1d = model.cif_conv1d
- self.cif_output = model.cif_output
- self.threshold = model.threshold
- self.smooth_factor = model.smooth_factor
- self.noise_threshold = model.noise_threshold
- self.tail_threshold = model.tail_threshold
-
- def forward(self, hidden: torch.Tensor,
- mask: torch.Tensor,
- ):
- alphas, token_num = self.forward_cnn(hidden, mask)
- mask = mask.transpose(-1, -2).float()
- mask = mask.squeeze(-1)
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-
- return acoustic_embeds, token_num, alphas, cif_peak
-
- def forward_cnn(self, hidden: torch.Tensor,
- mask: torch.Tensor,
- ):
- h = hidden
- context = h.transpose(1, 2)
- queries = self.pad(context)
- output = torch.relu(self.cif_conv1d(queries))
- output = output.transpose(1, 2)
-
- output = self.cif_output(output)
- alphas = torch.sigmoid(output)
- alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
- mask = mask.transpose(-1, -2).float()
- alphas = alphas * mask
- alphas = alphas.squeeze(-1)
- token_num = alphas.sum(-1)
-
- return alphas, token_num
-
- def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
- b, t, d = hidden.size()
- tail_threshold = self.tail_threshold
-
- zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
- ones_t = torch.ones_like(zeros_t)
-
- mask_1 = torch.cat([mask, zeros_t], dim=1)
- mask_2 = torch.cat([ones_t, mask], dim=1)
- mask = mask_2 - mask_1
- tail_threshold = mask * tail_threshold
- alphas = torch.cat([alphas, zeros_t], dim=1)
- alphas = torch.add(alphas, tail_threshold)
-
- zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
- hidden = torch.cat([hidden, zeros], dim=1)
- token_num = alphas.sum(dim=-1)
- token_num_floor = torch.floor(token_num)
-
- return hidden, alphas, token_num_floor
-
-
-# @torch.jit.script
-# def cif(hidden, alphas, threshold: float):
-# batch_size, len_time, hidden_size = hidden.size()
-# threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
-#
-# # loop varss
-# integrate = torch.zeros([batch_size], device=hidden.device)
-# frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
-# # intermediate vars along time
-# list_fires = []
-# list_frames = []
-#
-# for t in range(len_time):
-# alpha = alphas[:, t]
-# distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
-#
-# integrate += alpha
-# list_fires.append(integrate)
-#
-# fire_place = integrate >= threshold
-# integrate = torch.where(fire_place,
-# integrate - torch.ones([batch_size], device=hidden.device),
-# integrate)
-# cur = torch.where(fire_place,
-# distribution_completion,
-# alpha)
-# remainds = alpha - cur
-#
-# frame += cur[:, None] * hidden[:, t, :]
-# list_frames.append(frame)
-# frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
-# remainds[:, None] * hidden[:, t, :],
-# frame)
-#
-# fires = torch.stack(list_fires, 1)
-# frames = torch.stack(list_frames, 1)
-# list_ls = []
-# len_labels = torch.floor(alphas.sum(-1)).int()
-# max_label_len = len_labels.max()
-# for b in range(batch_size):
-# fire = fires[b, :]
-# l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
-# pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device)
-# list_ls.append(torch.cat([l, pad_l], 0))
-# return torch.stack(list_ls, 0), fires
-
-
-@torch.jit.script
-def cif(hidden, alphas, threshold: float):
- batch_size, len_time, hidden_size = hidden.size()
- threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
-
- # loop varss
- integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
- frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
- # intermediate vars along time
- list_fires = []
- list_frames = []
-
- for t in range(len_time):
- alpha = alphas[:, t]
- distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
-
- integrate += alpha
- list_fires.append(integrate)
-
- fire_place = integrate >= threshold
- integrate = torch.where(fire_place,
- integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
- integrate)
- cur = torch.where(fire_place,
- distribution_completion,
- alpha)
- remainds = alpha - cur
-
- frame += cur[:, None] * hidden[:, t, :]
- list_frames.append(frame)
- frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
- remainds[:, None] * hidden[:, t, :],
- frame)
-
- fires = torch.stack(list_fires, 1)
- frames = torch.stack(list_frames, 1)
-
- fire_idxs = fires >= threshold
- frame_fires = torch.zeros_like(hidden)
- max_label_len = frames[0, fire_idxs[0]].size(0)
- for b in range(batch_size):
- frame_fire = frames[b, fire_idxs[b]]
- frame_len = frame_fire.size(0)
- frame_fires[b, :frame_len, :] = frame_fire
-
- if frame_len >= max_label_len:
- max_label_len = frame_len
- frame_fires = frame_fires[:, :max_label_len, :]
- return frame_fires, fires
-
-
-class CifPredictorV3(nn.Module):
- def __init__(self, model):
- super().__init__()
-
- self.pad = model.pad
- self.cif_conv1d = model.cif_conv1d
- self.cif_output = model.cif_output
- self.threshold = model.threshold
- self.smooth_factor = model.smooth_factor
- self.noise_threshold = model.noise_threshold
- self.tail_threshold = model.tail_threshold
-
- self.upsample_times = model.upsample_times
- self.upsample_cnn = model.upsample_cnn
- self.blstm = model.blstm
- self.cif_output2 = model.cif_output2
- self.smooth_factor2 = model.smooth_factor2
- self.noise_threshold2 = model.noise_threshold2
-
- def forward(self, hidden: torch.Tensor,
- mask: torch.Tensor,
- ):
- h = hidden
- context = h.transpose(1, 2)
- queries = self.pad(context)
- output = torch.relu(self.cif_conv1d(queries))
- output = output.transpose(1, 2)
-
- output = self.cif_output(output)
- alphas = torch.sigmoid(output)
- alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
- mask = mask.transpose(-1, -2).float()
- alphas = alphas * mask
- alphas = alphas.squeeze(-1)
- token_num = alphas.sum(-1)
-
- mask = mask.squeeze(-1)
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-
- return acoustic_embeds, token_num, alphas, cif_peak
-
- def get_upsample_timestmap(self, hidden, mask=None, token_num=None):
- h = hidden
- b = hidden.shape[0]
- context = h.transpose(1, 2)
-
- # generate alphas2
- _output = context
- output2 = self.upsample_cnn(_output)
- output2 = output2.transpose(1, 2)
- output2, (_, _) = self.blstm(output2)
- alphas2 = torch.sigmoid(self.cif_output2(output2))
- alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
-
- mask = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
- mask = mask.unsqueeze(-1)
- alphas2 = alphas2 * mask
- alphas2 = alphas2.squeeze(-1)
- _token_num = alphas2.sum(-1)
- alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
- # upsampled alphas and cif_peak
- us_alphas = alphas2
- us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
- return us_alphas, us_cif_peak
-
- def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
- b, t, d = hidden.size()
- tail_threshold = self.tail_threshold
-
- zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
- ones_t = torch.ones_like(zeros_t)
-
- mask_1 = torch.cat([mask, zeros_t], dim=1)
- mask_2 = torch.cat([ones_t, mask], dim=1)
- mask = mask_2 - mask_1
- tail_threshold = mask * tail_threshold
- alphas = torch.cat([alphas, zeros_t], dim=1)
- alphas = torch.add(alphas, tail_threshold)
-
- zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
- hidden = torch.cat([hidden, zeros], dim=1)
- token_num = alphas.sum(dim=-1)
- token_num_floor = torch.floor(token_num)
-
- return hidden, alphas, token_num_floor
-
-
-@torch.jit.script
-def cif_wo_hidden(alphas, threshold: float):
- batch_size, len_time = alphas.size()
-
- # loop varss
- integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=alphas.device)
- # intermediate vars along time
- list_fires = []
-
- for t in range(len_time):
- alpha = alphas[:, t]
-
- integrate += alpha
- list_fires.append(integrate)
-
- fire_place = integrate >= threshold
- integrate = torch.where(fire_place,
- integrate - torch.ones([batch_size], device=alphas.device)*threshold,
- integrate)
-
- fires = torch.stack(list_fires, 1)
- return fires
diff --git a/funasr/export/test/__init__.py b/funasr/export/test/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/export/test/__init__.py
+++ /dev/null
diff --git a/funasr/export/test/test_onnx.py b/funasr/export/test/test_onnx.py
deleted file mode 100644
index 4351728..0000000
--- a/funasr/export/test/test_onnx.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import onnxruntime
-import numpy as np
-
-
-if __name__ == '__main__':
- onnx_path = "/mnt/workspace/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.onnx"
- sess = onnxruntime.InferenceSession(onnx_path)
- input_name = [nd.name for nd in sess.get_inputs()]
- output_name = [nd.name for nd in sess.get_outputs()]
-
- def _get_feed_dict(feats_length):
- return {'speech': np.zeros((1, feats_length, 560), dtype=np.float32), 'speech_lengths': np.array([feats_length,], dtype=np.int32)}
-
- def _run(feed_dict):
- output = sess.run(output_name, input_feed=feed_dict)
- for name, value in zip(output_name, output):
- print('{}: {}'.format(name, value.shape))
-
- _run(_get_feed_dict(100))
- _run(_get_feed_dict(200))
\ No newline at end of file
diff --git a/funasr/export/test/test_onnx_punc.py b/funasr/export/test/test_onnx_punc.py
deleted file mode 100644
index 39f85f4..0000000
--- a/funasr/export/test/test_onnx_punc.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import onnxruntime
-import numpy as np
-
-
-if __name__ == '__main__':
- onnx_path = "../damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.onnx"
- sess = onnxruntime.InferenceSession(onnx_path)
- input_name = [nd.name for nd in sess.get_inputs()]
- output_name = [nd.name for nd in sess.get_outputs()]
-
- def _get_feed_dict(text_length):
- return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32)}
-
- def _run(feed_dict):
- output = sess.run(output_name, input_feed=feed_dict)
- for name, value in zip(output_name, output):
- print('{}: {}'.format(name, value))
- _run(_get_feed_dict(10))
diff --git a/funasr/export/test/test_onnx_punc_vadrealtime.py b/funasr/export/test/test_onnx_punc_vadrealtime.py
deleted file mode 100644
index 507226e..0000000
--- a/funasr/export/test/test_onnx_punc_vadrealtime.py
+++ /dev/null
@@ -1,22 +0,0 @@
-import onnxruntime
-import numpy as np
-
-
-if __name__ == '__main__':
- onnx_path = "./export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/model.onnx"
- sess = onnxruntime.InferenceSession(onnx_path)
- input_name = [nd.name for nd in sess.get_inputs()]
- output_name = [nd.name for nd in sess.get_outputs()]
-
- def _get_feed_dict(text_length):
- return {'inputs': np.ones((1, text_length), dtype=np.int64),
- 'text_lengths': np.array([text_length,], dtype=np.int32),
- 'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
- 'sub_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
- }
-
- def _run(feed_dict):
- output = sess.run(output_name, input_feed=feed_dict)
- for name, value in zip(output_name, output):
- print('{}: {}'.format(name, value))
- _run(_get_feed_dict(10))
diff --git a/funasr/export/test/test_onnx_vad.py b/funasr/export/test/test_onnx_vad.py
deleted file mode 100644
index 12f058f..0000000
--- a/funasr/export/test/test_onnx_vad.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import onnxruntime
-import numpy as np
-
-
-if __name__ == '__main__':
- onnx_path = "/mnt/workspace/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx"
- sess = onnxruntime.InferenceSession(onnx_path)
- input_name = [nd.name for nd in sess.get_inputs()]
- output_name = [nd.name for nd in sess.get_outputs()]
-
- def _get_feed_dict(feats_length):
-
- return {'speech': np.random.rand(1, feats_length, 400).astype(np.float32),
- 'in_cache0': np.random.rand(1, 128, 19, 1).astype(np.float32),
- 'in_cache1': np.random.rand(1, 128, 19, 1).astype(np.float32),
- 'in_cache2': np.random.rand(1, 128, 19, 1).astype(np.float32),
- 'in_cache3': np.random.rand(1, 128, 19, 1).astype(np.float32),
- }
-
- def _run(feed_dict):
- output = sess.run(output_name, input_feed=feed_dict)
- for name, value in zip(output_name, output):
- print('{}: {}'.format(name, value.shape))
-
- _run(_get_feed_dict(100))
- _run(_get_feed_dict(200))
\ No newline at end of file
diff --git a/funasr/export/test/test_torchscripts.py b/funasr/export/test/test_torchscripts.py
deleted file mode 100644
index 9afec74..0000000
--- a/funasr/export/test/test_torchscripts.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import torch
-import numpy as np
-
-if __name__ == '__main__':
- onnx_path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.torchscripts"
- loaded = torch.jit.load(onnx_path)
-
- x = torch.rand([2, 21, 560])
- x_len = torch.IntTensor([6, 21])
- res = loaded(x, x_len)
- print(res[0].size(), res[1])
-
- x = torch.rand([5, 50, 560])
- x_len = torch.IntTensor([6, 21, 10, 30, 50])
- res = loaded(x, x_len)
- print(res[0].size(), res[1])
-
\ No newline at end of file
diff --git a/funasr/export/utils/__init__.py b/funasr/export/utils/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/export/utils/__init__.py
+++ /dev/null
diff --git a/funasr/export/utils/torch_function.py b/funasr/export/utils/torch_function.py
deleted file mode 100644
index a078a7e..0000000
--- a/funasr/export/utils/torch_function.py
+++ /dev/null
@@ -1,80 +0,0 @@
-from typing import Optional
-
-import torch
-import torch.nn as nn
-
-import numpy as np
-
-
-class MakePadMask(nn.Module):
- def __init__(self, max_seq_len=512, flip=True):
- super().__init__()
- if flip:
- self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool)
- else:
- self.mask_pad = torch.Tensor(np.tri(max_seq_len)).type(torch.bool)
-
- def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
- """Make mask tensor containing indices of padded part.
- This implementation creates the same mask tensor with original make_pad_mask,
- which can be converted into onnx format.
- Dimension length of xs should be 2 or 3.
- """
- if length_dim == 0:
- raise ValueError("length_dim cannot be 0: {}".format(length_dim))
-
- if xs is not None and len(xs.shape) == 3:
- if length_dim == 1:
- lengths = lengths.unsqueeze(1).expand(
- *xs.transpose(1, 2).shape[:2])
- else:
- lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
-
- if maxlen is not None:
- m = maxlen
- elif xs is not None:
- m = xs.shape[-1]
- else:
- m = torch.max(lengths)
-
- mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)
-
- if length_dim == 1:
- return mask.transpose(1, 2)
- else:
- return mask
-
-class sequence_mask(nn.Module):
- def __init__(self, max_seq_len=512, flip=True):
- super().__init__()
-
- def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
- if max_seq_len is None:
- max_seq_len = lengths.max()
- row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
- matrix = torch.unsqueeze(lengths, dim=-1)
- mask = row_vector < matrix
-
- return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
-
-def normalize(input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None) -> torch.Tensor:
- if out is None:
- denom = input.norm(p, dim, keepdim=True).expand_as(input)
- return input / denom
- else:
- denom = input.norm(p, dim, keepdim=True).expand_as(input)
- return torch.div(input, denom, out=out)
-
-def subsequent_mask(size: torch.Tensor):
- return torch.ones(size, size).tril()
-
-
-def MakePadMask_test():
- feats_length = torch.tensor([10]).type(torch.long)
- mask_fn = MakePadMask()
- mask = mask_fn(feats_length)
- print(mask)
-
-
-if __name__ == '__main__':
- MakePadMask_test()
\ No newline at end of file
diff --git a/funasr/export/__init__.py b/funasr/frontends/__init__.py
similarity index 100%
rename from funasr/export/__init__.py
rename to funasr/frontends/__init__.py
diff --git a/funasr/models/frontend/default.py b/funasr/frontends/default.py
similarity index 95%
rename from funasr/models/frontend/default.py
rename to funasr/frontends/default.py
index b4e518a..8ac1ca8 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/frontends/default.py
@@ -6,20 +6,19 @@
import humanfriendly
import numpy as np
import torch
+import torch.nn as nn
try:
from torch_complex.tensor import ComplexTensor
except:
print("Please install torch_complex firstly")
-from funasr.models.frontend.utils.log_mel import LogMel
-from funasr.models.frontend.utils.stft import Stft
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.utils.frontend import Frontend
-from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.frontends.utils.log_mel import LogMel
+from funasr.frontends.utils.stft import Stft
+from funasr.frontends.utils.frontend import Frontend
from funasr.models.transformer.utils.nets_utils import make_pad_mask
-class DefaultFrontend(AbsFrontend):
+class DefaultFrontend(nn.Module):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
@@ -38,7 +37,7 @@
fmin: int = None,
fmax: int = None,
htk: bool = False,
- frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
+ frontend_conf: Optional[dict] = None,
apply_stft: bool = True,
use_channel: int = None,
):
@@ -139,7 +138,7 @@
return input_stft, feats_lens
-class MultiChannelFrontend(AbsFrontend):
+class MultiChannelFrontend(nn.Module):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
@@ -160,7 +159,7 @@
fmin: int = None,
fmax: int = None,
htk: bool = False,
- frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
+ frontend_conf: Optional[dict] = None,
apply_stft: bool = True,
use_channel: int = None,
lfr_m: int = 1,
diff --git a/funasr/models/frontend/eend_ola_feature.py b/funasr/frontends/eend_ola_feature.py
similarity index 100%
rename from funasr/models/frontend/eend_ola_feature.py
rename to funasr/frontends/eend_ola_feature.py
diff --git a/funasr/models/frontend/fused.py b/funasr/frontends/fused.py
similarity index 95%
rename from funasr/models/frontend/fused.py
rename to funasr/frontends/fused.py
index ff95871..24f73f4 100644
--- a/funasr/models/frontend/fused.py
+++ b/funasr/frontends/fused.py
@@ -1,12 +1,12 @@
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.frontends.default import DefaultFrontend
+from funasr.frontends.s3prl import S3prlFrontend
import numpy as np
import torch
+import torch.nn as nn
from typing import Tuple
-class FusedFrontends(AbsFrontend):
+class FusedFrontends(nn.Module):
def __init__(
self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
):
diff --git a/funasr/models/frontend/s3prl.py b/funasr/frontends/s3prl.py
similarity index 93%
rename from funasr/models/frontend/s3prl.py
rename to funasr/frontends/s3prl.py
index 0e419bc..ff60592 100644
--- a/funasr/models/frontend/s3prl.py
+++ b/funasr/frontends/s3prl.py
@@ -8,11 +8,10 @@
import humanfriendly
import torch
+import torch.nn as nn
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.utils.frontend import Frontend
+from funasr.frontends.utils.frontend import Frontend
from funasr.models.transformer.utils.nets_utils import pad_list
-from funasr.utils.get_default_kwargs import get_default_kwargs
def base_s3prl_setup(args):
@@ -26,13 +25,13 @@
return args
-class S3prlFrontend(AbsFrontend):
+class S3prlFrontend(nn.Module):
"""Speech Pretrained Representation frontend structure for ASR."""
def __init__(
self,
fs: Union[int, str] = 16000,
- frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
+ frontend_conf: Optional[dict] = None,
download_dir: str = None,
multilayer_feature: bool = False,
):
diff --git a/funasr/models/frontend/utils/__init__.py b/funasr/frontends/utils/__init__.py
similarity index 100%
rename from funasr/models/frontend/utils/__init__.py
rename to funasr/frontends/utils/__init__.py
diff --git a/funasr/models/frontend/utils/beamformer.py b/funasr/frontends/utils/beamformer.py
similarity index 100%
rename from funasr/models/frontend/utils/beamformer.py
rename to funasr/frontends/utils/beamformer.py
diff --git a/funasr/models/frontend/utils/complex_utils.py b/funasr/frontends/utils/complex_utils.py
similarity index 100%
rename from funasr/models/frontend/utils/complex_utils.py
rename to funasr/frontends/utils/complex_utils.py
diff --git a/funasr/models/frontend/utils/dnn_beamformer.py b/funasr/frontends/utils/dnn_beamformer.py
similarity index 94%
rename from funasr/models/frontend/utils/dnn_beamformer.py
rename to funasr/frontends/utils/dnn_beamformer.py
index 0f4e9da..75637d2 100644
--- a/funasr/models/frontend/utils/dnn_beamformer.py
+++ b/funasr/frontends/utils/dnn_beamformer.py
@@ -4,12 +4,12 @@
import torch
from torch.nn import functional as F
-from funasr.models.frontend.utils.beamformer import apply_beamforming_vector
-from funasr.models.frontend.utils.beamformer import get_mvdr_vector
-from funasr.models.frontend.utils.beamformer import (
+from funasr.frontends.utils.beamformer import apply_beamforming_vector
+from funasr.frontends.utils.beamformer import get_mvdr_vector
+from funasr.frontends.utils.beamformer import (
get_power_spectral_density_matrix, # noqa: H301
)
-from funasr.models.frontend.utils.mask_estimator import MaskEstimator
+from funasr.frontends.utils.mask_estimator import MaskEstimator
from torch_complex.tensor import ComplexTensor
diff --git a/funasr/models/frontend/utils/dnn_wpe.py b/funasr/frontends/utils/dnn_wpe.py
similarity index 97%
rename from funasr/models/frontend/utils/dnn_wpe.py
rename to funasr/frontends/utils/dnn_wpe.py
index 6a14c5d..9171339 100644
--- a/funasr/models/frontend/utils/dnn_wpe.py
+++ b/funasr/frontends/utils/dnn_wpe.py
@@ -4,7 +4,7 @@
import torch
from torch_complex.tensor import ComplexTensor
-from funasr.models.frontend.utils.mask_estimator import MaskEstimator
+from funasr.frontends.utils.mask_estimator import MaskEstimator
from funasr.models.transformer.utils.nets_utils import make_pad_mask
diff --git a/funasr/models/frontend/utils/feature_transform.py b/funasr/frontends/utils/feature_transform.py
similarity index 100%
rename from funasr/models/frontend/utils/feature_transform.py
rename to funasr/frontends/utils/feature_transform.py
diff --git a/funasr/models/frontend/utils/frontend.py b/funasr/frontends/utils/frontend.py
similarity index 96%
rename from funasr/models/frontend/utils/frontend.py
rename to funasr/frontends/utils/frontend.py
index 8f84e73..a2a6cc1 100644
--- a/funasr/models/frontend/utils/frontend.py
+++ b/funasr/frontends/utils/frontend.py
@@ -8,8 +8,8 @@
import torch.nn as nn
from torch_complex.tensor import ComplexTensor
-from funasr.models.frontend.utils.dnn_beamformer import DNN_Beamformer
-from funasr.models.frontend.utils.dnn_wpe import DNN_WPE
+from funasr.frontends.utils.dnn_beamformer import DNN_Beamformer
+from funasr.frontends.utils.dnn_wpe import DNN_WPE
class Frontend(nn.Module):
diff --git a/funasr/models/frontend/utils/log_mel.py b/funasr/frontends/utils/log_mel.py
similarity index 100%
rename from funasr/models/frontend/utils/log_mel.py
rename to funasr/frontends/utils/log_mel.py
diff --git a/funasr/models/frontend/utils/mask_estimator.py b/funasr/frontends/utils/mask_estimator.py
similarity index 100%
rename from funasr/models/frontend/utils/mask_estimator.py
rename to funasr/frontends/utils/mask_estimator.py
diff --git a/funasr/models/frontend/utils/stft.py b/funasr/frontends/utils/stft.py
similarity index 98%
rename from funasr/models/frontend/utils/stft.py
rename to funasr/frontends/utils/stft.py
index 2a0bc96..00d9ec5 100644
--- a/funasr/models/frontend/utils/stft.py
+++ b/funasr/frontends/utils/stft.py
@@ -10,7 +10,7 @@
except:
print("Please install torch_complex firstly")
from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.frontend.utils.complex_utils import is_complex
+from funasr.frontends.utils.complex_utils import is_complex
import librosa
import numpy as np
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/frontends/wav_frontend.py
similarity index 97%
rename from funasr/models/frontend/wav_frontend.py
rename to funasr/frontends/wav_frontend.py
index ac16065..4866fa1 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/frontends/wav_frontend.py
@@ -4,11 +4,13 @@
import numpy as np
import torch
+import torch.nn as nn
import torchaudio.compliance.kaldi as kaldi
from torch.nn.utils.rnn import pad_sequence
-import funasr.models.frontend.eend_ola_feature as eend_ola_feature
-from funasr.models.frontend.abs_frontend import AbsFrontend
+import funasr.frontends.eend_ola_feature as eend_ola_feature
+from funasr.utils.register import register_class
+
def load_cmvn(cmvn_file):
@@ -73,8 +75,8 @@
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
-
-class WavFrontend(AbsFrontend):
+@register_class("frontend_classes", "WavFrontend")
+class WavFrontend(nn.Module):
"""Conventional frontend structure for ASR.
"""
@@ -93,6 +95,7 @@
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
+ **kwargs,
):
super().__init__()
self.fs = fs
@@ -208,7 +211,8 @@
return feats_pad, feats_lens
-class WavFrontendOnline(AbsFrontend):
+@register_class("frontend_classes", "WavFrontendOnline")
+class WavFrontendOnline(nn.Module):
"""Conventional frontend structure for streaming ASR/VAD.
"""
@@ -227,6 +231,7 @@
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
+ **kwargs,
):
super().__init__()
self.fs = fs
@@ -454,7 +459,7 @@
self.lfr_splice_cache = []
-class WavFrontendMel23(AbsFrontend):
+class WavFrontendMel23(nn.Module):
"""Conventional frontend structure for ASR.
"""
@@ -465,6 +470,7 @@
frame_shift: int = 10,
lfr_m: int = 1,
lfr_n: int = 1,
+ **kwargs,
):
super().__init__()
self.fs = fs
diff --git a/funasr/models/frontend/wav_frontend_kaldifeat.py b/funasr/frontends/wav_frontend_kaldifeat.py
similarity index 100%
rename from funasr/models/frontend/wav_frontend_kaldifeat.py
rename to funasr/frontends/wav_frontend_kaldifeat.py
diff --git a/funasr/models/frontend/windowing.py b/funasr/frontends/windowing.py
similarity index 96%
rename from funasr/models/frontend/windowing.py
rename to funasr/frontends/windowing.py
index 94c9d27..3250550 100644
--- a/funasr/models/frontend/windowing.py
+++ b/funasr/frontends/windowing.py
@@ -4,12 +4,12 @@
"""Sliding Window for raw audio input data."""
-from funasr.models.frontend.abs_frontend import AbsFrontend
import torch
+import torch.nn as nn
from typing import Tuple
-class SlidingWindow(AbsFrontend):
+class SlidingWindow(nn.Module):
"""Sliding Window.
Provides a sliding window over a batched continuous raw audio tensor.
Optionally, provides padding (Currently not implemented).
diff --git a/funasr/metrics/compute_acc.py b/funasr/metrics/compute_acc.py
new file mode 100644
index 0000000..2b45836
--- /dev/null
+++ b/funasr/metrics/compute_acc.py
@@ -0,0 +1,23 @@
+import torch
+
+def th_accuracy(pad_outputs, pad_targets, ignore_label):
+ """Calculate accuracy.
+
+ Args:
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
+ pad_targets (LongTensor): Target label tensors (B, Lmax).
+ ignore_label (int): Ignore label id.
+
+ Returns:
+ float: Accuracy value (0.0 - 1.0).
+
+ """
+ pad_pred = pad_outputs.view(
+ pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
+ ).argmax(2)
+ mask = pad_targets != ignore_label
+ numerator = torch.sum(
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
+ )
+ denominator = torch.sum(mask)
+ return float(numerator) / float(denominator)
diff --git a/funasr/models/bat/attention.py b/funasr/models/bat/attention.py
new file mode 100644
index 0000000..11645b3
--- /dev/null
+++ b/funasr/models/bat/attention.py
@@ -0,0 +1,238 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention layer definition."""
+
+import math
+
+import numpy
+import torch
+from torch import nn
+from typing import Optional, Tuple
+
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+import funasr.models.lora.layers as lora
+
+
+class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
+ """RelPositionMultiHeadedAttention definition.
+ Args:
+ num_heads: Number of attention heads.
+ embed_size: Embedding size.
+ dropout_rate: Dropout rate.
+ """
+
+ def __init__(
+ self,
+ num_heads: int,
+ embed_size: int,
+ dropout_rate: float = 0.0,
+ simplified_attention_score: bool = False,
+ ) -> None:
+ """Construct an MultiHeadedAttention object."""
+ super().__init__()
+
+ self.d_k = embed_size // num_heads
+ self.num_heads = num_heads
+
+ assert self.d_k * num_heads == embed_size, (
+ "embed_size (%d) must be divisible by num_heads (%d)",
+ (embed_size, num_heads),
+ )
+
+ self.linear_q = torch.nn.Linear(embed_size, embed_size)
+ self.linear_k = torch.nn.Linear(embed_size, embed_size)
+ self.linear_v = torch.nn.Linear(embed_size, embed_size)
+
+ self.linear_out = torch.nn.Linear(embed_size, embed_size)
+
+ if simplified_attention_score:
+ self.linear_pos = torch.nn.Linear(embed_size, num_heads)
+
+ self.compute_att_score = self.compute_simplified_attention_score
+ else:
+ self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
+
+ self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+ self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ self.compute_att_score = self.compute_attention_score
+
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.attn = None
+
+ def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+ """Compute relative positional encoding.
+ Args:
+ x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
+ left_context: Number of frames in left context.
+ Returns:
+ x: Output sequence. (B, H, T_1, T_2)
+ """
+ batch_size, n_heads, time1, n = x.shape
+ time2 = time1 + left_context
+
+ batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
+
+ return x.as_strided(
+ (batch_size, n_heads, time1, time2),
+ (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
+ storage_offset=(n_stride * (time1 - 1)),
+ )
+
+ def compute_simplified_attention_score(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ pos_enc: torch.Tensor,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Simplified attention score computation.
+ Reference: https://github.com/k2-fsa/icefall/pull/458
+ Args:
+ query: Transformed query tensor. (B, H, T_1, d_k)
+ key: Transformed key tensor. (B, H, T_2, d_k)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ left_context: Number of frames in left context.
+ Returns:
+ : Attention score. (B, H, T_1, T_2)
+ """
+ pos_enc = self.linear_pos(pos_enc)
+
+ matrix_ac = torch.matmul(query, key.transpose(2, 3))
+
+ matrix_bd = self.rel_shift(
+ pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
+ left_context=left_context,
+ )
+
+ return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+ def compute_attention_score(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ pos_enc: torch.Tensor,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Attention score computation.
+ Args:
+ query: Transformed query tensor. (B, H, T_1, d_k)
+ key: Transformed key tensor. (B, H, T_2, d_k)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ left_context: Number of frames in left context.
+ Returns:
+ : Attention score. (B, H, T_1, T_2)
+ """
+ p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
+
+ query = query.transpose(1, 2)
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+ matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+ matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
+ matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
+
+ return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+ def forward_qkv(
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Transform query, key and value.
+ Args:
+ query: Query tensor. (B, T_1, size)
+ key: Key tensor. (B, T_2, size)
+ v: Value tensor. (B, T_2, size)
+ Returns:
+ q: Transformed query tensor. (B, H, T_1, d_k)
+ k: Transformed key tensor. (B, H, T_2, d_k)
+ v: Transformed value tensor. (B, H, T_2, d_k)
+ """
+ n_batch = query.size(0)
+
+ q = (
+ self.linear_q(query)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+ k = (
+ self.linear_k(key)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+ v = (
+ self.linear_v(value)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+
+ return q, k, v
+
+ def forward_attention(
+ self,
+ value: torch.Tensor,
+ scores: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Compute attention context vector.
+ Args:
+ value: Transformed value. (B, H, T_2, d_k)
+ scores: Attention score. (B, H, T_1, T_2)
+ mask: Source mask. (B, T_2)
+ chunk_mask: Chunk mask. (T_1, T_1)
+ Returns:
+ attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
+ """
+ batch_size = scores.size(0)
+ mask = mask.unsqueeze(1).unsqueeze(2)
+ if chunk_mask is not None:
+ mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
+ scores = scores.masked_fill(mask, float("-inf"))
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+
+ attn_output = self.dropout(self.attn)
+ attn_output = torch.matmul(attn_output, value)
+
+ attn_output = self.linear_out(
+ attn_output.transpose(1, 2)
+ .contiguous()
+ .view(batch_size, -1, self.num_heads * self.d_k)
+ )
+
+ return attn_output
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Compute scaled dot product attention with rel. positional encoding.
+ Args:
+ query: Query tensor. (B, T_1, size)
+ key: Key tensor. (B, T_2, size)
+ value: Value tensor. (B, T_2, size)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ mask: Source mask. (B, T_2)
+ chunk_mask: Chunk mask. (T_1, T_1)
+ left_context: Number of frames in left context.
+ Returns:
+ : Output tensor. (B, T_1, H * d_k)
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
+ return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
+
diff --git a/funasr/models/bat/cif_predictor.py b/funasr/models/bat/cif_predictor.py
new file mode 100644
index 0000000..9aa3e33
--- /dev/null
+++ b/funasr/models/bat/cif_predictor.py
@@ -0,0 +1,220 @@
+# import torch
+# from torch import nn
+# from torch import Tensor
+# import logging
+# import numpy as np
+# from funasr.train_utils.device_funcs import to_device
+# from funasr.models.transformer.utils.nets_utils import make_pad_mask
+# from funasr.models.scama.utils import sequence_mask
+# from typing import Optional, Tuple
+#
+# from funasr.utils.register import register_class
+#
+# class mae_loss(nn.Module):
+#
+# def __init__(self, normalize_length=False):
+# super(mae_loss, self).__init__()
+# self.normalize_length = normalize_length
+# self.criterion = torch.nn.L1Loss(reduction='sum')
+#
+# def forward(self, token_length, pre_token_length):
+# loss_token_normalizer = token_length.size(0)
+# if self.normalize_length:
+# loss_token_normalizer = token_length.sum().type(torch.float32)
+# loss = self.criterion(token_length, pre_token_length)
+# loss = loss / loss_token_normalizer
+# return loss
+#
+#
+# def cif(hidden, alphas, threshold):
+# batch_size, len_time, hidden_size = hidden.size()
+#
+# # loop varss
+# integrate = torch.zeros([batch_size], device=hidden.device)
+# frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
+# # intermediate vars along time
+# list_fires = []
+# list_frames = []
+#
+# for t in range(len_time):
+# alpha = alphas[:, t]
+# distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
+#
+# integrate += alpha
+# list_fires.append(integrate)
+#
+# fire_place = integrate >= threshold
+# integrate = torch.where(fire_place,
+# integrate - torch.ones([batch_size], device=hidden.device),
+# integrate)
+# cur = torch.where(fire_place,
+# distribution_completion,
+# alpha)
+# remainds = alpha - cur
+#
+# frame += cur[:, None] * hidden[:, t, :]
+# list_frames.append(frame)
+# frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
+# remainds[:, None] * hidden[:, t, :],
+# frame)
+#
+# fires = torch.stack(list_fires, 1)
+# frames = torch.stack(list_frames, 1)
+# list_ls = []
+# len_labels = torch.round(alphas.sum(-1)).int()
+# max_label_len = len_labels.max()
+# for b in range(batch_size):
+# fire = fires[b, :]
+# l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
+# pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
+# list_ls.append(torch.cat([l, pad_l], 0))
+# return torch.stack(list_ls, 0), fires
+#
+#
+# def cif_wo_hidden(alphas, threshold):
+# batch_size, len_time = alphas.size()
+#
+# # loop varss
+# integrate = torch.zeros([batch_size], device=alphas.device)
+# # intermediate vars along time
+# list_fires = []
+#
+# for t in range(len_time):
+# alpha = alphas[:, t]
+#
+# integrate += alpha
+# list_fires.append(integrate)
+#
+# fire_place = integrate >= threshold
+# integrate = torch.where(fire_place,
+# integrate - torch.ones([batch_size], device=alphas.device)*threshold,
+# integrate)
+#
+# fires = torch.stack(list_fires, 1)
+# return fires
+#
+# @register_class("predictor_classes", "BATPredictor")
+# class BATPredictor(nn.Module):
+# def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False):
+# super(BATPredictor, self).__init__()
+#
+# self.pad = nn.ConstantPad1d((l_order, r_order), 0)
+# self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
+# self.cif_output = nn.Linear(idim, 1)
+# self.dropout = torch.nn.Dropout(p=dropout)
+# self.threshold = threshold
+# self.smooth_factor = smooth_factor
+# self.noise_threshold = noise_threshold
+# self.return_accum = return_accum
+#
+# def cif(
+# self,
+# input: Tensor,
+# alpha: Tensor,
+# beta: float = 1.0,
+# return_accum: bool = False,
+# ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+# B, S, C = input.size()
+# assert tuple(alpha.size()) == (B, S), f"{alpha.size()} != {(B, S)}"
+#
+# dtype = alpha.dtype
+# alpha = alpha.float()
+#
+# alpha_sum = alpha.sum(1)
+# feat_lengths = (alpha_sum / beta).floor().long()
+# T = feat_lengths.max()
+#
+# # aggregate and integrate
+# csum = alpha.cumsum(-1)
+# with torch.no_grad():
+# # indices used for scattering
+# right_idx = (csum / beta).floor().long().clip(max=T)
+# left_idx = right_idx.roll(1, dims=1)
+# left_idx[:, 0] = 0
+#
+# # count # of fires from each source
+# fire_num = right_idx - left_idx
+# extra_weights = (fire_num - 1).clip(min=0)
+# # The extra entry in last dim is for
+# output = input.new_zeros((B, T + 1, C))
+# source_range = torch.arange(1, 1 + S).unsqueeze(0).type_as(input)
+# zero = alpha.new_zeros((1,))
+#
+# # right scatter
+# fire_mask = fire_num > 0
+# right_weight = torch.where(
+# fire_mask,
+# csum - right_idx.type_as(alpha) * beta,
+# zero
+# ).type_as(input)
+# # assert right_weight.ge(0).all(), f"{right_weight} should be non-negative."
+# output.scatter_add_(
+# 1,
+# right_idx.unsqueeze(-1).expand(-1, -1, C),
+# right_weight.unsqueeze(-1) * input
+# )
+#
+# # left scatter
+# left_weight = (
+# alpha - right_weight - extra_weights.type_as(alpha) * beta
+# ).type_as(input)
+# output.scatter_add_(
+# 1,
+# left_idx.unsqueeze(-1).expand(-1, -1, C),
+# left_weight.unsqueeze(-1) * input
+# )
+#
+# # extra scatters
+# if extra_weights.ge(0).any():
+# extra_steps = extra_weights.max().item()
+# tgt_idx = left_idx
+# src_feats = input * beta
+# for _ in range(extra_steps):
+# tgt_idx = (tgt_idx + 1).clip(max=T)
+# # (B, S, 1)
+# src_mask = (extra_weights > 0)
+# output.scatter_add_(
+# 1,
+# tgt_idx.unsqueeze(-1).expand(-1, -1, C),
+# src_feats * src_mask.unsqueeze(2)
+# )
+# extra_weights -= 1
+#
+# output = output[:, :T, :]
+#
+# if return_accum:
+# return output, csum
+# else:
+# return output, alpha
+#
+# def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None):
+# h = hidden
+# context = h.transpose(1, 2)
+# queries = self.pad(context)
+# memory = self.cif_conv1d(queries)
+# output = memory + context
+# output = self.dropout(output)
+# output = output.transpose(1, 2)
+# output = torch.relu(output)
+# output = self.cif_output(output)
+# alphas = torch.sigmoid(output)
+# alphas = torch.nn.functional.relu(alphas*self.smooth_factor - self.noise_threshold)
+# if mask is not None:
+# alphas = alphas * mask.transpose(-1, -2).float()
+# if mask_chunk_predictor is not None:
+# alphas = alphas * mask_chunk_predictor
+# alphas = alphas.squeeze(-1)
+# if target_label_length is not None:
+# target_length = target_label_length
+# elif target_label is not None:
+# target_length = (target_label != ignore_id).float().sum(-1)
+# # logging.info("target_length: {}".format(target_length))
+# else:
+# target_length = None
+# token_num = alphas.sum(-1)
+# if target_length is not None:
+# # length_noise = torch.rand(alphas.size(0), device=alphas.device) - 0.5
+# # target_length = length_noise + target_length
+# alphas *= ((target_length + 1e-4) / token_num)[:, None].repeat(1, alphas.size(1))
+# acoustic_embeds, cif_peak = self.cif(hidden, alphas, self.threshold, self.return_accum)
+# return acoustic_embeds, token_num, alphas, cif_peak
diff --git a/funasr/models/bat/conformer_chunk_encoder.py b/funasr/models/bat/conformer_chunk_encoder.py
new file mode 100644
index 0000000..2dc03c3
--- /dev/null
+++ b/funasr/models/bat/conformer_chunk_encoder.py
@@ -0,0 +1,701 @@
+
+"""Conformer encoder definition."""
+
+import logging
+from typing import Union, Dict, List, Tuple, Optional
+
+import torch
+from torch import nn
+
+
+from funasr.models.bat.attention import (
+ RelPositionMultiHeadedAttentionChunk,
+)
+from funasr.models.transformer.embedding import (
+ StreamingRelPositionalEncoding,
+)
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.nets_utils import get_activation
+from funasr.models.transformer.utils.nets_utils import (
+ TooShortUttError,
+ check_short_utt,
+ make_chunk_mask,
+ make_source_mask,
+)
+from funasr.models.transformer.positionwise_feed_forward import (
+ PositionwiseFeedForward,
+)
+from funasr.models.transformer.utils.repeat import repeat, MultiBlocks
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
+from funasr.models.transformer.utils.subsampling import StreamingConvInput
+from funasr.utils.register import register_class
+
+
+
+class ChunkEncoderLayer(nn.Module):
+ """Chunk Conformer module definition.
+ Args:
+ block_size: Input/output size.
+ self_att: Self-attention module instance.
+ feed_forward: Feed-forward module instance.
+ feed_forward_macaron: Feed-forward module instance for macaron network.
+ conv_mod: Convolution module instance.
+ norm_class: Normalization module class.
+ norm_args: Normalization module arguments.
+ dropout_rate: Dropout rate.
+ """
+
+ def __init__(
+ self,
+ block_size: int,
+ self_att: torch.nn.Module,
+ feed_forward: torch.nn.Module,
+ feed_forward_macaron: torch.nn.Module,
+ conv_mod: torch.nn.Module,
+ norm_class: torch.nn.Module = LayerNorm,
+ norm_args: Dict = {},
+ dropout_rate: float = 0.0,
+ ) -> None:
+ """Construct a Conformer object."""
+ super().__init__()
+
+ self.self_att = self_att
+
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.feed_forward_scale = 0.5
+
+ self.conv_mod = conv_mod
+
+ self.norm_feed_forward = norm_class(block_size, **norm_args)
+ self.norm_self_att = norm_class(block_size, **norm_args)
+
+ self.norm_macaron = norm_class(block_size, **norm_args)
+ self.norm_conv = norm_class(block_size, **norm_args)
+ self.norm_final = norm_class(block_size, **norm_args)
+
+ self.dropout = torch.nn.Dropout(dropout_rate)
+
+ self.block_size = block_size
+ self.cache = None
+
+ def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+ """Initialize/Reset self-attention and convolution modules cache for streaming.
+ Args:
+ left_context: Number of left frames during chunk-by-chunk inference.
+ device: Device to use for cache tensor.
+ """
+ self.cache = [
+ torch.zeros(
+ (1, left_context, self.block_size),
+ device=device,
+ ),
+ torch.zeros(
+ (
+ 1,
+ self.block_size,
+ self.conv_mod.kernel_size - 1,
+ ),
+ device=device,
+ ),
+ ]
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Conformer input sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ mask: Source mask. (B, T)
+ chunk_mask: Chunk mask. (T_2, T_2)
+ Returns:
+ x: Conformer output sequences. (B, T, D_block)
+ mask: Source mask. (B, T)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ """
+ residual = x
+
+ x = self.norm_macaron(x)
+ x = residual + self.feed_forward_scale * self.dropout(
+ self.feed_forward_macaron(x)
+ )
+
+ residual = x
+ x = self.norm_self_att(x)
+ x_q = x
+ x = residual + self.dropout(
+ self.self_att(
+ x_q,
+ x,
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+ )
+
+ residual = x
+
+ x = self.norm_conv(x)
+ x, _ = self.conv_mod(x)
+ x = residual + self.dropout(x)
+ residual = x
+
+ x = self.norm_feed_forward(x)
+ x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
+
+ x = self.norm_final(x)
+ return x, mask, pos_enc
+
+ def chunk_forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_size: int = 16,
+ left_context: int = 0,
+ right_context: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode chunk of input sequence.
+ Args:
+ x: Conformer input sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ mask: Source mask. (B, T_2)
+ left_context: Number of frames in left context.
+ right_context: Number of frames in right context.
+ Returns:
+ x: Conformer output sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ """
+ residual = x
+
+ x = self.norm_macaron(x)
+ x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
+
+ residual = x
+ x = self.norm_self_att(x)
+ if left_context > 0:
+ key = torch.cat([self.cache[0], x], dim=1)
+ else:
+ key = x
+ val = key
+
+ if right_context > 0:
+ att_cache = key[:, -(left_context + right_context) : -right_context, :]
+ else:
+ att_cache = key[:, -left_context:, :]
+ x = residual + self.self_att(
+ x,
+ key,
+ val,
+ pos_enc,
+ mask,
+ left_context=left_context,
+ )
+
+ residual = x
+ x = self.norm_conv(x)
+ x, conv_cache = self.conv_mod(
+ x, cache=self.cache[1], right_context=right_context
+ )
+ x = residual + x
+ residual = x
+
+ x = self.norm_feed_forward(x)
+ x = residual + self.feed_forward_scale * self.feed_forward(x)
+
+ x = self.norm_final(x)
+ self.cache = [att_cache, conv_cache]
+
+ return x, pos_enc
+
+
+
+class CausalConvolution(nn.Module):
+ """ConformerConvolution module definition.
+ Args:
+ channels: The number of channels.
+ kernel_size: Size of the convolving kernel.
+ activation: Type of activation function.
+ norm_args: Normalization module arguments.
+ causal: Whether to use causal convolution (set to True if streaming).
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ norm_args: Dict = {},
+ causal: bool = False,
+ ) -> None:
+ """Construct an ConformerConvolution object."""
+ super().__init__()
+
+ assert (kernel_size - 1) % 2 == 0
+
+ self.kernel_size = kernel_size
+
+ self.pointwise_conv1 = torch.nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ if causal:
+ self.lorder = kernel_size - 1
+ padding = 0
+ else:
+ self.lorder = 0
+ padding = (kernel_size - 1) // 2
+
+ self.depthwise_conv = torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ )
+ self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
+ self.pointwise_conv2 = torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ self.activation = activation
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ cache: Optional[torch.Tensor] = None,
+ right_context: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute convolution module.
+ Args:
+ x: ConformerConvolution input sequences. (B, T, D_hidden)
+ cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
+ right_context: Number of frames in right context.
+ Returns:
+ x: ConformerConvolution output sequences. (B, T, D_hidden)
+ cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
+ """
+ x = self.pointwise_conv1(x.transpose(1, 2))
+ x = torch.nn.functional.glu(x, dim=1)
+
+ if self.lorder > 0:
+ if cache is None:
+ x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+ else:
+ x = torch.cat([cache, x], dim=2)
+
+ if right_context > 0:
+ cache = x[:, :, -(self.lorder + right_context) : -right_context]
+ else:
+ cache = x[:, :, -self.lorder :]
+
+ x = self.depthwise_conv(x)
+ x = self.activation(self.norm(x))
+
+ x = self.pointwise_conv2(x).transpose(1, 2)
+
+ return x, cache
+
+@register_class("encoder_classes", "ConformerChunkEncoder")
+class ConformerChunkEncoder(nn.Module):
+ """Encoder module definition.
+ Args:
+ input_size: Input size.
+ body_conf: Encoder body configuration.
+ input_conf: Encoder input configuration.
+ main_conf: Encoder main configuration.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ embed_vgg_like: bool = False,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 3,
+ macaron_style: bool = False,
+ rel_pos_type: str = "legacy",
+ pos_enc_layer_type: str = "rel_pos",
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ zero_triu: bool = False,
+ norm_type: str = "layer_norm",
+ cnn_module_kernel: int = 31,
+ conv_mod_norm_eps: float = 0.00001,
+ conv_mod_norm_momentum: float = 0.1,
+ simplified_att_score: bool = False,
+ dynamic_chunk_training: bool = False,
+ short_chunk_threshold: float = 0.75,
+ short_chunk_size: int = 25,
+ left_chunk_size: int = 0,
+ time_reduction_factor: int = 1,
+ unified_model_training: bool = False,
+ default_chunk_size: int = 16,
+ jitter_range: int = 4,
+ subsampling_factor: int = 1,
+ ) -> None:
+ """Construct an Encoder object."""
+ super().__init__()
+
+
+ self.embed = StreamingConvInput(
+ input_size,
+ output_size,
+ subsampling_factor,
+ vgg_like=embed_vgg_like,
+ output_size=output_size,
+ )
+
+ self.pos_enc = StreamingRelPositionalEncoding(
+ output_size,
+ positional_dropout_rate,
+ )
+
+ activation = get_activation(
+ activation_type
+ )
+
+ pos_wise_args = (
+ output_size,
+ linear_units,
+ positional_dropout_rate,
+ activation,
+ )
+
+ conv_mod_norm_args = {
+ "eps": conv_mod_norm_eps,
+ "momentum": conv_mod_norm_momentum,
+ }
+
+ conv_mod_args = (
+ output_size,
+ cnn_module_kernel,
+ activation,
+ conv_mod_norm_args,
+ dynamic_chunk_training or unified_model_training,
+ )
+
+ mult_att_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ simplified_att_score,
+ )
+
+
+ fn_modules = []
+ for _ in range(num_blocks):
+ module = lambda: ChunkEncoderLayer(
+ output_size,
+ RelPositionMultiHeadedAttentionChunk(*mult_att_args),
+ PositionwiseFeedForward(*pos_wise_args),
+ PositionwiseFeedForward(*pos_wise_args),
+ CausalConvolution(*conv_mod_args),
+ dropout_rate=dropout_rate,
+ )
+ fn_modules.append(module)
+
+ self.encoders = MultiBlocks(
+ [fn() for fn in fn_modules],
+ output_size,
+ )
+
+ self._output_size = output_size
+
+ self.dynamic_chunk_training = dynamic_chunk_training
+ self.short_chunk_threshold = short_chunk_threshold
+ self.short_chunk_size = short_chunk_size
+ self.left_chunk_size = left_chunk_size
+
+ self.unified_model_training = unified_model_training
+ self.default_chunk_size = default_chunk_size
+ self.jitter_range = jitter_range
+
+ self.time_reduction_factor = time_reduction_factor
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
+ """Return the corresponding number of sample for a given chunk size, in frames.
+ Where size is the number of features frames after applying subsampling.
+ Args:
+ size: Number of frames after subsampling.
+ hop_length: Frontend's hop length
+ Returns:
+ : Number of raw samples
+ """
+ return self.embed.get_size_before_subsampling(size) * hop_length
+
+ def get_encoder_input_size(self, size: int) -> int:
+ """Return the corresponding number of sample for a given chunk size, in frames.
+ Where size is the number of features frames after applying subsampling.
+ Args:
+ size: Number of frames after subsampling.
+ Returns:
+ : Number of raw samples
+ """
+ return self.embed.get_size_before_subsampling(size)
+
+
+ def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+ """Initialize/Reset encoder streaming cache.
+ Args:
+ left_context: Number of frames in left context.
+ device: Device ID.
+ """
+ return self.encoders.reset_streaming_cache(left_context, device)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Encoder input features. (B, T_in, F)
+ x_len: Encoder input features lengths. (B,)
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ x_len: Encoder outputs lenghts. (B,)
+ """
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len).to(x.device)
+
+ if self.unified_model_training:
+ if self.training:
+ chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ else:
+ chunk_size = self.default_chunk_size
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+ x_utt = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=None,
+ )
+ x_chunk = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x_utt = x_utt[:,::self.time_reduction_factor,:]
+ x_chunk = x_chunk[:,::self.time_reduction_factor,:]
+ olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+ return x_utt, x_chunk, olens
+
+ elif self.dynamic_chunk_training:
+ max_len = x.size(1)
+ if self.training:
+ chunk_size = torch.randint(1, max_len, (1,)).item()
+
+ if chunk_size > (max_len * self.short_chunk_threshold):
+ chunk_size = max_len
+ else:
+ chunk_size = (chunk_size % self.short_chunk_size) + 1
+ else:
+ chunk_size = self.default_chunk_size
+
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+ else:
+ x, mask = self.embed(x, mask, None)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = None
+ x = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+ olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+ return x, olens, None
+
+ def full_utt_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Encoder input features. (B, T_in, F)
+ x_len: Encoder input features lengths. (B,)
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ x_len: Encoder outputs lenghts. (B,)
+ """
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len).to(x.device)
+ x, mask = self.embed(x, mask, None)
+ pos_enc = self.pos_enc(x)
+ x_utt = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=None,
+ )
+
+ if self.time_reduction_factor > 1:
+ x_utt = x_utt[:,::self.time_reduction_factor,:]
+ return x_utt
+
+ def simu_chunk_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ ) -> torch.Tensor:
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len)
+
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+
+ x = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+
+ return x
+
+ def chunk_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ processed_frames: torch.tensor,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ ) -> torch.Tensor:
+ """Encode input sequences as chunks.
+ Args:
+ x: Encoder input features. (1, T_in, F)
+ x_len: Encoder input features lengths. (1,)
+ processed_frames: Number of frames already seen.
+ left_context: Number of frames in left context.
+ right_context: Number of frames in right context.
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ """
+ mask = make_source_mask(x_len)
+ x, mask = self.embed(x, mask, None)
+
+ if left_context > 0:
+ processed_mask = (
+ torch.arange(left_context, device=x.device)
+ .view(1, left_context)
+ .flip(1)
+ )
+ processed_mask = processed_mask >= processed_frames
+ mask = torch.cat([processed_mask, mask], dim=1)
+ pos_enc = self.pos_enc(x, left_context=left_context)
+ x = self.encoders.chunk_forward(
+ x,
+ pos_enc,
+ mask,
+ chunk_size=chunk_size,
+ left_context=left_context,
+ right_context=right_context,
+ )
+
+ if right_context > 0:
+ x = x[:, 0:-right_context, :]
+
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+ return x
diff --git a/funasr/models/bat/model.py b/funasr/models/bat/model.py
index 83ce302..d814e31 100644
--- a/funasr/models/bat/model.py
+++ b/funasr/models/bat/model.py
@@ -5,34 +5,23 @@
from typing import Dict, List, Optional, Tuple, Union
import torch
+import torch.nn as nn
from packaging.version import parse as V
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.decoder.rnnt_decoder import RNNTDecoder
-from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.joint_net.joint_network import JointNetwork
+
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.transformer.add_sos_eos import add_sos_eos
-from funasr.layers.abs_normalize import AbsNormalize
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.train_utils.device_funcs import force_gatherable
-from funasr.models.base_model import FunASRModel
-if V(torch.__version__) >= V("1.6.0"):
- from torch.cuda.amp import autocast
-else:
-
- @contextmanager
- def autocast(enabled=True):
- yield
+from torch.cuda.amp import autocast
-class BATModel(FunASRModel):
+
+class BATModel(nn.Module):
"""BATModel module definition.
Args:
@@ -61,18 +50,7 @@
def __init__(
self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- encoder: AbsEncoder,
- decoder: RNNTDecoder,
- joint_network: JointNetwork,
- att_decoder: Optional[AbsAttDecoder] = None,
- predictor = None,
- transducer_weight: float = 1.0,
- predictor_weight: float = 1.0,
+
cif_weight: float = 1.0,
fastemit_lambda: float = 0.0,
auxiliary_ctc_weight: float = 0.0,
@@ -89,6 +67,7 @@
length_normalized_loss: bool = False,
r_d: int = 5,
r_u: int = 5,
+ **kwargs,
) -> None:
"""Construct an BATModel object."""
super().__init__()
diff --git a/funasr/models/bici_paraformer/cif_predictor.py b/funasr/models/bici_paraformer/cif_predictor.py
new file mode 100644
index 0000000..67d801c
--- /dev/null
+++ b/funasr/models/bici_paraformer/cif_predictor.py
@@ -0,0 +1,340 @@
+import torch
+from torch import nn
+from torch import Tensor
+import logging
+import numpy as np
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.scama.utils import sequence_mask
+from typing import Optional, Tuple
+
+from funasr.utils.register import register_class
+
+
+class mae_loss(nn.Module):
+
+ def __init__(self, normalize_length=False):
+ super(mae_loss, self).__init__()
+ self.normalize_length = normalize_length
+ self.criterion = torch.nn.L1Loss(reduction='sum')
+
+ def forward(self, token_length, pre_token_length):
+ loss_token_normalizer = token_length.size(0)
+ if self.normalize_length:
+ loss_token_normalizer = token_length.sum().type(torch.float32)
+ loss = self.criterion(token_length, pre_token_length)
+ loss = loss / loss_token_normalizer
+ return loss
+
+
+def cif(hidden, alphas, threshold):
+ batch_size, len_time, hidden_size = hidden.size()
+
+ # loop varss
+ integrate = torch.zeros([batch_size], device=hidden.device)
+ frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
+ # intermediate vars along time
+ list_fires = []
+ list_frames = []
+
+ for t in range(len_time):
+ alpha = alphas[:, t]
+ distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
+
+ integrate += alpha
+ list_fires.append(integrate)
+
+ fire_place = integrate >= threshold
+ integrate = torch.where(fire_place,
+ integrate - torch.ones([batch_size], device=hidden.device),
+ integrate)
+ cur = torch.where(fire_place,
+ distribution_completion,
+ alpha)
+ remainds = alpha - cur
+
+ frame += cur[:, None] * hidden[:, t, :]
+ list_frames.append(frame)
+ frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
+ remainds[:, None] * hidden[:, t, :],
+ frame)
+
+ fires = torch.stack(list_fires, 1)
+ frames = torch.stack(list_frames, 1)
+ list_ls = []
+ len_labels = torch.round(alphas.sum(-1)).int()
+ max_label_len = len_labels.max()
+ for b in range(batch_size):
+ fire = fires[b, :]
+ l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
+ pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
+ list_ls.append(torch.cat([l, pad_l], 0))
+ return torch.stack(list_ls, 0), fires
+
+
+def cif_wo_hidden(alphas, threshold):
+ batch_size, len_time = alphas.size()
+
+ # loop varss
+ integrate = torch.zeros([batch_size], device=alphas.device)
+ # intermediate vars along time
+ list_fires = []
+
+ for t in range(len_time):
+ alpha = alphas[:, t]
+
+ integrate += alpha
+ list_fires.append(integrate)
+
+ fire_place = integrate >= threshold
+ integrate = torch.where(fire_place,
+ integrate - torch.ones([batch_size], device=alphas.device)*threshold,
+ integrate)
+
+ fires = torch.stack(list_fires, 1)
+ return fires
+
+@register_class("predictor_classes", "CifPredictorV3")
+class CifPredictorV3(nn.Module):
+ def __init__(self,
+ idim,
+ l_order,
+ r_order,
+ threshold=1.0,
+ dropout=0.1,
+ smooth_factor=1.0,
+ noise_threshold=0,
+ tail_threshold=0.0,
+ tf2torch_tensor_name_prefix_torch="predictor",
+ tf2torch_tensor_name_prefix_tf="seq2seq/cif",
+ smooth_factor2=1.0,
+ noise_threshold2=0,
+ upsample_times=5,
+ upsample_type="cnn",
+ use_cif1_cnn=True,
+ tail_mask=True,
+ ):
+ super(CifPredictorV3, self).__init__()
+
+ self.pad = nn.ConstantPad1d((l_order, r_order), 0)
+ self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
+ self.cif_output = nn.Linear(idim, 1)
+ self.dropout = torch.nn.Dropout(p=dropout)
+ self.threshold = threshold
+ self.smooth_factor = smooth_factor
+ self.noise_threshold = noise_threshold
+ self.tail_threshold = tail_threshold
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+
+ self.upsample_times = upsample_times
+ self.upsample_type = upsample_type
+ self.use_cif1_cnn = use_cif1_cnn
+ if self.upsample_type == 'cnn':
+ self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+ self.cif_output2 = nn.Linear(idim, 1)
+ elif self.upsample_type == 'cnn_blstm':
+ self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+ self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
+ self.cif_output2 = nn.Linear(idim*2, 1)
+ elif self.upsample_type == 'cnn_attn':
+ self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+ from funasr.models.transformer.encoder import EncoderLayer as TransformerEncoderLayer
+ from funasr.models.transformer.attention import MultiHeadedAttention
+ from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
+ positionwise_layer_args = (
+ idim,
+ idim*2,
+ 0.1,
+ )
+ self.self_attn = TransformerEncoderLayer(
+ idim,
+ MultiHeadedAttention(
+ 4, idim, 0.1
+ ),
+ PositionwiseFeedForward(*positionwise_layer_args),
+ 0.1,
+ True, #normalize_before,
+ False, #concat_after,
+ )
+ self.cif_output2 = nn.Linear(idim, 1)
+ self.smooth_factor2 = smooth_factor2
+ self.noise_threshold2 = noise_threshold2
+
+ def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
+ target_label_length=None):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ output = torch.relu(self.cif_conv1d(queries))
+
+ # alphas2 is an extra head for timestamp prediction
+ if not self.use_cif1_cnn:
+ _output = context
+ else:
+ _output = output
+ if self.upsample_type == 'cnn':
+ output2 = self.upsample_cnn(_output)
+ output2 = output2.transpose(1,2)
+ elif self.upsample_type == 'cnn_blstm':
+ output2 = self.upsample_cnn(_output)
+ output2 = output2.transpose(1,2)
+ output2, (_, _) = self.blstm(output2)
+ elif self.upsample_type == 'cnn_attn':
+ output2 = self.upsample_cnn(_output)
+ output2 = output2.transpose(1,2)
+ output2, _ = self.self_attn(output2, mask)
+ # import pdb; pdb.set_trace()
+ alphas2 = torch.sigmoid(self.cif_output2(output2))
+ alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
+ # repeat the mask in T demension to match the upsampled length
+ if mask is not None:
+ mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
+ mask2 = mask2.unsqueeze(-1)
+ alphas2 = alphas2 * mask2
+ alphas2 = alphas2.squeeze(-1)
+ token_num2 = alphas2.sum(-1)
+
+ output = output.transpose(1, 2)
+
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+ if mask is not None:
+ mask = mask.transpose(-1, -2).float()
+ alphas = alphas * mask
+ if mask_chunk_predictor is not None:
+ alphas = alphas * mask_chunk_predictor
+ alphas = alphas.squeeze(-1)
+ mask = mask.squeeze(-1)
+ if target_label_length is not None:
+ target_length = target_label_length
+ elif target_label is not None:
+ target_length = (target_label != ignore_id).float().sum(-1)
+ else:
+ target_length = None
+ token_num = alphas.sum(-1)
+
+ if target_length is not None:
+ alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+ elif self.tail_threshold > 0.0:
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+ if target_length is None and self.tail_threshold > 0.0:
+ token_num_int = torch.max(token_num).type(torch.int32).item()
+ acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+ return acoustic_embeds, token_num, alphas, cif_peak, token_num2
+
+ def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
+ h = hidden
+ b = hidden.shape[0]
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ output = torch.relu(self.cif_conv1d(queries))
+
+ # alphas2 is an extra head for timestamp prediction
+ if not self.use_cif1_cnn:
+ _output = context
+ else:
+ _output = output
+ if self.upsample_type == 'cnn':
+ output2 = self.upsample_cnn(_output)
+ output2 = output2.transpose(1,2)
+ elif self.upsample_type == 'cnn_blstm':
+ output2 = self.upsample_cnn(_output)
+ output2 = output2.transpose(1,2)
+ output2, (_, _) = self.blstm(output2)
+ elif self.upsample_type == 'cnn_attn':
+ output2 = self.upsample_cnn(_output)
+ output2 = output2.transpose(1,2)
+ output2, _ = self.self_attn(output2, mask)
+ alphas2 = torch.sigmoid(self.cif_output2(output2))
+ alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
+ # repeat the mask in T demension to match the upsampled length
+ if mask is not None:
+ mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
+ mask2 = mask2.unsqueeze(-1)
+ alphas2 = alphas2 * mask2
+ alphas2 = alphas2.squeeze(-1)
+ _token_num = alphas2.sum(-1)
+ if token_num is not None:
+ alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
+ # re-downsample
+ ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
+ ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
+ # upsampled alphas and cif_peak
+ us_alphas = alphas2
+ us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
+ return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
+
+ def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
+ b, t, d = hidden.size()
+ tail_threshold = self.tail_threshold
+ if mask is not None:
+ zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
+ ones_t = torch.ones_like(zeros_t)
+ mask_1 = torch.cat([mask, zeros_t], dim=1)
+ mask_2 = torch.cat([ones_t, mask], dim=1)
+ mask = mask_2 - mask_1
+ tail_threshold = mask * tail_threshold
+ alphas = torch.cat([alphas, zeros_t], dim=1)
+ alphas = torch.add(alphas, tail_threshold)
+ else:
+ tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
+ tail_threshold = torch.reshape(tail_threshold, (1, 1))
+ alphas = torch.cat([alphas, tail_threshold], dim=1)
+ zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
+ hidden = torch.cat([hidden, zeros], dim=1)
+ token_num = alphas.sum(dim=-1)
+ token_num_floor = torch.floor(token_num)
+
+ return hidden, alphas, token_num_floor
+
+ def gen_frame_alignments(self,
+ alphas: torch.Tensor = None,
+ encoder_sequence_length: torch.Tensor = None):
+ batch_size, maximum_length = alphas.size()
+ int_type = torch.int32
+
+ is_training = self.training
+ if is_training:
+ token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
+ else:
+ token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
+
+ max_token_num = torch.max(token_num).item()
+
+ alphas_cumsum = torch.cumsum(alphas, dim=1)
+ alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
+ alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
+
+ index = torch.ones([batch_size, max_token_num], dtype=int_type)
+ index = torch.cumsum(index, dim=1)
+ index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
+
+ index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
+ index_div_bool_zeros = index_div.eq(0)
+ index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
+ index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
+ token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
+ index_div_bool_zeros_count *= token_num_mask
+
+ index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
+ ones = torch.ones_like(index_div_bool_zeros_count_tile)
+ zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
+ ones = torch.cumsum(ones, dim=2)
+ cond = index_div_bool_zeros_count_tile == ones
+ index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
+
+ index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
+ index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
+ index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
+ index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
+ predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
+ int_type).to(encoder_sequence_length.device)
+ index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
+
+ predictor_alignments = index_div_bool_zeros_count_tile_out
+ predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
+ return predictor_alignments.detach(), predictor_alignments_length.detach()
diff --git a/funasr/models/bici_paraformer/model.py b/funasr/models/bici_paraformer/model.py
new file mode 100644
index 0000000..23a6985
--- /dev/null
+++ b/funasr/models/bici_paraformer/model.py
@@ -0,0 +1,328 @@
+
+import logging
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+import tempfile
+import codecs
+import requests
+import re
+import copy
+import torch
+import torch.nn as nn
+import random
+import numpy as np
+import time
+
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import force_gatherable
+
+from funasr.models.paraformer.search import Hypothesis
+
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+from funasr.utils.register import register_class, registry_tables
+from funasr.models.ctc.ctc import CTC
+
+from funasr.models.paraformer.model import Paraformer
+
+@register_class("model_classes", "BiCifParaformer")
+class BiCifParaformer(Paraformer):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+
+ def _calc_pre2_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
+
+ # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+ loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
+
+ return loss_pre2
+
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # 0. sampler
+ decoder_out_1st = None
+ if self.sampling_ratio > 0.0:
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds)
+ else:
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre
+
+
+ def calc_predictor(self, encoder_out, encoder_out_lens):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
+ None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
+ return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+
+ def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
+ encoder_out_mask,
+ token_num)
+ return ds_alphas, ds_cif_peak, us_alphas, us_peaks
+
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+
+ loss_ctc, cer_ctc = None, None
+ loss_pre = None
+ stats = dict()
+
+ # decoder: CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+
+ # decoder: Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ loss_pre2 = self._calc_pre2_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+ else:
+ loss = self.ctc_weight * loss_ctc + (
+ 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+ stats["loss_pre2"] = loss_pre2.detach().cpu()
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = int((text_lengths + self.predictor_bias).sum())
+
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def generate(self,
+ data_in: list,
+ data_lengths: list = None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=self.frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data[
+ "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # predictor
+ predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return []
+ decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+ pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+ # BiCifParaformer, test no bias cif2
+
+ _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
+ pre_token_length)
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+ minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
+ us_peaks[i][:encoder_out_lens[i] * 3],
+ copy.copy(token),
+ vad_offset=kwargs.get("begin_time", 0))
+
+ text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token, timestamp)
+
+ result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
+ "time_stamp_postprocessed": time_stamp_postprocessed,
+ "word_lists": word_lists
+ }
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+
+
+ return results, meta_data
diff --git a/funasr/models/branchformer/branchformer_encoder.py b/funasr/models/branchformer/encoder.py
similarity index 98%
rename from funasr/models/branchformer/branchformer_encoder.py
rename to funasr/models/branchformer/encoder.py
index 0f037c8..11b6429 100644
--- a/funasr/models/branchformer/branchformer_encoder.py
+++ b/funasr/models/branchformer/encoder.py
@@ -16,8 +16,8 @@
import numpy
import torch
+import torch.nn as nn
-from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.branchformer.cgmlp import ConvolutionalGatingMLP
from funasr.models.branchformer.fastformer import FastSelfAttention
from funasr.models.transformer.utils.nets_utils import make_pad_mask
@@ -33,8 +33,8 @@
ScaledPositionalEncoding,
)
from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.subsampling import (
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling2,
Conv2dSubsampling6,
@@ -43,6 +43,7 @@
check_short_utt,
)
+from funasr.utils.register import register_class
class BranchformerEncoderLayer(torch.nn.Module):
"""Branchformer encoder layer module.
@@ -290,8 +291,8 @@
return x, mask
-
-class BranchformerEncoder(AbsEncoder):
+@register_class("encoder_classes", "BranchformerEncoder")
+class BranchformerEncoder(nn.Module):
"""Branchformer encoder module."""
def __init__(
diff --git a/funasr/models/branchformer/model.py b/funasr/models/branchformer/model.py
index 5cb2af7..a14b407 100644
--- a/funasr/models/branchformer/model.py
+++ b/funasr/models/branchformer/model.py
@@ -1,57 +1,9 @@
import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
-import torch
-import torch.nn as nn
-import random
-import numpy as np
-import time
-# from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
-from funasr.models.paraformer.search import Hypothesis
-
-from funasr.models.model_class_factory import *
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
-else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.transformer.model import Transformer
+from funasr.utils.register import register_class
+@register_class("model_classes", "Branchformer")
class Branchformer(Transformer):
"""CTC-attention hybrid Encoder-Decoder model"""
diff --git a/funasr/models/cnn/DTDNN.py b/funasr/models/cnn/DTDNN.py
index 3de0b1e..02fcfdf 100644
--- a/funasr/models/cnn/DTDNN.py
+++ b/funasr/models/cnn/DTDNN.py
@@ -6,7 +6,7 @@
import torch.nn.functional as F
from torch import nn
-from funasr.modules.cnn.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
+from funasr.models.cnn.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
BasicResBlock, get_nonlinear
diff --git a/funasr/models/cnn/ResNet.py b/funasr/models/cnn/ResNet.py
index 54c3901..b846e9e 100644
--- a/funasr/models/cnn/ResNet.py
+++ b/funasr/models/cnn/ResNet.py
@@ -16,7 +16,7 @@
import torch.nn.functional as F
import funasr.models.sond.pooling.pooling_layers as pooling_layers
-from funasr.modules.cnn.fusion import AFF
+from funasr.models.cnn.fusion import AFF
class ReLU(nn.Hardtanh):
diff --git a/funasr/models/cnn/ResNet_aug.py b/funasr/models/cnn/ResNet_aug.py
index 6b03c67..95416ef 100644
--- a/funasr/models/cnn/ResNet_aug.py
+++ b/funasr/models/cnn/ResNet_aug.py
@@ -16,7 +16,7 @@
import torch.nn.functional as F
import funasr.models.sond.pooling.pooling_layers as pooling_layers
-from funasr.modules.cnn.fusion import AFF
+from funasr.models.cnn.fusion import AFF
class ReLU(nn.Hardtanh):
diff --git a/funasr/models/conformer/conformer_encoder.py b/funasr/models/conformer/conformer_encoder.py
deleted file mode 100644
index 4b8cfe6..0000000
--- a/funasr/models/conformer/conformer_encoder.py
+++ /dev/null
@@ -1,1280 +0,0 @@
-# Copyright 2020 Tomoki Hayashi
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-
-"""Conformer encoder definition."""
-
-import logging
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-from typing import Dict
-
-import torch
-from torch import nn
-
-from funasr.models.ctc import CTC
-from funasr.models.transformer.attention import (
- MultiHeadedAttention, # noqa: H301
- RelPositionMultiHeadedAttention, # noqa: H301
- RelPositionMultiHeadedAttentionChunk,
- LegacyRelPositionMultiHeadedAttention, # noqa: H301
-)
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.transformer.embedding import (
- PositionalEncoding, # noqa: H301
- ScaledPositionalEncoding, # noqa: H301
- RelPositionalEncoding, # noqa: H301
- LegacyRelPositionalEncoding, # noqa: H301
- StreamingRelPositionalEncoding,
-)
-from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
-from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
-from funasr.models.transformer.utils.nets_utils import get_activation
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.transformer.utils.nets_utils import (
- TooShortUttError,
- check_short_utt,
- make_chunk_mask,
- make_source_mask,
-)
-from funasr.models.transformer.positionwise_feed_forward import (
- PositionwiseFeedForward, # noqa: H301
-)
-from funasr.models.transformer.repeat import repeat, MultiBlocks
-from funasr.models.transformer.subsampling import Conv2dSubsampling
-from funasr.models.transformer.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.subsampling import TooShortUttError
-from funasr.models.transformer.subsampling import check_short_utt
-from funasr.models.transformer.subsampling import Conv2dSubsamplingPad
-from funasr.models.transformer.subsampling import StreamingConvInput
-
-class ConvolutionModule(nn.Module):
- """ConvolutionModule in Conformer model.
-
- Args:
- channels (int): The number of channels of conv layers.
- kernel_size (int): Kernerl size of conv layers.
-
- """
-
- def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
- """Construct an ConvolutionModule object."""
- super(ConvolutionModule, self).__init__()
- # kernerl_size should be a odd number for 'SAME' padding
- assert (kernel_size - 1) % 2 == 0
-
- self.pointwise_conv1 = nn.Conv1d(
- channels,
- 2 * channels,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=bias,
- )
- self.depthwise_conv = nn.Conv1d(
- channels,
- channels,
- kernel_size,
- stride=1,
- padding=(kernel_size - 1) // 2,
- groups=channels,
- bias=bias,
- )
- self.norm = nn.BatchNorm1d(channels)
- self.pointwise_conv2 = nn.Conv1d(
- channels,
- channels,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=bias,
- )
- self.activation = activation
-
- def forward(self, x):
- """Compute convolution module.
-
- Args:
- x (torch.Tensor): Input tensor (#batch, time, channels).
-
- Returns:
- torch.Tensor: Output tensor (#batch, time, channels).
-
- """
- # exchange the temporal dimension and the feature dimension
- x = x.transpose(1, 2)
-
- # GLU mechanism
- x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
- x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
-
- # 1D Depthwise Conv
- x = self.depthwise_conv(x)
- x = self.activation(self.norm(x))
-
- x = self.pointwise_conv2(x)
-
- return x.transpose(1, 2)
-
-
-class EncoderLayer(nn.Module):
- """Encoder layer module.
-
- Args:
- size (int): Input dimension.
- self_attn (torch.nn.Module): Self-attention module instance.
- `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
- can be used as the argument.
- feed_forward (torch.nn.Module): Feed-forward module instance.
- `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
- can be used as the argument.
- feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
- `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
- can be used as the argument.
- conv_module (torch.nn.Module): Convolution module instance.
- `ConvlutionModule` instance can be used as the argument.
- dropout_rate (float): Dropout rate.
- normalize_before (bool): Whether to use layer_norm before the first block.
- concat_after (bool): Whether to concat attention layer's input and output.
- if True, additional linear will be applied.
- i.e. x -> x + linear(concat(x, att(x)))
- if False, no additional linear will be applied. i.e. x -> x + att(x)
- stochastic_depth_rate (float): Proability to skip this layer.
- During training, the layer may skip residual computation and return input
- as-is with given probability.
- """
-
- def __init__(
- self,
- size,
- self_attn,
- feed_forward,
- feed_forward_macaron,
- conv_module,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- stochastic_depth_rate=0.0,
- ):
- """Construct an EncoderLayer object."""
- super(EncoderLayer, self).__init__()
- self.self_attn = self_attn
- self.feed_forward = feed_forward
- self.feed_forward_macaron = feed_forward_macaron
- self.conv_module = conv_module
- self.norm_ff = LayerNorm(size) # for the FNN module
- self.norm_mha = LayerNorm(size) # for the MHA module
- if feed_forward_macaron is not None:
- self.norm_ff_macaron = LayerNorm(size)
- self.ff_scale = 0.5
- else:
- self.ff_scale = 1.0
- if self.conv_module is not None:
- self.norm_conv = LayerNorm(size) # for the CNN module
- self.norm_final = LayerNorm(size) # for the final output of the block
- self.dropout = nn.Dropout(dropout_rate)
- self.size = size
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear = nn.Linear(size + size, size)
- self.stochastic_depth_rate = stochastic_depth_rate
-
- def forward(self, x_input, mask, cache=None):
- """Compute encoded features.
-
- Args:
- x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
- - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
- - w/o pos emb: Tensor (#batch, time, size).
- mask (torch.Tensor): Mask tensor for the input (#batch, time).
- cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-
- Returns:
- torch.Tensor: Output tensor (#batch, time, size).
- torch.Tensor: Mask tensor (#batch, time).
-
- """
- if isinstance(x_input, tuple):
- x, pos_emb = x_input[0], x_input[1]
- else:
- x, pos_emb = x_input, None
-
- skip_layer = False
- # with stochastic depth, residual connection `x + f(x)` becomes
- # `x <- x + 1 / (1 - p) * f(x)` at training time.
- stoch_layer_coeff = 1.0
- if self.training and self.stochastic_depth_rate > 0:
- skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
- stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
-
- if skip_layer:
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
- if pos_emb is not None:
- return (x, pos_emb), mask
- return x, mask
-
- # whether to use macaron style
- if self.feed_forward_macaron is not None:
- residual = x
- if self.normalize_before:
- x = self.norm_ff_macaron(x)
- x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
- self.feed_forward_macaron(x)
- )
- if not self.normalize_before:
- x = self.norm_ff_macaron(x)
-
- # multi-headed self-attention module
- residual = x
- if self.normalize_before:
- x = self.norm_mha(x)
-
- if cache is None:
- x_q = x
- else:
- assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
- x_q = x[:, -1:, :]
- residual = residual[:, -1:, :]
- mask = None if mask is None else mask[:, -1:, :]
-
- if pos_emb is not None:
- x_att = self.self_attn(x_q, x, x, pos_emb, mask)
- else:
- x_att = self.self_attn(x_q, x, x, mask)
-
- if self.concat_after:
- x_concat = torch.cat((x, x_att), dim=-1)
- x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
- else:
- x = residual + stoch_layer_coeff * self.dropout(x_att)
- if not self.normalize_before:
- x = self.norm_mha(x)
-
- # convolution module
- if self.conv_module is not None:
- residual = x
- if self.normalize_before:
- x = self.norm_conv(x)
- x = residual + stoch_layer_coeff * self.dropout(self.conv_module(x))
- if not self.normalize_before:
- x = self.norm_conv(x)
-
- # feed forward module
- residual = x
- if self.normalize_before:
- x = self.norm_ff(x)
- x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
- self.feed_forward(x)
- )
- if not self.normalize_before:
- x = self.norm_ff(x)
-
- if self.conv_module is not None:
- x = self.norm_final(x)
-
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
-
- if pos_emb is not None:
- return (x, pos_emb), mask
-
- return x, mask
-
-class ChunkEncoderLayer(torch.nn.Module):
- """Chunk Conformer module definition.
- Args:
- block_size: Input/output size.
- self_att: Self-attention module instance.
- feed_forward: Feed-forward module instance.
- feed_forward_macaron: Feed-forward module instance for macaron network.
- conv_mod: Convolution module instance.
- norm_class: Normalization module class.
- norm_args: Normalization module arguments.
- dropout_rate: Dropout rate.
- """
-
- def __init__(
- self,
- block_size: int,
- self_att: torch.nn.Module,
- feed_forward: torch.nn.Module,
- feed_forward_macaron: torch.nn.Module,
- conv_mod: torch.nn.Module,
- norm_class: torch.nn.Module = LayerNorm,
- norm_args: Dict = {},
- dropout_rate: float = 0.0,
- ) -> None:
- """Construct a Conformer object."""
- super().__init__()
-
- self.self_att = self_att
-
- self.feed_forward = feed_forward
- self.feed_forward_macaron = feed_forward_macaron
- self.feed_forward_scale = 0.5
-
- self.conv_mod = conv_mod
-
- self.norm_feed_forward = norm_class(block_size, **norm_args)
- self.norm_self_att = norm_class(block_size, **norm_args)
-
- self.norm_macaron = norm_class(block_size, **norm_args)
- self.norm_conv = norm_class(block_size, **norm_args)
- self.norm_final = norm_class(block_size, **norm_args)
-
- self.dropout = torch.nn.Dropout(dropout_rate)
-
- self.block_size = block_size
- self.cache = None
-
- def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
- """Initialize/Reset self-attention and convolution modules cache for streaming.
- Args:
- left_context: Number of left frames during chunk-by-chunk inference.
- device: Device to use for cache tensor.
- """
- self.cache = [
- torch.zeros(
- (1, left_context, self.block_size),
- device=device,
- ),
- torch.zeros(
- (
- 1,
- self.block_size,
- self.conv_mod.kernel_size - 1,
- ),
- device=device,
- ),
- ]
-
- def forward(
- self,
- x: torch.Tensor,
- pos_enc: torch.Tensor,
- mask: torch.Tensor,
- chunk_mask: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Encode input sequences.
- Args:
- x: Conformer input sequences. (B, T, D_block)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
- mask: Source mask. (B, T)
- chunk_mask: Chunk mask. (T_2, T_2)
- Returns:
- x: Conformer output sequences. (B, T, D_block)
- mask: Source mask. (B, T)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
- """
- residual = x
-
- x = self.norm_macaron(x)
- x = residual + self.feed_forward_scale * self.dropout(
- self.feed_forward_macaron(x)
- )
-
- residual = x
- x = self.norm_self_att(x)
- x_q = x
- x = residual + self.dropout(
- self.self_att(
- x_q,
- x,
- x,
- pos_enc,
- mask,
- chunk_mask=chunk_mask,
- )
- )
-
- residual = x
-
- x = self.norm_conv(x)
- x, _ = self.conv_mod(x)
- x = residual + self.dropout(x)
- residual = x
-
- x = self.norm_feed_forward(x)
- x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
-
- x = self.norm_final(x)
- return x, mask, pos_enc
-
- def chunk_forward(
- self,
- x: torch.Tensor,
- pos_enc: torch.Tensor,
- mask: torch.Tensor,
- chunk_size: int = 16,
- left_context: int = 0,
- right_context: int = 0,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Encode chunk of input sequence.
- Args:
- x: Conformer input sequences. (B, T, D_block)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
- mask: Source mask. (B, T_2)
- left_context: Number of frames in left context.
- right_context: Number of frames in right context.
- Returns:
- x: Conformer output sequences. (B, T, D_block)
- pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
- """
- residual = x
-
- x = self.norm_macaron(x)
- x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
-
- residual = x
- x = self.norm_self_att(x)
- if left_context > 0:
- key = torch.cat([self.cache[0], x], dim=1)
- else:
- key = x
- val = key
-
- if right_context > 0:
- att_cache = key[:, -(left_context + right_context) : -right_context, :]
- else:
- att_cache = key[:, -left_context:, :]
- x = residual + self.self_att(
- x,
- key,
- val,
- pos_enc,
- mask,
- left_context=left_context,
- )
-
- residual = x
- x = self.norm_conv(x)
- x, conv_cache = self.conv_mod(
- x, cache=self.cache[1], right_context=right_context
- )
- x = residual + x
- residual = x
-
- x = self.norm_feed_forward(x)
- x = residual + self.feed_forward_scale * self.feed_forward(x)
-
- x = self.norm_final(x)
- self.cache = [att_cache, conv_cache]
-
- return x, pos_enc
-
-
-class ConformerEncoder(AbsEncoder):
- """Conformer encoder module.
-
- Args:
- input_size (int): Input dimension.
- output_size (int): Dimension of attention.
- attention_heads (int): The number of heads of multi head attention.
- linear_units (int): The number of units of position-wise feed forward.
- num_blocks (int): The number of decoder blocks.
- dropout_rate (float): Dropout rate.
- attention_dropout_rate (float): Dropout rate in attention.
- positional_dropout_rate (float): Dropout rate after adding positional encoding.
- input_layer (Union[str, torch.nn.Module]): Input layer type.
- normalize_before (bool): Whether to use layer_norm before the first block.
- concat_after (bool): Whether to concat attention layer's input and output.
- If True, additional linear will be applied.
- i.e. x -> x + linear(concat(x, att(x)))
- If False, no additional linear will be applied. i.e. x -> x + att(x)
- positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
- positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
- rel_pos_type (str): Whether to use the latest relative positional encoding or
- the legacy one. The legacy relative positional encoding will be deprecated
- in the future. More Details can be found in
- https://github.com/espnet/espnet/pull/2816.
- encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
- encoder_attn_layer_type (str): Encoder attention layer type.
- activation_type (str): Encoder activation function type.
- macaron_style (bool): Whether to use macaron style for positionwise layer.
- use_cnn_module (bool): Whether to use convolution module.
- zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
- cnn_module_kernel (int): Kernerl size of convolution module.
- padding_idx (int): Padding idx for input_layer=embed.
-
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: str = "conv2d",
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 3,
- macaron_style: bool = False,
- rel_pos_type: str = "legacy",
- pos_enc_layer_type: str = "rel_pos",
- selfattention_layer_type: str = "rel_selfattn",
- activation_type: str = "swish",
- use_cnn_module: bool = True,
- zero_triu: bool = False,
- cnn_module_kernel: int = 31,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- stochastic_depth_rate: Union[float, List[float]] = 0.0,
- ):
- super().__init__()
- self._output_size = output_size
-
- if rel_pos_type == "legacy":
- if pos_enc_layer_type == "rel_pos":
- pos_enc_layer_type = "legacy_rel_pos"
- if selfattention_layer_type == "rel_selfattn":
- selfattention_layer_type = "legacy_rel_selfattn"
- elif rel_pos_type == "latest":
- assert selfattention_layer_type != "legacy_rel_selfattn"
- assert pos_enc_layer_type != "legacy_rel_pos"
- else:
- raise ValueError("unknown rel_pos_type: " + rel_pos_type)
-
- activation = get_activation(activation_type)
- if pos_enc_layer_type == "abs_pos":
- pos_enc_class = PositionalEncoding
- elif pos_enc_layer_type == "scaled_abs_pos":
- pos_enc_class = ScaledPositionalEncoding
- elif pos_enc_layer_type == "rel_pos":
- assert selfattention_layer_type == "rel_selfattn"
- pos_enc_class = RelPositionalEncoding
- elif pos_enc_layer_type == "legacy_rel_pos":
- assert selfattention_layer_type == "legacy_rel_selfattn"
- pos_enc_class = LegacyRelPositionalEncoding
- logging.warning(
- "Using legacy_rel_pos and it will be deprecated in the future."
- )
- else:
- raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
-
- if input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(input_size, output_size),
- torch.nn.LayerNorm(output_size),
- torch.nn.Dropout(dropout_rate),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d":
- self.embed = Conv2dSubsampling(
- input_size,
- output_size,
- dropout_rate,
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2dpad":
- self.embed = Conv2dSubsamplingPad(
- input_size,
- output_size,
- dropout_rate,
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d2":
- self.embed = Conv2dSubsampling2(
- input_size,
- output_size,
- dropout_rate,
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d6":
- self.embed = Conv2dSubsampling6(
- input_size,
- output_size,
- dropout_rate,
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d8":
- self.embed = Conv2dSubsampling8(
- input_size,
- output_size,
- dropout_rate,
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif isinstance(input_layer, torch.nn.Module):
- self.embed = torch.nn.Sequential(
- input_layer,
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer is None:
- self.embed = torch.nn.Sequential(
- pos_enc_class(output_size, positional_dropout_rate)
- )
- else:
- raise ValueError("unknown input_layer: " + input_layer)
- self.normalize_before = normalize_before
- if positionwise_layer_type == "linear":
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- activation,
- )
- elif positionwise_layer_type == "conv1d":
- positionwise_layer = MultiLayeredConv1d
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d-linear":
- positionwise_layer = Conv1dLinear
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- else:
- raise NotImplementedError("Support only linear or conv1d.")
-
- if selfattention_layer_type == "selfattn":
- encoder_selfattn_layer = MultiHeadedAttention
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- )
- elif selfattention_layer_type == "legacy_rel_selfattn":
- assert pos_enc_layer_type == "legacy_rel_pos"
- encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- )
- logging.warning(
- "Using legacy_rel_selfattn and it will be deprecated in the future."
- )
- elif selfattention_layer_type == "rel_selfattn":
- assert pos_enc_layer_type == "rel_pos"
- encoder_selfattn_layer = RelPositionMultiHeadedAttention
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- zero_triu,
- )
- else:
- raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
-
- convolution_layer = ConvolutionModule
- convolution_layer_args = (output_size, cnn_module_kernel, activation)
-
- if isinstance(stochastic_depth_rate, float):
- stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
-
- if len(stochastic_depth_rate) != num_blocks:
- raise ValueError(
- f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
- f"should be equal to num_blocks ({num_blocks})"
- )
-
- self.encoders = repeat(
- num_blocks,
- lambda lnum: EncoderLayer(
- output_size,
- encoder_selfattn_layer(*encoder_selfattn_layer_args),
- positionwise_layer(*positionwise_layer_args),
- positionwise_layer(*positionwise_layer_args) if macaron_style else None,
- convolution_layer(*convolution_layer_args) if use_cnn_module else None,
- dropout_rate,
- normalize_before,
- concat_after,
- stochastic_depth_rate[lnum],
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
-
- self.interctc_layer_idx = interctc_layer_idx
- if len(interctc_layer_idx) > 0:
- assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
- self.interctc_use_conditioning = interctc_use_conditioning
- self.conditioning_layer = None
-
- def output_size(self) -> int:
- return self._output_size
-
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Calculate forward propagation.
-
- Args:
- xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
- ilens (torch.Tensor): Input length (#batch).
- prev_states (torch.Tensor): Not to be used now.
-
- Returns:
- torch.Tensor: Output tensor (#batch, L, output_size).
- torch.Tensor: Output length (#batch).
- torch.Tensor: Not to be used now.
-
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-
- if (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
- or isinstance(self.embed, Conv2dSubsamplingPad)
- ):
- short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
- if short_status:
- raise TooShortUttError(
- f"has {xs_pad.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- xs_pad.size(1),
- limit_size,
- )
- xs_pad, masks = self.embed(xs_pad, masks)
- else:
- xs_pad = self.embed(xs_pad)
-
- intermediate_outs = []
- if len(self.interctc_layer_idx) == 0:
- xs_pad, masks = self.encoders(xs_pad, masks)
- else:
- for layer_idx, encoder_layer in enumerate(self.encoders):
- xs_pad, masks = encoder_layer(xs_pad, masks)
-
- if layer_idx + 1 in self.interctc_layer_idx:
- encoder_out = xs_pad
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
-
- # intermediate outputs are also normalized
- if self.normalize_before:
- encoder_out = self.after_norm(encoder_out)
-
- intermediate_outs.append((layer_idx + 1, encoder_out))
-
- if self.interctc_use_conditioning:
- ctc_out = ctc.softmax(encoder_out)
-
- if isinstance(xs_pad, tuple):
- x, pos_emb = xs_pad
- x = x + self.conditioning_layer(ctc_out)
- xs_pad = (x, pos_emb)
- else:
- xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
- if isinstance(xs_pad, tuple):
- xs_pad = xs_pad[0]
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
-
- olens = masks.squeeze(1).sum(1)
- if len(intermediate_outs) > 0:
- return (xs_pad, intermediate_outs), olens, None
- return xs_pad, olens, None
-
-
-class CausalConvolution(torch.nn.Module):
- """ConformerConvolution module definition.
- Args:
- channels: The number of channels.
- kernel_size: Size of the convolving kernel.
- activation: Type of activation function.
- norm_args: Normalization module arguments.
- causal: Whether to use causal convolution (set to True if streaming).
- """
-
- def __init__(
- self,
- channels: int,
- kernel_size: int,
- activation: torch.nn.Module = torch.nn.ReLU(),
- norm_args: Dict = {},
- causal: bool = False,
- ) -> None:
- """Construct an ConformerConvolution object."""
- super().__init__()
-
- assert (kernel_size - 1) % 2 == 0
-
- self.kernel_size = kernel_size
-
- self.pointwise_conv1 = torch.nn.Conv1d(
- channels,
- 2 * channels,
- kernel_size=1,
- stride=1,
- padding=0,
- )
-
- if causal:
- self.lorder = kernel_size - 1
- padding = 0
- else:
- self.lorder = 0
- padding = (kernel_size - 1) // 2
-
- self.depthwise_conv = torch.nn.Conv1d(
- channels,
- channels,
- kernel_size,
- stride=1,
- padding=padding,
- groups=channels,
- )
- self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
- self.pointwise_conv2 = torch.nn.Conv1d(
- channels,
- channels,
- kernel_size=1,
- stride=1,
- padding=0,
- )
-
- self.activation = activation
-
- def forward(
- self,
- x: torch.Tensor,
- cache: Optional[torch.Tensor] = None,
- right_context: int = 0,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Compute convolution module.
- Args:
- x: ConformerConvolution input sequences. (B, T, D_hidden)
- cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
- right_context: Number of frames in right context.
- Returns:
- x: ConformerConvolution output sequences. (B, T, D_hidden)
- cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
- """
- x = self.pointwise_conv1(x.transpose(1, 2))
- x = torch.nn.functional.glu(x, dim=1)
-
- if self.lorder > 0:
- if cache is None:
- x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
- else:
- x = torch.cat([cache, x], dim=2)
-
- if right_context > 0:
- cache = x[:, :, -(self.lorder + right_context) : -right_context]
- else:
- cache = x[:, :, -self.lorder :]
-
- x = self.depthwise_conv(x)
- x = self.activation(self.norm(x))
-
- x = self.pointwise_conv2(x).transpose(1, 2)
-
- return x, cache
-
-class ConformerChunkEncoder(AbsEncoder):
- """Encoder module definition.
- Args:
- input_size: Input size.
- body_conf: Encoder body configuration.
- input_conf: Encoder input configuration.
- main_conf: Encoder main configuration.
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- embed_vgg_like: bool = False,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 3,
- macaron_style: bool = False,
- rel_pos_type: str = "legacy",
- pos_enc_layer_type: str = "rel_pos",
- selfattention_layer_type: str = "rel_selfattn",
- activation_type: str = "swish",
- use_cnn_module: bool = True,
- zero_triu: bool = False,
- norm_type: str = "layer_norm",
- cnn_module_kernel: int = 31,
- conv_mod_norm_eps: float = 0.00001,
- conv_mod_norm_momentum: float = 0.1,
- simplified_att_score: bool = False,
- dynamic_chunk_training: bool = False,
- short_chunk_threshold: float = 0.75,
- short_chunk_size: int = 25,
- left_chunk_size: int = 0,
- time_reduction_factor: int = 1,
- unified_model_training: bool = False,
- default_chunk_size: int = 16,
- jitter_range: int = 4,
- subsampling_factor: int = 1,
- ) -> None:
- """Construct an Encoder object."""
- super().__init__()
-
-
- self.embed = StreamingConvInput(
- input_size,
- output_size,
- subsampling_factor,
- vgg_like=embed_vgg_like,
- output_size=output_size,
- )
-
- self.pos_enc = StreamingRelPositionalEncoding(
- output_size,
- positional_dropout_rate,
- )
-
- activation = get_activation(
- activation_type
- )
-
- pos_wise_args = (
- output_size,
- linear_units,
- positional_dropout_rate,
- activation,
- )
-
- conv_mod_norm_args = {
- "eps": conv_mod_norm_eps,
- "momentum": conv_mod_norm_momentum,
- }
-
- conv_mod_args = (
- output_size,
- cnn_module_kernel,
- activation,
- conv_mod_norm_args,
- dynamic_chunk_training or unified_model_training,
- )
-
- mult_att_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- simplified_att_score,
- )
-
-
- fn_modules = []
- for _ in range(num_blocks):
- module = lambda: ChunkEncoderLayer(
- output_size,
- RelPositionMultiHeadedAttentionChunk(*mult_att_args),
- PositionwiseFeedForward(*pos_wise_args),
- PositionwiseFeedForward(*pos_wise_args),
- CausalConvolution(*conv_mod_args),
- dropout_rate=dropout_rate,
- )
- fn_modules.append(module)
-
- self.encoders = MultiBlocks(
- [fn() for fn in fn_modules],
- output_size,
- )
-
- self._output_size = output_size
-
- self.dynamic_chunk_training = dynamic_chunk_training
- self.short_chunk_threshold = short_chunk_threshold
- self.short_chunk_size = short_chunk_size
- self.left_chunk_size = left_chunk_size
-
- self.unified_model_training = unified_model_training
- self.default_chunk_size = default_chunk_size
- self.jitter_range = jitter_range
-
- self.time_reduction_factor = time_reduction_factor
-
- def output_size(self) -> int:
- return self._output_size
-
- def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
- """Return the corresponding number of sample for a given chunk size, in frames.
- Where size is the number of features frames after applying subsampling.
- Args:
- size: Number of frames after subsampling.
- hop_length: Frontend's hop length
- Returns:
- : Number of raw samples
- """
- return self.embed.get_size_before_subsampling(size) * hop_length
-
- def get_encoder_input_size(self, size: int) -> int:
- """Return the corresponding number of sample for a given chunk size, in frames.
- Where size is the number of features frames after applying subsampling.
- Args:
- size: Number of frames after subsampling.
- Returns:
- : Number of raw samples
- """
- return self.embed.get_size_before_subsampling(size)
-
-
- def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
- """Initialize/Reset encoder streaming cache.
- Args:
- left_context: Number of frames in left context.
- device: Device ID.
- """
- return self.encoders.reset_streaming_cache(left_context, device)
-
- def forward(
- self,
- x: torch.Tensor,
- x_len: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Encode input sequences.
- Args:
- x: Encoder input features. (B, T_in, F)
- x_len: Encoder input features lengths. (B,)
- Returns:
- x: Encoder outputs. (B, T_out, D_enc)
- x_len: Encoder outputs lenghts. (B,)
- """
- short_status, limit_size = check_short_utt(
- self.embed.subsampling_factor, x.size(1)
- )
-
- if short_status:
- raise TooShortUttError(
- f"has {x.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- x.size(1),
- limit_size,
- )
-
- mask = make_source_mask(x_len).to(x.device)
-
- if self.unified_model_training:
- if self.training:
- chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
- else:
- chunk_size = self.default_chunk_size
- x, mask = self.embed(x, mask, chunk_size)
- pos_enc = self.pos_enc(x)
- chunk_mask = make_chunk_mask(
- x.size(1),
- chunk_size,
- left_chunk_size=self.left_chunk_size,
- device=x.device,
- )
- x_utt = self.encoders(
- x,
- pos_enc,
- mask,
- chunk_mask=None,
- )
- x_chunk = self.encoders(
- x,
- pos_enc,
- mask,
- chunk_mask=chunk_mask,
- )
-
- olens = mask.eq(0).sum(1)
- if self.time_reduction_factor > 1:
- x_utt = x_utt[:,::self.time_reduction_factor,:]
- x_chunk = x_chunk[:,::self.time_reduction_factor,:]
- olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
-
- return x_utt, x_chunk, olens
-
- elif self.dynamic_chunk_training:
- max_len = x.size(1)
- if self.training:
- chunk_size = torch.randint(1, max_len, (1,)).item()
-
- if chunk_size > (max_len * self.short_chunk_threshold):
- chunk_size = max_len
- else:
- chunk_size = (chunk_size % self.short_chunk_size) + 1
- else:
- chunk_size = self.default_chunk_size
-
- x, mask = self.embed(x, mask, chunk_size)
- pos_enc = self.pos_enc(x)
-
- chunk_mask = make_chunk_mask(
- x.size(1),
- chunk_size,
- left_chunk_size=self.left_chunk_size,
- device=x.device,
- )
- else:
- x, mask = self.embed(x, mask, None)
- pos_enc = self.pos_enc(x)
- chunk_mask = None
- x = self.encoders(
- x,
- pos_enc,
- mask,
- chunk_mask=chunk_mask,
- )
-
- olens = mask.eq(0).sum(1)
- if self.time_reduction_factor > 1:
- x = x[:,::self.time_reduction_factor,:]
- olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
-
- return x, olens, None
-
- def full_utt_forward(
- self,
- x: torch.Tensor,
- x_len: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Encode input sequences.
- Args:
- x: Encoder input features. (B, T_in, F)
- x_len: Encoder input features lengths. (B,)
- Returns:
- x: Encoder outputs. (B, T_out, D_enc)
- x_len: Encoder outputs lenghts. (B,)
- """
- short_status, limit_size = check_short_utt(
- self.embed.subsampling_factor, x.size(1)
- )
-
- if short_status:
- raise TooShortUttError(
- f"has {x.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- x.size(1),
- limit_size,
- )
-
- mask = make_source_mask(x_len).to(x.device)
- x, mask = self.embed(x, mask, None)
- pos_enc = self.pos_enc(x)
- x_utt = self.encoders(
- x,
- pos_enc,
- mask,
- chunk_mask=None,
- )
-
- if self.time_reduction_factor > 1:
- x_utt = x_utt[:,::self.time_reduction_factor,:]
- return x_utt
-
- def simu_chunk_forward(
- self,
- x: torch.Tensor,
- x_len: torch.Tensor,
- chunk_size: int = 16,
- left_context: int = 32,
- right_context: int = 0,
- ) -> torch.Tensor:
- short_status, limit_size = check_short_utt(
- self.embed.subsampling_factor, x.size(1)
- )
-
- if short_status:
- raise TooShortUttError(
- f"has {x.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- x.size(1),
- limit_size,
- )
-
- mask = make_source_mask(x_len)
-
- x, mask = self.embed(x, mask, chunk_size)
- pos_enc = self.pos_enc(x)
- chunk_mask = make_chunk_mask(
- x.size(1),
- chunk_size,
- left_chunk_size=self.left_chunk_size,
- device=x.device,
- )
-
- x = self.encoders(
- x,
- pos_enc,
- mask,
- chunk_mask=chunk_mask,
- )
- olens = mask.eq(0).sum(1)
- if self.time_reduction_factor > 1:
- x = x[:,::self.time_reduction_factor,:]
-
- return x
-
- def chunk_forward(
- self,
- x: torch.Tensor,
- x_len: torch.Tensor,
- processed_frames: torch.tensor,
- chunk_size: int = 16,
- left_context: int = 32,
- right_context: int = 0,
- ) -> torch.Tensor:
- """Encode input sequences as chunks.
- Args:
- x: Encoder input features. (1, T_in, F)
- x_len: Encoder input features lengths. (1,)
- processed_frames: Number of frames already seen.
- left_context: Number of frames in left context.
- right_context: Number of frames in right context.
- Returns:
- x: Encoder outputs. (B, T_out, D_enc)
- """
- mask = make_source_mask(x_len)
- x, mask = self.embed(x, mask, None)
-
- if left_context > 0:
- processed_mask = (
- torch.arange(left_context, device=x.device)
- .view(1, left_context)
- .flip(1)
- )
- processed_mask = processed_mask >= processed_frames
- mask = torch.cat([processed_mask, mask], dim=1)
- pos_enc = self.pos_enc(x, left_context=left_context)
- x = self.encoders.chunk_forward(
- x,
- pos_enc,
- mask,
- chunk_size=chunk_size,
- left_context=left_context,
- right_context=right_context,
- )
-
- if right_context > 0:
- x = x[:, 0:-right_context, :]
-
- if self.time_reduction_factor > 1:
- x = x[:,::self.time_reduction_factor,:]
- return x
diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py
new file mode 100644
index 0000000..709e10e
--- /dev/null
+++ b/funasr/models/conformer/encoder.py
@@ -0,0 +1,613 @@
+# Copyright 2020 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Conformer encoder definition."""
+
+import logging
+from typing import Union, Dict, List, Tuple, Optional
+
+import torch
+from torch import nn
+
+from funasr.models.ctc.ctc import CTC
+from funasr.models.transformer.attention import (
+ MultiHeadedAttention, # noqa: H301
+ RelPositionMultiHeadedAttention, # noqa: H301
+ LegacyRelPositionMultiHeadedAttention, # noqa: H301
+)
+from funasr.models.transformer.embedding import (
+ PositionalEncoding, # noqa: H301
+ ScaledPositionalEncoding, # noqa: H301
+ RelPositionalEncoding, # noqa: H301
+ LegacyRelPositionalEncoding, # noqa: H301
+ StreamingRelPositionalEncoding,
+)
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
+from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
+from funasr.models.transformer.utils.nets_utils import get_activation
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.utils.nets_utils import (
+ TooShortUttError,
+ check_short_utt,
+ make_chunk_mask,
+ make_source_mask,
+)
+from funasr.models.transformer.positionwise_feed_forward import (
+ PositionwiseFeedForward, # noqa: H301
+)
+from funasr.models.transformer.utils.repeat import repeat, MultiBlocks
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
+from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
+from funasr.models.transformer.utils.subsampling import StreamingConvInput
+from funasr.utils.register import register_class
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model.
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+
+ """
+
+ def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
+ """Construct an ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ groups=channels,
+ bias=bias,
+ )
+ self.norm = nn.BatchNorm1d(channels)
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = activation
+
+ def forward(self, x):
+ """Compute convolution module.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2)
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ x = self.activation(self.norm(x))
+
+ x = self.pointwise_conv2(x)
+
+ return x.transpose(1, 2)
+
+
+class EncoderLayer(nn.Module):
+ """Encoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
+ can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+ stochastic_depth_rate (float): Proability to skip this layer.
+ During training, the layer may skip residual computation and return input
+ as-is with given probability.
+ """
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ feed_forward,
+ feed_forward_macaron,
+ conv_module,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ stochastic_depth_rate=0.0,
+ ):
+ """Construct an EncoderLayer object."""
+ super(EncoderLayer, self).__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.conv_module = conv_module
+ self.norm_ff = LayerNorm(size) # for the FNN module
+ self.norm_mha = LayerNorm(size) # for the MHA module
+ if feed_forward_macaron is not None:
+ self.norm_ff_macaron = LayerNorm(size)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+ if self.conv_module is not None:
+ self.norm_conv = LayerNorm(size) # for the CNN module
+ self.norm_final = LayerNorm(size) # for the final output of the block
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear = nn.Linear(size + size, size)
+ self.stochastic_depth_rate = stochastic_depth_rate
+
+ def forward(self, x_input, mask, cache=None):
+ """Compute encoded features.
+
+ Args:
+ x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
+ - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
+ - w/o pos emb: Tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+
+ """
+ if isinstance(x_input, tuple):
+ x, pos_emb = x_input[0], x_input[1]
+ else:
+ x, pos_emb = x_input, None
+
+ skip_layer = False
+ # with stochastic depth, residual connection `x + f(x)` becomes
+ # `x <- x + 1 / (1 - p) * f(x)` at training time.
+ stoch_layer_coeff = 1.0
+ if self.training and self.stochastic_depth_rate > 0:
+ skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+ stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+ if skip_layer:
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+ if pos_emb is not None:
+ return (x, pos_emb), mask
+ return x, mask
+
+ # whether to use macaron style
+ if self.feed_forward_macaron is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff_macaron(x)
+ x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
+ self.feed_forward_macaron(x)
+ )
+ if not self.normalize_before:
+ x = self.norm_ff_macaron(x)
+
+ # multi-headed self-attention module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_mha(x)
+
+ if cache is None:
+ x_q = x
+ else:
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
+ x_q = x[:, -1:, :]
+ residual = residual[:, -1:, :]
+ mask = None if mask is None else mask[:, -1:, :]
+
+ if pos_emb is not None:
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
+ else:
+ x_att = self.self_attn(x_q, x, x, mask)
+
+ if self.concat_after:
+ x_concat = torch.cat((x, x_att), dim=-1)
+ x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+ else:
+ x = residual + stoch_layer_coeff * self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm_mha(x)
+
+ # convolution module
+ if self.conv_module is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_conv(x)
+ x = residual + stoch_layer_coeff * self.dropout(self.conv_module(x))
+ if not self.normalize_before:
+ x = self.norm_conv(x)
+
+ # feed forward module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff(x)
+ x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
+ self.feed_forward(x)
+ )
+ if not self.normalize_before:
+ x = self.norm_ff(x)
+
+ if self.conv_module is not None:
+ x = self.norm_final(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ if pos_emb is not None:
+ return (x, pos_emb), mask
+
+ return x, mask
+
+
+@register_class("encoder_classes", "ConformerEncoder")
+class ConformerEncoder(nn.Module):
+ """Conformer encoder module.
+
+ Args:
+ input_size (int): Input dimension.
+ output_size (int): Dimension of attention.
+ attention_heads (int): The number of heads of multi head attention.
+ linear_units (int): The number of units of position-wise feed forward.
+ num_blocks (int): The number of decoder blocks.
+ dropout_rate (float): Dropout rate.
+ attention_dropout_rate (float): Dropout rate in attention.
+ positional_dropout_rate (float): Dropout rate after adding positional encoding.
+ input_layer (Union[str, torch.nn.Module]): Input layer type.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ If True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ If False, no additional linear will be applied. i.e. x -> x + att(x)
+ positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
+ positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
+ rel_pos_type (str): Whether to use the latest relative positional encoding or
+ the legacy one. The legacy relative positional encoding will be deprecated
+ in the future. More Details can be found in
+ https://github.com/espnet/espnet/pull/2816.
+ encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
+ encoder_attn_layer_type (str): Encoder attention layer type.
+ activation_type (str): Encoder activation function type.
+ macaron_style (bool): Whether to use macaron style for positionwise layer.
+ use_cnn_module (bool): Whether to use convolution module.
+ zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
+ cnn_module_kernel (int): Kernerl size of convolution module.
+ padding_idx (int): Padding idx for input_layer=embed.
+
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 3,
+ macaron_style: bool = False,
+ rel_pos_type: str = "legacy",
+ pos_enc_layer_type: str = "rel_pos",
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ zero_triu: bool = False,
+ cnn_module_kernel: int = 31,
+ padding_idx: int = -1,
+ interctc_layer_idx: List[int] = [],
+ interctc_use_conditioning: bool = False,
+ stochastic_depth_rate: Union[float, List[float]] = 0.0,
+ ):
+ super().__init__()
+ self._output_size = output_size
+
+ if rel_pos_type == "legacy":
+ if pos_enc_layer_type == "rel_pos":
+ pos_enc_layer_type = "legacy_rel_pos"
+ if selfattention_layer_type == "rel_selfattn":
+ selfattention_layer_type = "legacy_rel_selfattn"
+ elif rel_pos_type == "latest":
+ assert selfattention_layer_type != "legacy_rel_selfattn"
+ assert pos_enc_layer_type != "legacy_rel_pos"
+ else:
+ raise ValueError("unknown rel_pos_type: " + rel_pos_type)
+
+ activation = get_activation(activation_type)
+ if pos_enc_layer_type == "abs_pos":
+ pos_enc_class = PositionalEncoding
+ elif pos_enc_layer_type == "scaled_abs_pos":
+ pos_enc_class = ScaledPositionalEncoding
+ elif pos_enc_layer_type == "rel_pos":
+ assert selfattention_layer_type == "rel_selfattn"
+ pos_enc_class = RelPositionalEncoding
+ elif pos_enc_layer_type == "legacy_rel_pos":
+ assert selfattention_layer_type == "legacy_rel_selfattn"
+ pos_enc_class = LegacyRelPositionalEncoding
+ logging.warning(
+ "Using legacy_rel_pos and it will be deprecated in the future."
+ )
+ else:
+ raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
+
+ if input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(input_size, output_size),
+ torch.nn.LayerNorm(output_size),
+ torch.nn.Dropout(dropout_rate),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d":
+ self.embed = Conv2dSubsampling(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2dpad":
+ self.embed = Conv2dSubsamplingPad(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d2":
+ self.embed = Conv2dSubsampling2(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d6":
+ self.embed = Conv2dSubsampling6(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d8":
+ self.embed = Conv2dSubsampling8(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif isinstance(input_layer, torch.nn.Module):
+ self.embed = torch.nn.Sequential(
+ input_layer,
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer is None:
+ self.embed = torch.nn.Sequential(
+ pos_enc_class(output_size, positional_dropout_rate)
+ )
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+ self.normalize_before = normalize_before
+ if positionwise_layer_type == "linear":
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ activation,
+ )
+ elif positionwise_layer_type == "conv1d":
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d-linear":
+ positionwise_layer = Conv1dLinear
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ else:
+ raise NotImplementedError("Support only linear or conv1d.")
+
+ if selfattention_layer_type == "selfattn":
+ encoder_selfattn_layer = MultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+ elif selfattention_layer_type == "legacy_rel_selfattn":
+ assert pos_enc_layer_type == "legacy_rel_pos"
+ encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+ logging.warning(
+ "Using legacy_rel_selfattn and it will be deprecated in the future."
+ )
+ elif selfattention_layer_type == "rel_selfattn":
+ assert pos_enc_layer_type == "rel_pos"
+ encoder_selfattn_layer = RelPositionMultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ zero_triu,
+ )
+ else:
+ raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
+
+ convolution_layer = ConvolutionModule
+ convolution_layer_args = (output_size, cnn_module_kernel, activation)
+
+ if isinstance(stochastic_depth_rate, float):
+ stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
+
+ if len(stochastic_depth_rate) != num_blocks:
+ raise ValueError(
+ f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
+ f"should be equal to num_blocks ({num_blocks})"
+ )
+
+ self.encoders = repeat(
+ num_blocks,
+ lambda lnum: EncoderLayer(
+ output_size,
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ positionwise_layer(*positionwise_layer_args) if macaron_style else None,
+ convolution_layer(*convolution_layer_args) if use_cnn_module else None,
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ stochastic_depth_rate[lnum],
+ ),
+ )
+ if self.normalize_before:
+ self.after_norm = LayerNorm(output_size)
+
+ self.interctc_layer_idx = interctc_layer_idx
+ if len(interctc_layer_idx) > 0:
+ assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+ self.interctc_use_conditioning = interctc_use_conditioning
+ self.conditioning_layer = None
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
+ ilens (torch.Tensor): Input length (#batch).
+ prev_states (torch.Tensor): Not to be used now.
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, L, output_size).
+ torch.Tensor: Output length (#batch).
+ torch.Tensor: Not to be used now.
+
+ """
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+
+ if (
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
+ or isinstance(self.embed, Conv2dSubsamplingPad)
+ ):
+ short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+ if short_status:
+ raise TooShortUttError(
+ f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ xs_pad.size(1),
+ limit_size,
+ )
+ xs_pad, masks = self.embed(xs_pad, masks)
+ else:
+ xs_pad = self.embed(xs_pad)
+
+ intermediate_outs = []
+ if len(self.interctc_layer_idx) == 0:
+ xs_pad, masks = self.encoders(xs_pad, masks)
+ else:
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ xs_pad, masks = encoder_layer(xs_pad, masks)
+
+ if layer_idx + 1 in self.interctc_layer_idx:
+ encoder_out = xs_pad
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # intermediate outputs are also normalized
+ if self.normalize_before:
+ encoder_out = self.after_norm(encoder_out)
+
+ intermediate_outs.append((layer_idx + 1, encoder_out))
+
+ if self.interctc_use_conditioning:
+ ctc_out = ctc.softmax(encoder_out)
+
+ if isinstance(xs_pad, tuple):
+ x, pos_emb = xs_pad
+ x = x + self.conditioning_layer(ctc_out)
+ xs_pad = (x, pos_emb)
+ else:
+ xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+ if isinstance(xs_pad, tuple):
+ xs_pad = xs_pad[0]
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ olens = masks.squeeze(1).sum(1)
+ if len(intermediate_outs) > 0:
+ return (xs_pad, intermediate_outs), olens, None
+ return xs_pad, olens, None
+
diff --git a/funasr/models/conformer/model.py b/funasr/models/conformer/model.py
index 48f04e4..5319a73 100644
--- a/funasr/models/conformer/model.py
+++ b/funasr/models/conformer/model.py
@@ -1,57 +1,11 @@
import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
+
import torch
-import torch.nn as nn
-import random
-import numpy as np
-import time
-# from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
-from funasr.models.paraformer.search import Hypothesis
-
-from funasr.models.model_class_factory import *
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
-else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.transformer.model import Transformer
+from funasr.utils.register import register_class, registry_tables
+@register_class("model_classes", "Conformer")
class Conformer(Transformer):
"""CTC-attention hybrid Encoder-Decoder model"""
diff --git a/funasr/models/ct_transformer/attention.py b/funasr/models/ct_transformer/attention.py
new file mode 100644
index 0000000..a35ddee
--- /dev/null
+++ b/funasr/models/ct_transformer/attention.py
@@ -0,0 +1,1091 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention layer definition."""
+
+import math
+
+import numpy
+import torch
+from torch import nn
+from typing import Optional, Tuple
+
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+import funasr.models.lora.layers as lora
+
+class MultiHeadedAttention(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadedAttention, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k = nn.Linear(n_feat, n_feat)
+ self.linear_v = nn.Linear(n_feat, n_feat)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(self, query, key, value):
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+ """
+ n_batch = query.size(0)
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q, k, v
+
+ def forward_attention(self, value, scores, mask):
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, query, key, value, mask):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, mask)
+
+
+class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding (old version).
+
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+ Paper: https://arxiv.org/abs/1901.02860
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
+
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate)
+ self.zero_triu = zero_triu
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x):
+ """Compute relative positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, head, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor.
+
+ """
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)
+
+ if self.zero_triu:
+ ones = torch.ones((x.size(2), x.size(3)))
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+ return x
+
+ def forward(self, query, key, value, pos_emb, mask):
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, time1)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k
+ ) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask)
+
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding (new implementation).
+
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+ Paper: https://arxiv.org/abs/1901.02860
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
+
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate)
+ self.zero_triu = zero_triu
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x):
+ """Compute relative positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
+ time1 means the length of query vector.
+
+ Returns:
+ torch.Tensor: Output tensor.
+
+ """
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)[
+ :, :, :, : x.size(-1) // 2 + 1
+ ] # only keep the positions from 0 to time2
+
+ if self.zero_triu:
+ ones = torch.ones((x.size(2), x.size(3)), device=x.device)
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+ return x
+
+ def forward(self, query, key, value, pos_emb, mask):
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ pos_emb (torch.Tensor): Positional embedding tensor
+ (#batch, 2*time1-1, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+ p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, 2*time1-1)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k
+ ) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask)
+
+
+class MultiHeadedAttentionSANM(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadedAttentionSANM, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ # self.linear_q = nn.Linear(n_feat, n_feat)
+ # self.linear_k = nn.Linear(n_feat, n_feat)
+ # self.linear_v = nn.Linear(n_feat, n_feat)
+ if lora_list is not None:
+ if "o" in lora_list:
+ self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
+ if lora_qkv_list == [False, False, False]:
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ else:
+ self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
+ # padding
+ left_padding = (kernel_size - 1) // 2
+ if sanm_shfit > 0:
+ left_padding = left_padding + sanm_shfit
+ right_padding = kernel_size - 1 - left_padding
+ self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+
+ def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
+ b, t, d = inputs.size()
+ if mask is not None:
+ mask = torch.reshape(mask, (b, -1, 1))
+ if mask_shfit_chunk is not None:
+ mask = mask * mask_shfit_chunk
+ inputs = inputs * mask
+
+ x = inputs.transpose(1, 2)
+ x = self.pad_fn(x)
+ x = self.fsmn_block(x)
+ x = x.transpose(1, 2)
+ x += inputs
+ x = self.dropout(x)
+ if mask is not None:
+ x = x * mask
+ return x
+
+ def forward_qkv(self, x):
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+ """
+ b, t, d = x.size()
+ q_k_v = self.linear_q_k_v(x)
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+ q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
+ k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+ v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q_h, k_h, v_h, v
+
+ def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ if mask_att_chunk_encoder is not None:
+ mask = mask * mask_att_chunk_encoder
+
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+ return att_outs + fsmn_memory
+
+ def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ if chunk_size is not None and look_back > 0 or look_back == -1:
+ if cache is not None:
+ k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
+ v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
+ k_h = torch.cat((cache["k"], k_h), dim=2)
+ v_h = torch.cat((cache["v"], v_h), dim=2)
+
+ cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
+ cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
+ if look_back != -1:
+ cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
+ cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
+ else:
+ cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
+ "v": v_h[:, :, :-(chunk_size[2]), :]}
+ cache = cache_tmp
+ fsmn_memory = self.forward_fsmn(v, None)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, None)
+ return att_outs + fsmn_memory, cache
+
+
+class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
+ return att_outs + fsmn_memory
+
+class MultiHeadedAttentionSANMDecoder(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadedAttentionSANMDecoder, self).__init__()
+
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ self.fsmn_block = nn.Conv1d(n_feat, n_feat,
+ kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
+ # padding
+ # padding
+ left_padding = (kernel_size - 1) // 2
+ if sanm_shfit > 0:
+ left_padding = left_padding + sanm_shfit
+ right_padding = kernel_size - 1 - left_padding
+ self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+ self.kernel_size = kernel_size
+
+ def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
+ '''
+ :param x: (#batch, time1, size).
+ :param mask: Mask tensor (#batch, 1, time)
+ :return:
+ '''
+ # print("in fsmn, inputs", inputs.size())
+ b, t, d = inputs.size()
+ # logging.info(
+ # "mask: {}".format(mask.size()))
+ if mask is not None:
+ mask = torch.reshape(mask, (b ,-1, 1))
+ # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
+ if mask_shfit_chunk is not None:
+ # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
+ mask = mask * mask_shfit_chunk
+ # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
+ # print("in fsmn, mask", mask.size())
+ # print("in fsmn, inputs", inputs.size())
+ inputs = inputs * mask
+
+ x = inputs.transpose(1, 2)
+ b, d, t = x.size()
+ if cache is None:
+ # print("in fsmn, cache is None, x", x.size())
+
+ x = self.pad_fn(x)
+ if not self.training:
+ cache = x
+ else:
+ # print("in fsmn, cache is not None, x", x.size())
+ # x = torch.cat((x, cache), dim=2)[:, :, :-1]
+ # if t < self.kernel_size:
+ # x = self.pad_fn(x)
+ x = torch.cat((cache[:, :, 1:], x), dim=2)
+ x = x[:, :, -(self.kernel_size+t-1):]
+ # print("in fsmn, cache is not None, x_cat", x.size())
+ cache = x
+ x = self.fsmn_block(x)
+ x = x.transpose(1, 2)
+ # print("in fsmn, fsmn_out", x.size())
+ if x.size(1) != inputs.size(1):
+ inputs = inputs[:, -1, :]
+
+ x = x + inputs
+ x = self.dropout(x)
+ if mask is not None:
+ x = x * mask
+ return x, cache
+
+class MultiHeadedAttentionCrossAtt(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadedAttentionCrossAtt, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ if lora_list is not None:
+ if "q" in lora_list:
+ self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ lora_kv_list = ["k" in lora_list, "v" in lora_list]
+ if lora_kv_list == [False, False]:
+ self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+ else:
+ self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
+ r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+ if "o" in lora_list:
+ self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ else:
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(self, x, memory):
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+ """
+
+ # print("in forward_qkv, x", x.size())
+ b = x.size(0)
+ q = self.linear_q(x)
+ q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
+
+ k_v = self.linear_k_v(memory)
+ k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
+ k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+ v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+
+
+ return q_h, k_h, v_h
+
+ def forward_attention(self, value, scores, mask):
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ # logging.info(
+ # "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, x, memory, memory_mask):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h = self.forward_qkv(x, memory)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ return self.forward_attention(v_h, scores, memory_mask)
+
+ def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h = self.forward_qkv(x, memory)
+ if chunk_size is not None and look_back > 0:
+ if cache is not None:
+ k_h = torch.cat((cache["k"], k_h), dim=2)
+ v_h = torch.cat((cache["v"], v_h), dim=2)
+ cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
+ cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
+ else:
+ cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
+ "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
+ cache = cache_tmp
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ return self.forward_attention(v_h, scores, None), cache
+
+
+class MultiHeadSelfAttention(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head, in_feat, n_feat, dropout_rate):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadSelfAttention, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(self, x):
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+ """
+ b, t, d = x.size()
+ q_k_v = self.linear_q_k_v(x)
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+ q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
+ k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+ v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q_h, k_h, v_h, v
+
+ def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ if mask_att_chunk_encoder is not None:
+ mask = mask * mask_att_chunk_encoder
+
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, x, mask, mask_att_chunk_encoder=None):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+ return att_outs
+
+class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
+ """RelPositionMultiHeadedAttention definition.
+ Args:
+ num_heads: Number of attention heads.
+ embed_size: Embedding size.
+ dropout_rate: Dropout rate.
+ """
+
+ def __init__(
+ self,
+ num_heads: int,
+ embed_size: int,
+ dropout_rate: float = 0.0,
+ simplified_attention_score: bool = False,
+ ) -> None:
+ """Construct an MultiHeadedAttention object."""
+ super().__init__()
+
+ self.d_k = embed_size // num_heads
+ self.num_heads = num_heads
+
+ assert self.d_k * num_heads == embed_size, (
+ "embed_size (%d) must be divisible by num_heads (%d)",
+ (embed_size, num_heads),
+ )
+
+ self.linear_q = torch.nn.Linear(embed_size, embed_size)
+ self.linear_k = torch.nn.Linear(embed_size, embed_size)
+ self.linear_v = torch.nn.Linear(embed_size, embed_size)
+
+ self.linear_out = torch.nn.Linear(embed_size, embed_size)
+
+ if simplified_attention_score:
+ self.linear_pos = torch.nn.Linear(embed_size, num_heads)
+
+ self.compute_att_score = self.compute_simplified_attention_score
+ else:
+ self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
+
+ self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+ self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ self.compute_att_score = self.compute_attention_score
+
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.attn = None
+
+ def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+ """Compute relative positional encoding.
+ Args:
+ x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
+ left_context: Number of frames in left context.
+ Returns:
+ x: Output sequence. (B, H, T_1, T_2)
+ """
+ batch_size, n_heads, time1, n = x.shape
+ time2 = time1 + left_context
+
+ batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
+
+ return x.as_strided(
+ (batch_size, n_heads, time1, time2),
+ (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
+ storage_offset=(n_stride * (time1 - 1)),
+ )
+
+ def compute_simplified_attention_score(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ pos_enc: torch.Tensor,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Simplified attention score computation.
+ Reference: https://github.com/k2-fsa/icefall/pull/458
+ Args:
+ query: Transformed query tensor. (B, H, T_1, d_k)
+ key: Transformed key tensor. (B, H, T_2, d_k)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ left_context: Number of frames in left context.
+ Returns:
+ : Attention score. (B, H, T_1, T_2)
+ """
+ pos_enc = self.linear_pos(pos_enc)
+
+ matrix_ac = torch.matmul(query, key.transpose(2, 3))
+
+ matrix_bd = self.rel_shift(
+ pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
+ left_context=left_context,
+ )
+
+ return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+ def compute_attention_score(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ pos_enc: torch.Tensor,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Attention score computation.
+ Args:
+ query: Transformed query tensor. (B, H, T_1, d_k)
+ key: Transformed key tensor. (B, H, T_2, d_k)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ left_context: Number of frames in left context.
+ Returns:
+ : Attention score. (B, H, T_1, T_2)
+ """
+ p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
+
+ query = query.transpose(1, 2)
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+ matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+ matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
+ matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
+
+ return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+ def forward_qkv(
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Transform query, key and value.
+ Args:
+ query: Query tensor. (B, T_1, size)
+ key: Key tensor. (B, T_2, size)
+ v: Value tensor. (B, T_2, size)
+ Returns:
+ q: Transformed query tensor. (B, H, T_1, d_k)
+ k: Transformed key tensor. (B, H, T_2, d_k)
+ v: Transformed value tensor. (B, H, T_2, d_k)
+ """
+ n_batch = query.size(0)
+
+ q = (
+ self.linear_q(query)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+ k = (
+ self.linear_k(key)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+ v = (
+ self.linear_v(value)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+
+ return q, k, v
+
+ def forward_attention(
+ self,
+ value: torch.Tensor,
+ scores: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Compute attention context vector.
+ Args:
+ value: Transformed value. (B, H, T_2, d_k)
+ scores: Attention score. (B, H, T_1, T_2)
+ mask: Source mask. (B, T_2)
+ chunk_mask: Chunk mask. (T_1, T_1)
+ Returns:
+ attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
+ """
+ batch_size = scores.size(0)
+ mask = mask.unsqueeze(1).unsqueeze(2)
+ if chunk_mask is not None:
+ mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
+ scores = scores.masked_fill(mask, float("-inf"))
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+
+ attn_output = self.dropout(self.attn)
+ attn_output = torch.matmul(attn_output, value)
+
+ attn_output = self.linear_out(
+ attn_output.transpose(1, 2)
+ .contiguous()
+ .view(batch_size, -1, self.num_heads * self.d_k)
+ )
+
+ return attn_output
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Compute scaled dot product attention with rel. positional encoding.
+ Args:
+ query: Query tensor. (B, T_1, size)
+ key: Key tensor. (B, T_2, size)
+ value: Value tensor. (B, T_2, size)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ mask: Source mask. (B, T_2)
+ chunk_mask: Chunk mask. (T_1, T_1)
+ left_context: Number of frames in left context.
+ Returns:
+ : Output tensor. (B, T_1, H * d_k)
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
+ return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
+
+
+class CosineDistanceAttention(nn.Module):
+ """ Compute Cosine Distance between spk decoder output and speaker profile
+ Args:
+ profile_path: speaker profile file path (.npy file)
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, spk_decoder_out, profile, profile_lens=None):
+ """
+ Args:
+ spk_decoder_out(torch.Tensor):(B, L, D)
+ spk_profiles(torch.Tensor):(B, N, D)
+ """
+ x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D)
+ if profile_lens is not None:
+
+ mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
+ )
+ weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
+ weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N)
+ else:
+ x = x[:, -1:, :, :]
+ weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
+ weights = self.softmax(weights_not_softmax) # (B, 1, N)
+ spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D)
+
+ return spk_embedding, weights
diff --git a/funasr/models/ct_transformer/sanm_encoder.py b/funasr/models/ct_transformer/sanm_encoder.py
new file mode 100644
index 0000000..1bdf5d5
--- /dev/null
+++ b/funasr/models/ct_transformer/sanm_encoder.py
@@ -0,0 +1,383 @@
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from funasr.models.scama.chunk_utilis import overlap_chunk
+import numpy as np
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.sanm.attention import MultiHeadedAttention
+from funasr.models.ct_transformer.attention import MultiHeadedAttentionSANMwithMask
+from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
+from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
+from funasr.models.transformer.positionwise_feed_forward import (
+ PositionwiseFeedForward, # noqa: H301
+)
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
+from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
+
+from funasr.models.ctc.ctc import CTC
+
+from funasr.utils.register import register_class
+
+class EncoderLayerSANM(nn.Module):
+ def __init__(
+ self,
+ in_size,
+ size,
+ self_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ stochastic_depth_rate=0.0,
+ ):
+ """Construct an EncoderLayer object."""
+ super(EncoderLayerSANM, self).__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(in_size)
+ self.norm2 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.in_size = in_size
+ self.size = size
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear = nn.Linear(size + size, size)
+ self.stochastic_depth_rate = stochastic_depth_rate
+ self.dropout_rate = dropout_rate
+
+ def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+ """Compute encoded features.
+
+ Args:
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+
+ """
+ skip_layer = False
+ # with stochastic depth, residual connection `x + f(x)` becomes
+ # `x <- x + 1 / (1 - p) * f(x)` at training time.
+ stoch_layer_coeff = 1.0
+ if self.training and self.stochastic_depth_rate > 0:
+ skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+ stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+ if skip_layer:
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+ return x, mask
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+
+ if self.concat_after:
+ x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+ if self.in_size == self.size:
+ x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+ else:
+ x = stoch_layer_coeff * self.concat_linear(x_concat)
+ else:
+ if self.in_size == self.size:
+ x = residual + stoch_layer_coeff * self.dropout(
+ self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ )
+ else:
+ x = stoch_layer_coeff * self.dropout(
+ self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ )
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+
+ def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+ """Compute encoded features.
+
+ Args:
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+
+ """
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+
+ if self.in_size == self.size:
+ attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+ x = residual + attn
+ else:
+ x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.feed_forward(x)
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ return x, cache
+
+
+@register_class("encoder_classes", "SANMVadEncoder")
+class SANMVadEncoder(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: Optional[str] = "conv2d",
+ pos_enc_class=SinusoidalPositionEncoder,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 1,
+ padding_idx: int = -1,
+ interctc_layer_idx: List[int] = [],
+ interctc_use_conditioning: bool = False,
+ kernel_size : int = 11,
+ sanm_shfit : int = 0,
+ selfattention_layer_type: str = "sanm",
+ ):
+ super().__init__()
+ self._output_size = output_size
+
+ if input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(input_size, output_size),
+ torch.nn.LayerNorm(output_size),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d":
+ self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d2":
+ self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d6":
+ self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d8":
+ self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+ elif input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+ SinusoidalPositionEncoder(),
+ )
+ elif input_layer is None:
+ if input_size == output_size:
+ self.embed = None
+ else:
+ self.embed = torch.nn.Linear(input_size, output_size)
+ elif input_layer == "pe":
+ self.embed = SinusoidalPositionEncoder()
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+ self.normalize_before = normalize_before
+ if positionwise_layer_type == "linear":
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d":
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d-linear":
+ positionwise_layer = Conv1dLinear
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ else:
+ raise NotImplementedError("Support only linear or conv1d.")
+
+ if selfattention_layer_type == "selfattn":
+ encoder_selfattn_layer = MultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+
+ elif selfattention_layer_type == "sanm":
+ self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
+ encoder_selfattn_layer_args0 = (
+ attention_heads,
+ input_size,
+ output_size,
+ attention_dropout_rate,
+ kernel_size,
+ sanm_shfit,
+ )
+
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ output_size,
+ attention_dropout_rate,
+ kernel_size,
+ sanm_shfit,
+ )
+
+ self.encoders0 = repeat(
+ 1,
+ lambda lnum: EncoderLayerSANM(
+ input_size,
+ output_size,
+ self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+ self.encoders = repeat(
+ num_blocks-1,
+ lambda lnum: EncoderLayerSANM(
+ output_size,
+ output_size,
+ self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ if self.normalize_before:
+ self.after_norm = LayerNorm(output_size)
+
+ self.interctc_layer_idx = interctc_layer_idx
+ if len(interctc_layer_idx) > 0:
+ assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+ self.interctc_use_conditioning = interctc_use_conditioning
+ self.conditioning_layer = None
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ vad_indexes: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Embed positions in tensor.
+
+ Args:
+ xs_pad: input tensor (B, L, D)
+ ilens: input length (B)
+ prev_states: Not to be used now.
+ Returns:
+ position embedded tensor and mask
+ """
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+ sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
+ no_future_masks = masks & sub_masks
+ xs_pad *= self.output_size()**0.5
+ if self.embed is None:
+ xs_pad = xs_pad
+ elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
+ short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+ if short_status:
+ raise TooShortUttError(
+ f"has {xs_pad.size(1)} frames and is too short for subsampling " +
+ f"(it needs more than {limit_size} frames), return empty results",
+ xs_pad.size(1),
+ limit_size,
+ )
+ xs_pad, masks = self.embed(xs_pad, masks)
+ else:
+ xs_pad = self.embed(xs_pad)
+
+ # xs_pad = self.dropout(xs_pad)
+ mask_tup0 = [masks, no_future_masks]
+ encoder_outs = self.encoders0(xs_pad, mask_tup0)
+ xs_pad, _ = encoder_outs[0], encoder_outs[1]
+ intermediate_outs = []
+
+
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ if layer_idx + 1 == len(self.encoders):
+ # This is last layer.
+ coner_mask = torch.ones(masks.size(0),
+ masks.size(-1),
+ masks.size(-1),
+ device=xs_pad.device,
+ dtype=torch.bool)
+ for word_index, length in enumerate(ilens):
+ coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
+ vad_indexes[word_index],
+ device=xs_pad.device)
+ layer_mask = masks & coner_mask
+ else:
+ layer_mask = no_future_masks
+ mask_tup1 = [masks, layer_mask]
+ encoder_outs = encoder_layer(xs_pad, mask_tup1)
+ xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ olens = masks.squeeze(1).sum(1)
+ if len(intermediate_outs) > 0:
+ return (xs_pad, intermediate_outs), olens, None
+ return xs_pad, olens, None
diff --git a/funasr/models/ct_transformer/target_delay_transformer.py b/funasr/models/ct_transformer/target_delay_transformer.py
index f3bc772..59884a3 100644
--- a/funasr/models/ct_transformer/target_delay_transformer.py
+++ b/funasr/models/ct_transformer/target_delay_transformer.py
@@ -6,7 +6,7 @@
import torch.nn as nn
from funasr.models.transformer.embedding import SinusoidalPositionEncoder
-from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
+from funasr.models.sanm.encoder import SANMEncoder as Encoder
class TargetDelayTransformer(torch.nn.Module):
diff --git a/funasr/models/fsmn_vad/vad_realtime_transformer.py b/funasr/models/ct_transformer/vad_realtime_transformer.py
similarity index 97%
rename from funasr/models/fsmn_vad/vad_realtime_transformer.py
rename to funasr/models/ct_transformer/vad_realtime_transformer.py
index e02f624..155057c 100644
--- a/funasr/models/fsmn_vad/vad_realtime_transformer.py
+++ b/funasr/models/ct_transformer/vad_realtime_transformer.py
@@ -6,7 +6,7 @@
import torch.nn as nn
from funasr.models.transformer.embedding import SinusoidalPositionEncoder
-from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
+from funasr.models.ct_transformer.sanm_encoder import SANMVadEncoder as Encoder
class VadRealtimeTransformer(torch.nn.Module):
diff --git a/funasr/models/data2vec/data2vec.py b/funasr/models/data2vec/data2vec.py
index 19c5612..c77cedf 100644
--- a/funasr/models/data2vec/data2vec.py
+++ b/funasr/models/data2vec/data2vec.py
@@ -10,13 +10,14 @@
from typing import Tuple
import torch
+import torch.nn as nn
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.models.base_model import FunASRModel
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.specaug.abs_specaug import AbsSpecAug
+# from funasr.layers.abs_normalize import AbsNormalize
+# from funasr.models.base_model import FunASRModel
+# from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.frontends.abs_frontend import AbsFrontend
+# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+# from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.train_utils.device_funcs import force_gatherable
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
@@ -28,16 +29,16 @@
yield
-class Data2VecPretrainModel(FunASRModel):
+class Data2VecPretrainModel(nn.Module):
"""Data2Vec Pretrain model"""
def __init__(
self,
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- encoder: AbsEncoder,
- preencoder: Optional[AbsPreEncoder] = None,
+ frontend = None,
+ specaug = None,
+ normalize = None,
+ encoder = None,
+ preencoder = None,
):
super().__init__()
diff --git a/funasr/models/data2vec/data2vec_encoder.py b/funasr/models/data2vec/data2vec_encoder.py
index 1bcb639..4689e20 100644
--- a/funasr/models/data2vec/data2vec_encoder.py
+++ b/funasr/models/data2vec/data2vec_encoder.py
@@ -11,7 +11,6 @@
import torch.nn as nn
import torch.nn.functional as F
-from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.data2vec.data_utils import compute_mask_indices
from funasr.models.data2vec.ema_module import EMAModule
from funasr.models.data2vec.grad_multiply import GradMultiply
@@ -28,7 +27,7 @@
return end - r * pct_remaining
-class Data2VecEncoder(AbsEncoder):
+class Data2VecEncoder(nn.Module):
def __init__(
self,
# for ConvFeatureExtractionModel
diff --git a/funasr/models/e_branchformer/e_branchformer_encoder.py b/funasr/models/e_branchformer/encoder.py
similarity index 97%
rename from funasr/models/e_branchformer/e_branchformer_encoder.py
rename to funasr/models/e_branchformer/encoder.py
index a7b4fde..5604c9f 100644
--- a/funasr/models/e_branchformer/e_branchformer_encoder.py
+++ b/funasr/models/e_branchformer/encoder.py
@@ -13,9 +13,8 @@
from typing import List, Optional, Tuple
import torch
-
-from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
+import torch.nn as nn
+from funasr.models.ctc.ctc import CTC
from funasr.models.branchformer.cgmlp import ConvolutionalGatingMLP
from funasr.models.branchformer.fastformer import FastSelfAttention
from funasr.models.transformer.utils.nets_utils import get_activation, make_pad_mask
@@ -34,8 +33,8 @@
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward,
)
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.subsampling import (
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling2,
Conv2dSubsampling6,
@@ -43,7 +42,7 @@
TooShortUttError,
check_short_utt,
)
-
+from funasr.utils.register import register_class
class EBranchformerEncoderLayer(torch.nn.Module):
"""E-Branchformer encoder layer module.
@@ -175,8 +174,8 @@
return x, mask
-
-class EBranchformerEncoder(AbsEncoder):
+@register_class("encoder_classes", "EBranchformerEncoder")
+class EBranchformerEncoder(nn.Module):
"""E-Branchformer encoder module."""
def __init__(
diff --git a/funasr/models/e_branchformer/model.py b/funasr/models/e_branchformer/model.py
index ca3e0f5..ccf1320 100644
--- a/funasr/models/e_branchformer/model.py
+++ b/funasr/models/e_branchformer/model.py
@@ -1,57 +1,9 @@
import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
-import torch
-import torch.nn as nn
-import random
-import numpy as np
-import time
-# from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
-from funasr.models.paraformer.search import Hypothesis
-
-from funasr.models.model_class_factory import *
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
-else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.transformer.model import Transformer
+from funasr.utils.register import register_class
+@register_class("model_classes", "EBranchformer")
class EBranchformer(Transformer):
"""CTC-attention hybrid Encoder-Decoder model"""
diff --git a/funasr/models/eend/e2e_diar_eend_ola.py b/funasr/models/eend/e2e_diar_eend_ola.py
index 28aa223..ee80836 100644
--- a/funasr/models/eend/e2e_diar_eend_ola.py
+++ b/funasr/models/eend/e2e_diar_eend_ola.py
@@ -7,7 +7,7 @@
import torch.nn as nn
import torch.nn.functional as F
-from funasr.models.frontend.wav_frontend import WavFrontendMel23
+from funasr.frontends.wav_frontend import WavFrontendMel23
from funasr.models.eend.encoder import EENDOLATransformerEncoder
from funasr.models.eend.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.models.eend.utils.losses import standard_loss, cal_power_loss, fast_batch_pit_n_speaker_loss
diff --git a/funasr/models/eend/encoder.py b/funasr/models/eend/encoder.py
index 3065884..430f0d0 100644
--- a/funasr/models/eend/encoder.py
+++ b/funasr/models/eend/encoder.py
@@ -7,7 +7,7 @@
class MultiHeadSelfAttention(nn.Module):
def __init__(self, n_units, h=8, dropout_rate=0.1):
- super(MultiHeadSelfAttention, self).__init__()
+ super().__init__()
self.linearQ = nn.Linear(n_units, n_units)
self.linearK = nn.Linear(n_units, n_units)
self.linearV = nn.Linear(n_units, n_units)
diff --git a/funasr/models/frontend/__init__.py b/funasr/models/frontend/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models/frontend/__init__.py
+++ /dev/null
diff --git a/funasr/models/frontend/abs_frontend.py b/funasr/models/frontend/abs_frontend.py
deleted file mode 100644
index 6049a01..0000000
--- a/funasr/models/frontend/abs_frontend.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Tuple
-
-import torch
-
-
-class AbsFrontend(torch.nn.Module, ABC):
- @abstractmethod
- def output_size(self) -> int:
- raise NotImplementedError
-
- @abstractmethod
- def forward(
- self, input: torch.Tensor, input_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index b930e0c..cc3c87e 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -27,7 +27,7 @@
self.vad_opts.speech_to_sil_time_thres,
self.vad_opts.frame_in_ms)
- encoder_class = encoder_choices.get_class(encoder)
+ encoder_class = encoder_classes.get_class(encoder)
encoder = encoder_class(**encoder_conf)
self.encoder = encoder
# init variables
diff --git a/funasr/models/language_model/rnn/decoders.py b/funasr/models/language_model/rnn/decoders.py
index a426b51..fd8bc29 100644
--- a/funasr/models/language_model/rnn/decoders.py
+++ b/funasr/models/language_model/rnn/decoders.py
@@ -15,7 +15,7 @@
from funasr.metrics import end_detect
from funasr.models.transformer.utils.nets_utils import mask_by_length
from funasr.models.transformer.utils.nets_utils import pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
from funasr.models.transformer.utils.nets_utils import to_device
from funasr.models.language_model.rnn.attentions import att_to_numpy
diff --git a/funasr/models/language_model/rnn/encoders.py b/funasr/models/language_model/rnn/encoders.py
index 819585b..d047a8e 100644
--- a/funasr/models/language_model/rnn/encoders.py
+++ b/funasr/models/language_model/rnn/encoders.py
@@ -7,7 +7,7 @@
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
-from funasr.metrics import get_vgg2l_odim
+from funasr.metrics.common import get_vgg2l_odim
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.nets_utils import to_device
diff --git a/funasr/models/transformer/transformer_encoder.py b/funasr/models/language_model/transformer_encoder.py
similarity index 66%
rename from funasr/models/transformer/transformer_encoder.py
rename to funasr/models/language_model/transformer_encoder.py
index 1126da0..21f3548 100644
--- a/funasr/models/transformer/transformer_encoder.py
+++ b/funasr/models/language_model/transformer_encoder.py
@@ -11,8 +11,6 @@
from torch import nn
import logging
-from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
@@ -22,18 +20,17 @@
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.utils.nets_utils import rename_state_dict
+from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
from funasr.models.transformer.utils.lightconv import LightweightConvolution
from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
-from funasr.models.transformer.subsampling import Conv2dSubsampling
-from funasr.models.transformer.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.subsampling import TooShortUttError
-from funasr.models.transformer.subsampling import check_short_utt
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
class EncoderLayer(nn.Module):
@@ -143,216 +140,7 @@
return x, mask
-class TransformerEncoder(AbsEncoder):
- """Transformer encoder module.
-
- Args:
- input_size: input dim
- output_size: dimension of attention
- attention_heads: the number of heads of multi head attention
- linear_units: the number of units of position-wise feed forward
- num_blocks: the number of decoder blocks
- dropout_rate: dropout rate
- attention_dropout_rate: dropout rate in attention
- positional_dropout_rate: dropout rate after adding positional encoding
- input_layer: input layer type
- pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
- normalize_before: whether to use layer_norm before the first block
- concat_after: whether to concat attention layer's input and output
- if True, additional linear will be applied.
- i.e. x -> x + linear(concat(x, att(x)))
- if False, no additional linear will be applied.
- i.e. x -> x + att(x)
- positionwise_layer_type: linear of conv1d
- positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
- padding_idx: padding_idx for input_layer=embed
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: Optional[str] = "conv2d",
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- ):
- super().__init__()
- self._output_size = output_size
-
- if input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(input_size, output_size),
- torch.nn.LayerNorm(output_size),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d":
- self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d2":
- self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d6":
- self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d8":
- self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
- elif input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer is None:
- if input_size == output_size:
- self.embed = None
- else:
- self.embed = torch.nn.Linear(input_size, output_size)
- else:
- raise ValueError("unknown input_layer: " + input_layer)
- self.normalize_before = normalize_before
- if positionwise_layer_type == "linear":
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d":
- positionwise_layer = MultiLayeredConv1d
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d-linear":
- positionwise_layer = Conv1dLinear
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- else:
- raise NotImplementedError("Support only linear or conv1d.")
- self.encoders = repeat(
- num_blocks,
- lambda lnum: EncoderLayer(
- output_size,
- MultiHeadedAttention(
- attention_heads, output_size, attention_dropout_rate
- ),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
-
- self.interctc_layer_idx = interctc_layer_idx
- if len(interctc_layer_idx) > 0:
- assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
- self.interctc_use_conditioning = interctc_use_conditioning
- self.conditioning_layer = None
-
- def output_size(self) -> int:
- return self._output_size
-
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Embed positions in tensor.
-
- Args:
- xs_pad: input tensor (B, L, D)
- ilens: input length (B)
- prev_states: Not to be used now.
- Returns:
- position embedded tensor and mask
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-
- if self.embed is None:
- xs_pad = xs_pad
- elif (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
- ):
- short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
- if short_status:
- raise TooShortUttError(
- f"has {xs_pad.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- xs_pad.size(1),
- limit_size,
- )
- xs_pad, masks = self.embed(xs_pad, masks)
- else:
- xs_pad = self.embed(xs_pad)
-
- intermediate_outs = []
- if len(self.interctc_layer_idx) == 0:
- xs_pad, masks = self.encoders(xs_pad, masks)
- else:
- for layer_idx, encoder_layer in enumerate(self.encoders):
- xs_pad, masks = encoder_layer(xs_pad, masks)
-
- if layer_idx + 1 in self.interctc_layer_idx:
- encoder_out = xs_pad
-
- # intermediate outputs are also normalized
- if self.normalize_before:
- encoder_out = self.after_norm(encoder_out)
-
- intermediate_outs.append((layer_idx + 1, encoder_out))
-
- if self.interctc_use_conditioning:
- ctc_out = ctc.softmax(encoder_out)
- xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
-
- olens = masks.squeeze(1).sum(1)
- if len(intermediate_outs) > 0:
- return (xs_pad, intermediate_outs), olens, None
- return xs_pad, olens, None
-
-
-def _pre_hook(
- state_dict,
- prefix,
- local_metadata,
- strict,
- missing_keys,
- unexpected_keys,
- error_msgs,
-):
- # https://github.com/espnet/espnet/commit/21d70286c354c66c0350e65dc098d2ee236faccc#diff-bffb1396f038b317b2b64dd96e6d3563
- rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict)
- # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
- rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
-
-
-class TransformerEncoder_s0(torch.nn.Module):
+class TransformerEncoder_lm(nn.Module):
"""Transformer encoder module.
Args:
@@ -418,8 +206,7 @@
conditioning_layer_dim=None,
):
"""Construct an Encoder object."""
- super(TransformerEncoder_s0, self).__init__()
- self._register_load_state_dict_pre_hook(_pre_hook)
+ super().__init__()
self.conv_subsampling_factor = 1
if input_layer == "linear":
diff --git a/funasr/models/mfcca/e2e_asr_mfcca.py b/funasr/models/mfcca/e2e_asr_mfcca.py
index 5ec9e94..48534dd 100644
--- a/funasr/models/mfcca/e2e_asr_mfcca.py
+++ b/funasr/models/mfcca/e2e_asr_mfcca.py
@@ -9,15 +9,15 @@
import torch
from funasr.metrics import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
diff --git a/funasr/models/mfcca/mfcca_encoder.py b/funasr/models/mfcca/mfcca_encoder.py
index 92dd6e7..52da33d 100644
--- a/funasr/models/mfcca/mfcca_encoder.py
+++ b/funasr/models/mfcca/mfcca_encoder.py
@@ -26,13 +26,13 @@
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.subsampling import Conv2dSubsampling
-from funasr.models.transformer.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.subsampling import TooShortUttError
-from funasr.models.transformer.subsampling import check_short_utt
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.encoder.abs_encoder import AbsEncoder
import pdb
import math
diff --git a/funasr/models/model_class_factory.py b/funasr/models/model_class_factory.py
deleted file mode 100644
index 23d2304..0000000
--- a/funasr/models/model_class_factory.py
+++ /dev/null
@@ -1,162 +0,0 @@
-
-from funasr.models.normalize.global_mvn import GlobalMVN
-from funasr.models.normalize.utterance_mvn import UtteranceMVN
-from funasr.models.ctc.ctc import CTC
-
-from funasr.models.transducer.rnn_decoder import RNNDecoder
-from funasr.models.sanm.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
-from funasr.models.transformer.transformer_decoder import (
- DynamicConvolution2DTransformerDecoder, # noqa: H301
-)
-from funasr.models.transformer.transformer_decoder import DynamicConvolutionTransformerDecoder
-from funasr.models.transformer.transformer_decoder import (
- LightweightConvolution2DTransformerDecoder, # noqa: H301
-)
-from funasr.models.transformer.transformer_decoder import (
- LightweightConvolutionTransformerDecoder, # noqa: H301
-)
-from funasr.models.transformer.transformer_decoder import ParaformerDecoderSAN
-from funasr.models.transformer.transformer_decoder import TransformerDecoder
-from funasr.models.paraformer.contextual_decoder import ContextualParaformerDecoder
-from funasr.models.transformer.transformer_decoder import SAAsrTransformerDecoder
-
-from funasr.models.transducer.rnnt_decoder import RNNTDecoder
-from funasr.models.transducer.joint_network import JointNetwork
-
-
-from funasr.models.conformer.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
-from funasr.models.data2vec.data2vec_encoder import Data2VecEncoder
-from funasr.models.transducer.rnn_encoder import RNNEncoder
-from funasr.models.sanm.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
-from funasr.models.transformer.transformer_encoder import TransformerEncoder
-from funasr.models.branchformer.branchformer_encoder import BranchformerEncoder
-from funasr.models.e_branchformer.e_branchformer_encoder import EBranchformerEncoder
-from funasr.models.mfcca.mfcca_encoder import MFCCAEncoder
-from funasr.models.sond.encoder.resnet34_encoder import ResNet34Diar
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.default import MultiChannelFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-
-
-from funasr.models.paraformer.cif_predictor import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
-from funasr.models.specaug.specaug import SpecAug
-from funasr.models.specaug.specaug import SpecAugLFR
-from funasr.models.transformer.subsampling import Conv1dSubsampling
-from funasr.utils.class_choices import ClassChoices
-from funasr.models.fsmn_vad.fsmn_encoder import FSMN
-
-from funasr.models.sond.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
-from funasr.models.sond.encoder.conv_encoder import ConvEncoder
-from funasr.models.sond.encoder.fsmn_encoder import FsmnEncoder
-from funasr.models.sond.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
-
-from funasr.models.sond.encoder.conv_encoder import ConvEncoder
-from funasr.models.sond.encoder.fsmn_encoder import FsmnEncoder
-from funasr.models.eend.encoder_decoder_attractor import EncoderDecoderAttractor
-from funasr.models.eend.encoder import EENDOLATransformerEncoder
-
-frontend_choices = ClassChoices(
- name="frontend",
- classes=dict(
- default=DefaultFrontend,
- sliding_window=SlidingWindow,
- s3prl=S3prlFrontend,
- fused=FusedFrontends,
- wav_frontend=WavFrontend,
- multichannelfrontend=MultiChannelFrontend,
- ),
- default="default",
-)
-specaug_choices = ClassChoices(
- name="specaug",
- classes=dict(
- specaug=SpecAug,
- specaug_lfr=SpecAugLFR,
- ),
- default=None,
-)
-normalize_choices = ClassChoices(
- "normalize",
- classes=dict(
- global_mvn=GlobalMVN,
- utterance_mvn=UtteranceMVN,
- ),
- default=None,
-)
-
-encoder_choices = ClassChoices(
- "encoder",
- classes=dict(
- conformer=ConformerEncoder,
- transformer=TransformerEncoder,
- rnn=RNNEncoder,
- sanm=SANMEncoder,
- sanm_chunk_opt=SANMEncoderChunkOpt,
- data2vec_encoder=Data2VecEncoder,
- mfcca_enc=MFCCAEncoder,
- chunk_conformer=ConformerChunkEncoder,
- fsmn=FSMN,
- branchformer=BranchformerEncoder,
- e_branchformer=EBranchformerEncoder,
- resnet34=ResNet34Diar,
- resnet34_sp_l2reg=ResNet34SpL2RegDiar,
- ecapa_tdnn=ECAPA_TDNN,
- eend_ola_transformer=EENDOLATransformerEncoder,
- conv=ConvEncoder,
- resnet34_diar=ResNet34Diar,
- ),
- default="rnn",
-)
-
-
-decoder_choices = ClassChoices(
- "decoder",
- classes=dict(
- transformer=TransformerDecoder,
- lightweight_conv=LightweightConvolutionTransformerDecoder,
- lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
- dynamic_conv=DynamicConvolutionTransformerDecoder,
- dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
- rnn=RNNDecoder,
- fsmn_scama_opt=FsmnDecoderSCAMAOpt,
- paraformer_decoder_sanm=ParaformerSANMDecoder,
- paraformer_decoder_san=ParaformerDecoderSAN,
- contextual_paraformer_decoder=ContextualParaformerDecoder,
- sa_decoder=SAAsrTransformerDecoder,
- rnnt=RNNTDecoder,
- ),
- default="transformer",
-)
-
-
-joint_network_choices = ClassChoices(
- name="joint_network",
- classes=dict(
- joint_network=JointNetwork,
- ),
- default="joint_network",
-)
-
-predictor_choices = ClassChoices(
- name="predictor",
- classes=dict(
- cif_predictor=CifPredictor,
- ctc_predictor=None,
- cif_predictor_v2=CifPredictorV2,
- cif_predictor_v3=CifPredictorV3,
- bat_predictor=BATPredictor,
- ),
- default="cif_predictor",
-)
-
-stride_conv_choices = ClassChoices(
- name="stride_conv",
- classes=dict(
- stride_conv1d=Conv1dSubsampling
- ),
- default="stride_conv1d",
-)
\ No newline at end of file
diff --git a/funasr/models/paraformer/contextual_decoder.py b/funasr/models/neat_contextual_paraformer/decoder.py
similarity index 98%
rename from funasr/models/paraformer/contextual_decoder.py
rename to funasr/models/neat_contextual_paraformer/decoder.py
index 626cdef..ca689d3 100644
--- a/funasr/models/paraformer/contextual_decoder.py
+++ b/funasr/models/neat_contextual_paraformer/decoder.py
@@ -6,15 +6,15 @@
import numpy as np
from funasr.models.scama import utils as myutils
-from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.models.transformer.repeat import repeat
-from funasr.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.paraformer.decoder import DecoderLayerSANM, ParaformerSANMDecoder
+from funasr.utils.register import register_class, registry_tables
class ContextualDecoderLayer(nn.Module):
def __init__(
@@ -98,7 +98,7 @@
x = self.dropout(self.src_attn(x, memory, memory_mask))
return x, tgt_mask, memory, memory_mask, cache
-
+@register_class("decoder_classes", "ContextualParaformerDecoder")
class ContextualParaformerDecoder(ParaformerSANMDecoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
diff --git a/funasr/models/neat_contextual_paraformer/model.py b/funasr/models/neat_contextual_paraformer/model.py
new file mode 100644
index 0000000..7891307
--- /dev/null
+++ b/funasr/models/neat_contextual_paraformer/model.py
@@ -0,0 +1,533 @@
+import os
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+import tempfile
+import codecs
+import requests
+import re
+import copy
+import torch
+import torch.nn as nn
+import random
+import numpy as np
+import time
+# from funasr.layers.abs_normalize import AbsNormalize
+from funasr.losses.label_smoothing_loss import (
+ LabelSmoothingLoss, # noqa: H301
+)
+# from funasr.models.ctc import CTC
+# from funasr.models.decoder.abs_decoder import AbsDecoder
+# from funasr.models.e2e_asr_common import ErrorCalculator
+# from funasr.models.encoder.abs_encoder import AbsEncoder
+# from funasr.frontends.abs_frontend import AbsFrontend
+# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.paraformer.cif_predictor import mae_loss
+# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+# from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import force_gatherable
+# from funasr.models.base_model import FunASRModel
+# from funasr.models.paraformer.cif_predictor import CifPredictorV3
+from funasr.models.paraformer.search import Hypothesis
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+
+
+from funasr.models.paraformer.model import Paraformer
+
+from funasr.utils.register import register_class, registry_tables
+
+@register_class("model_classes", "NeatContextualParaformer")
+class NeatContextualParaformer(Paraformer):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.target_buffer_length = kwargs.get("target_buffer_length", -1)
+ inner_dim = kwargs.get("inner_dim", 256)
+ bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
+ use_decoder_embedding = kwargs.get("use_decoder_embedding", False)
+ crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
+ crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
+ bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
+
+
+ if bias_encoder_type == 'lstm':
+ logging.warning("enable bias encoder sampling and contextual training")
+ self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
+ self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
+ elif bias_encoder_type == 'mean':
+ logging.warning("enable bias encoder sampling and contextual training")
+ self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
+ else:
+ logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type))
+
+ if self.target_buffer_length > 0:
+ self.hotword_buffer = None
+ self.length_record = []
+ self.current_buffer_length = 0
+ self.use_decoder_embedding = use_decoder_embedding
+ self.crit_attn_weight = crit_attn_weight
+ if self.crit_attn_weight > 0:
+ self.attn_loss = torch.nn.L1Loss()
+ self.crit_attn_smooth = crit_attn_smooth
+
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ hotword_pad = kwargs.get("hotword_pad")
+ hotword_lengths = kwargs.get("hotword_lengths")
+ dha_pad = kwargs.get("dha_pad")
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+
+ loss_ctc, cer_ctc = None, None
+
+ stats = dict()
+
+ # 1. CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+
+ # 2b. Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
+ encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+ if loss_ideal is not None:
+ loss = loss + loss_ideal * self.crit_attn_weight
+ stats["loss_ideal"] = loss_ideal.detach().cpu()
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+ stats["loss"] = torch.clone(loss.detach())
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = int((text_lengths + self.predictor_bias).sum())
+
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+
+ def _calc_att_clas_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ hotword_pad: torch.Tensor,
+ hotword_lengths: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # -1. bias encoder
+ if self.use_decoder_embedding:
+ hw_embed = self.decoder.embed(hotword_pad)
+ else:
+ hw_embed = self.bias_embed(hotword_pad)
+ hw_embed, (_, _) = self.bias_encoder(hw_embed)
+ _ind = np.arange(0, hotword_pad.shape[0]).tolist()
+ selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
+ contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
+
+ # 0. sampler
+ decoder_out_1st = None
+ if self.sampling_ratio > 0.0:
+ if self.step_cur < 2:
+ logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds, contextual_info)
+ else:
+ if self.step_cur < 2:
+ logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ '''
+ if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
+ ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
+ attn_non_blank = attn[:,:,:,:-1]
+ ideal_attn_non_blank = ideal_attn[:,:,:-1]
+ loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device))
+ else:
+ loss_ideal = None
+ '''
+ loss_ideal = None
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
+
+
+ def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad)
+ with torch.no_grad():
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0,
+ index=torch.randperm(seq_lens[li])[:target_num].to(pre_acoustic_embeds.device),
+ value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
+ clas_scale=1.0):
+ if hw_list is None:
+ hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list
+ hw_list_pad = pad_list(hw_list, 0)
+ if self.use_decoder_embedding:
+ hw_embed = self.decoder.embed(hw_list_pad)
+ else:
+ hw_embed = self.bias_embed(hw_list_pad)
+ hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
+ hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
+ else:
+ hw_lengths = [len(i) for i in hw_list]
+ hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
+ if self.use_decoder_embedding:
+ hw_embed = self.decoder.embed(hw_list_pad)
+ else:
+ hw_embed = self.bias_embed(hw_list_pad)
+ hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
+ enforce_sorted=False)
+ _, (h_n, _) = self.bias_encoder(hw_embed)
+ hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
+
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
+ def generate(self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data[
+ "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+ # hotword
+ self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # predictor
+ predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return []
+
+
+ decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
+ pre_acoustic_embeds,
+ pre_token_length,
+ hw_list=self.hotword_list,
+ clas_scale=kwargs.get("clas_scale", 1.0))
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+ minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(
+ filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ if tokenizer is not None:
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+ else:
+ result_i = {"key": key[i], "token_int": token_int}
+ results.append(result_i)
+
+ return results, meta_data
+
+
+ def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
+ def load_seg_dict(seg_dict_file):
+ seg_dict = {}
+ assert isinstance(seg_dict_file, str)
+ with open(seg_dict_file, "r", encoding="utf8") as f:
+ lines = f.readlines()
+ for line in lines:
+ s = line.strip().split()
+ key = s[0]
+ value = s[1:]
+ seg_dict[key] = " ".join(value)
+ return seg_dict
+
+ def seg_tokenize(txt, seg_dict):
+ pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
+ out_txt = ""
+ for word in txt:
+ word = word.lower()
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ if pattern.match(word):
+ for char in word:
+ if char in seg_dict:
+ out_txt += seg_dict[char] + " "
+ else:
+ out_txt += "<unk>" + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
+ seg_dict = None
+ if frontend.cmvn_file is not None:
+ model_dir = os.path.dirname(frontend.cmvn_file)
+ seg_dict_file = os.path.join(model_dir, 'seg_dict')
+ if os.path.exists(seg_dict_file):
+ seg_dict = load_seg_dict(seg_dict_file)
+ else:
+ seg_dict = None
+ # for None
+ if hotword_list_or_file is None:
+ hotword_list = None
+ # for local txt inputs
+ elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
+ logging.info("Attempting to parse hotwords from local txt...")
+ hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hw_list = hw.split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
+ hotword_str_list.append(hw)
+ hotword_list.append(tokenizer.tokens2ids(hw_list))
+ hotword_list.append([self.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ # for url, download and generate txt
+ elif hotword_list_or_file.startswith('http'):
+ logging.info("Attempting to parse hotwords from url...")
+ work_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(work_dir):
+ os.makedirs(work_dir)
+ text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+ local_file = requests.get(hotword_list_or_file)
+ open(text_file_path, "wb").write(local_file.content)
+ hotword_list_or_file = text_file_path
+ hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hw_list = hw.split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
+ hotword_str_list.append(hw)
+ hotword_list.append(tokenizer.tokens2ids(hw_list))
+ hotword_list.append([self.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ # for text str input
+ elif not hotword_list_or_file.endswith('.txt'):
+ logging.info("Attempting to parse hotwords as str...")
+ hotword_list = []
+ hotword_str_list = []
+ for hw in hotword_list_or_file.strip().split():
+ hotword_str_list.append(hw)
+ hw_list = hw.strip().split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
+ hotword_list.append(tokenizer.tokens2ids(hw_list))
+ hotword_list.append([self.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Hotword list: {}.".format(hotword_str_list))
+ else:
+ hotword_list = None
+ return hotword_list
+
diff --git a/funasr/models/normalize/global_mvn.py b/funasr/models/normalize/global_mvn.py
index a9b7935..eea84dc 100644
--- a/funasr/models/normalize/global_mvn.py
+++ b/funasr/models/normalize/global_mvn.py
@@ -6,8 +6,9 @@
import torch
from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.utils.register import register_class, registry_tables
-
+@register_class("normalize_classes", "GlobalMVN")
class GlobalMVN(torch.nn.Module):
"""Apply global mean and variance normalization
TODO(kamo): Make this class portable somehow
diff --git a/funasr/models/normalize/utterance_mvn.py b/funasr/models/normalize/utterance_mvn.py
index 609c7aa..60703fb 100644
--- a/funasr/models/normalize/utterance_mvn.py
+++ b/funasr/models/normalize/utterance_mvn.py
@@ -3,8 +3,9 @@
import torch
from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.utils.register import register_class, registry_tables
-
+@register_class("normalize_classes", "UtteranceMVN")
class UtteranceMVN(torch.nn.Module):
def __init__(
self,
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 7cc088f..c1b7d7a 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -8,9 +8,12 @@
from funasr.models.scama.utils import sequence_mask
from typing import Optional, Tuple
+from funasr.utils.register import register_class, registry_tables
+
+@register_class("predictor_classes", "CifPredictor")
class CifPredictor(nn.Module):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
- super(CifPredictor, self).__init__()
+ super().__init__()
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
@@ -133,7 +136,7 @@
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
return predictor_alignments.detach(), predictor_alignments_length.detach()
-
+@register_class("predictor_classes", "CifPredictorV2")
class CifPredictorV2(nn.Module):
def __init__(self,
idim,
@@ -505,372 +508,3 @@
fires = torch.stack(list_fires, 1)
return fires
-
-class CifPredictorV3(nn.Module):
- def __init__(self,
- idim,
- l_order,
- r_order,
- threshold=1.0,
- dropout=0.1,
- smooth_factor=1.0,
- noise_threshold=0,
- tail_threshold=0.0,
- tf2torch_tensor_name_prefix_torch="predictor",
- tf2torch_tensor_name_prefix_tf="seq2seq/cif",
- smooth_factor2=1.0,
- noise_threshold2=0,
- upsample_times=5,
- upsample_type="cnn",
- use_cif1_cnn=True,
- tail_mask=True,
- ):
- super(CifPredictorV3, self).__init__()
-
- self.pad = nn.ConstantPad1d((l_order, r_order), 0)
- self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
- self.cif_output = nn.Linear(idim, 1)
- self.dropout = torch.nn.Dropout(p=dropout)
- self.threshold = threshold
- self.smooth_factor = smooth_factor
- self.noise_threshold = noise_threshold
- self.tail_threshold = tail_threshold
- self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
- self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
-
- self.upsample_times = upsample_times
- self.upsample_type = upsample_type
- self.use_cif1_cnn = use_cif1_cnn
- if self.upsample_type == 'cnn':
- self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
- self.cif_output2 = nn.Linear(idim, 1)
- elif self.upsample_type == 'cnn_blstm':
- self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
- self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
- self.cif_output2 = nn.Linear(idim*2, 1)
- elif self.upsample_type == 'cnn_attn':
- self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
- from funasr.models.encoder.transformer_encoder import EncoderLayer as TransformerEncoderLayer
- from funasr.models.transformer.attention import MultiHeadedAttention
- from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
- positionwise_layer_args = (
- idim,
- idim*2,
- 0.1,
- )
- self.self_attn = TransformerEncoderLayer(
- idim,
- MultiHeadedAttention(
- 4, idim, 0.1
- ),
- PositionwiseFeedForward(*positionwise_layer_args),
- 0.1,
- True, #normalize_before,
- False, #concat_after,
- )
- self.cif_output2 = nn.Linear(idim, 1)
- self.smooth_factor2 = smooth_factor2
- self.noise_threshold2 = noise_threshold2
-
- def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
- target_label_length=None):
- h = hidden
- context = h.transpose(1, 2)
- queries = self.pad(context)
- output = torch.relu(self.cif_conv1d(queries))
-
- # alphas2 is an extra head for timestamp prediction
- if not self.use_cif1_cnn:
- _output = context
- else:
- _output = output
- if self.upsample_type == 'cnn':
- output2 = self.upsample_cnn(_output)
- output2 = output2.transpose(1,2)
- elif self.upsample_type == 'cnn_blstm':
- output2 = self.upsample_cnn(_output)
- output2 = output2.transpose(1,2)
- output2, (_, _) = self.blstm(output2)
- elif self.upsample_type == 'cnn_attn':
- output2 = self.upsample_cnn(_output)
- output2 = output2.transpose(1,2)
- output2, _ = self.self_attn(output2, mask)
- # import pdb; pdb.set_trace()
- alphas2 = torch.sigmoid(self.cif_output2(output2))
- alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
- # repeat the mask in T demension to match the upsampled length
- if mask is not None:
- mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
- mask2 = mask2.unsqueeze(-1)
- alphas2 = alphas2 * mask2
- alphas2 = alphas2.squeeze(-1)
- token_num2 = alphas2.sum(-1)
-
- output = output.transpose(1, 2)
-
- output = self.cif_output(output)
- alphas = torch.sigmoid(output)
- alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
- if mask is not None:
- mask = mask.transpose(-1, -2).float()
- alphas = alphas * mask
- if mask_chunk_predictor is not None:
- alphas = alphas * mask_chunk_predictor
- alphas = alphas.squeeze(-1)
- mask = mask.squeeze(-1)
- if target_label_length is not None:
- target_length = target_label_length
- elif target_label is not None:
- target_length = (target_label != ignore_id).float().sum(-1)
- else:
- target_length = None
- token_num = alphas.sum(-1)
-
- if target_length is not None:
- alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
- elif self.tail_threshold > 0.0:
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
-
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
- if target_length is None and self.tail_threshold > 0.0:
- token_num_int = torch.max(token_num).type(torch.int32).item()
- acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
- return acoustic_embeds, token_num, alphas, cif_peak, token_num2
-
- def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
- h = hidden
- b = hidden.shape[0]
- context = h.transpose(1, 2)
- queries = self.pad(context)
- output = torch.relu(self.cif_conv1d(queries))
-
- # alphas2 is an extra head for timestamp prediction
- if not self.use_cif1_cnn:
- _output = context
- else:
- _output = output
- if self.upsample_type == 'cnn':
- output2 = self.upsample_cnn(_output)
- output2 = output2.transpose(1,2)
- elif self.upsample_type == 'cnn_blstm':
- output2 = self.upsample_cnn(_output)
- output2 = output2.transpose(1,2)
- output2, (_, _) = self.blstm(output2)
- elif self.upsample_type == 'cnn_attn':
- output2 = self.upsample_cnn(_output)
- output2 = output2.transpose(1,2)
- output2, _ = self.self_attn(output2, mask)
- alphas2 = torch.sigmoid(self.cif_output2(output2))
- alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
- # repeat the mask in T demension to match the upsampled length
- if mask is not None:
- mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
- mask2 = mask2.unsqueeze(-1)
- alphas2 = alphas2 * mask2
- alphas2 = alphas2.squeeze(-1)
- _token_num = alphas2.sum(-1)
- if token_num is not None:
- alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
- # re-downsample
- ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
- ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
- # upsampled alphas and cif_peak
- us_alphas = alphas2
- us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
- return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
-
- def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
- b, t, d = hidden.size()
- tail_threshold = self.tail_threshold
- if mask is not None:
- zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
- ones_t = torch.ones_like(zeros_t)
- mask_1 = torch.cat([mask, zeros_t], dim=1)
- mask_2 = torch.cat([ones_t, mask], dim=1)
- mask = mask_2 - mask_1
- tail_threshold = mask * tail_threshold
- alphas = torch.cat([alphas, zeros_t], dim=1)
- alphas = torch.add(alphas, tail_threshold)
- else:
- tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
- tail_threshold = torch.reshape(tail_threshold, (1, 1))
- alphas = torch.cat([alphas, tail_threshold], dim=1)
- zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
- hidden = torch.cat([hidden, zeros], dim=1)
- token_num = alphas.sum(dim=-1)
- token_num_floor = torch.floor(token_num)
-
- return hidden, alphas, token_num_floor
-
- def gen_frame_alignments(self,
- alphas: torch.Tensor = None,
- encoder_sequence_length: torch.Tensor = None):
- batch_size, maximum_length = alphas.size()
- int_type = torch.int32
-
- is_training = self.training
- if is_training:
- token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
- else:
- token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
-
- max_token_num = torch.max(token_num).item()
-
- alphas_cumsum = torch.cumsum(alphas, dim=1)
- alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
- alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
-
- index = torch.ones([batch_size, max_token_num], dtype=int_type)
- index = torch.cumsum(index, dim=1)
- index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
-
- index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
- index_div_bool_zeros = index_div.eq(0)
- index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
- index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
- token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
- index_div_bool_zeros_count *= token_num_mask
-
- index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
- ones = torch.ones_like(index_div_bool_zeros_count_tile)
- zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
- ones = torch.cumsum(ones, dim=2)
- cond = index_div_bool_zeros_count_tile == ones
- index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
-
- index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
- index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
- index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
- index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
- predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
- int_type).to(encoder_sequence_length.device)
- index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
-
- predictor_alignments = index_div_bool_zeros_count_tile_out
- predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
- return predictor_alignments.detach(), predictor_alignments_length.detach()
-
-class BATPredictor(nn.Module):
- def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False):
- super(BATPredictor, self).__init__()
-
- self.pad = nn.ConstantPad1d((l_order, r_order), 0)
- self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
- self.cif_output = nn.Linear(idim, 1)
- self.dropout = torch.nn.Dropout(p=dropout)
- self.threshold = threshold
- self.smooth_factor = smooth_factor
- self.noise_threshold = noise_threshold
- self.return_accum = return_accum
-
- def cif(
- self,
- input: Tensor,
- alpha: Tensor,
- beta: float = 1.0,
- return_accum: bool = False,
- ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
- B, S, C = input.size()
- assert tuple(alpha.size()) == (B, S), f"{alpha.size()} != {(B, S)}"
-
- dtype = alpha.dtype
- alpha = alpha.float()
-
- alpha_sum = alpha.sum(1)
- feat_lengths = (alpha_sum / beta).floor().long()
- T = feat_lengths.max()
-
- # aggregate and integrate
- csum = alpha.cumsum(-1)
- with torch.no_grad():
- # indices used for scattering
- right_idx = (csum / beta).floor().long().clip(max=T)
- left_idx = right_idx.roll(1, dims=1)
- left_idx[:, 0] = 0
-
- # count # of fires from each source
- fire_num = right_idx - left_idx
- extra_weights = (fire_num - 1).clip(min=0)
- # The extra entry in last dim is for
- output = input.new_zeros((B, T + 1, C))
- source_range = torch.arange(1, 1 + S).unsqueeze(0).type_as(input)
- zero = alpha.new_zeros((1,))
-
- # right scatter
- fire_mask = fire_num > 0
- right_weight = torch.where(
- fire_mask,
- csum - right_idx.type_as(alpha) * beta,
- zero
- ).type_as(input)
- # assert right_weight.ge(0).all(), f"{right_weight} should be non-negative."
- output.scatter_add_(
- 1,
- right_idx.unsqueeze(-1).expand(-1, -1, C),
- right_weight.unsqueeze(-1) * input
- )
-
- # left scatter
- left_weight = (
- alpha - right_weight - extra_weights.type_as(alpha) * beta
- ).type_as(input)
- output.scatter_add_(
- 1,
- left_idx.unsqueeze(-1).expand(-1, -1, C),
- left_weight.unsqueeze(-1) * input
- )
-
- # extra scatters
- if extra_weights.ge(0).any():
- extra_steps = extra_weights.max().item()
- tgt_idx = left_idx
- src_feats = input * beta
- for _ in range(extra_steps):
- tgt_idx = (tgt_idx + 1).clip(max=T)
- # (B, S, 1)
- src_mask = (extra_weights > 0)
- output.scatter_add_(
- 1,
- tgt_idx.unsqueeze(-1).expand(-1, -1, C),
- src_feats * src_mask.unsqueeze(2)
- )
- extra_weights -= 1
-
- output = output[:, :T, :]
-
- if return_accum:
- return output, csum
- else:
- return output, alpha
-
- def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None):
- h = hidden
- context = h.transpose(1, 2)
- queries = self.pad(context)
- memory = self.cif_conv1d(queries)
- output = memory + context
- output = self.dropout(output)
- output = output.transpose(1, 2)
- output = torch.relu(output)
- output = self.cif_output(output)
- alphas = torch.sigmoid(output)
- alphas = torch.nn.functional.relu(alphas*self.smooth_factor - self.noise_threshold)
- if mask is not None:
- alphas = alphas * mask.transpose(-1, -2).float()
- if mask_chunk_predictor is not None:
- alphas = alphas * mask_chunk_predictor
- alphas = alphas.squeeze(-1)
- if target_label_length is not None:
- target_length = target_label_length
- elif target_label is not None:
- target_length = (target_label != ignore_id).float().sum(-1)
- # logging.info("target_length: {}".format(target_length))
- else:
- target_length = None
- token_num = alphas.sum(-1)
- if target_length is not None:
- # length_noise = torch.rand(alphas.size(0), device=alphas.device) - 0.5
- # target_length = length_noise + target_length
- alphas *= ((target_length + 1e-4) / token_num)[:, None].repeat(1, alphas.size(1))
- acoustic_embeds, cif_peak = self.cif(hidden, alphas, self.threshold, self.return_accum)
- return acoustic_embeds, token_num, alphas, cif_peak
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
new file mode 100644
index 0000000..3fe9d19
--- /dev/null
+++ b/funasr/models/paraformer/decoder.py
@@ -0,0 +1,625 @@
+from typing import List
+from typing import Tuple
+import logging
+import torch
+import torch.nn as nn
+import numpy as np
+
+from funasr.models.scama import utils as myutils
+from funasr.models.transformer.decoder import BaseTransformerDecoder
+
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.decoder import DecoderLayer
+from funasr.models.transformer.attention import MultiHeadedAttention
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from funasr.utils.register import register_class, registry_tables
+
+class DecoderLayerSANM(nn.Module):
+ """Single decoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ src_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+
+ """
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an DecoderLayer object."""
+ super(DecoderLayerSANM, self).__init__()
+ self.size = size
+ self.self_attn = self_attn
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ if self_attn is not None:
+ self.norm2 = LayerNorm(size)
+ if src_attn is not None:
+ self.norm3 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ # tgt = self.dropout(tgt)
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ x, _ = self.self_attn(tgt, tgt_mask)
+ x = residual + self.dropout(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+ return x, tgt_mask, memory, memory_mask, cache
+
+ def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ # tgt = self.dropout(tgt)
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ if self.training:
+ cache = None
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + self.dropout(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+
+ return x, tgt_mask, memory, memory_mask, cache
+
+ def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
+ x = residual + self.dropout(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
+ x = residual + x
+
+ return x, memory, fsmn_cache, opt_cache
+
+
+@register_class("decoder_classes", "ParaformerSANMDecoder")
+class ParaformerSANMDecoder(BaseTransformerDecoder):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2006.01713
+ """
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ att_layer_num: int = 6,
+ kernel_size: int = 21,
+ sanm_shfit: int = 0,
+ lora_list: List[str] = None,
+ lora_rank: int = 8,
+ lora_alpha: int = 16,
+ lora_dropout: float = 0.1,
+ chunk_multiply_factor: tuple = (1,),
+ tf2torch_tensor_name_prefix_torch: str = "decoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_output_layer=use_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+
+ attention_dim = encoder_output_size
+
+ if input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(vocab_size, attention_dim),
+ # pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ elif input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(vocab_size, attention_dim),
+ torch.nn.LayerNorm(attention_dim),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ else:
+ raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+ self.normalize_before = normalize_before
+ if self.normalize_before:
+ self.after_norm = LayerNorm(attention_dim)
+ if use_output_layer:
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+ else:
+ self.output_layer = None
+
+ self.att_layer_num = att_layer_num
+ self.num_blocks = num_blocks
+ if sanm_shfit is None:
+ sanm_shfit = (kernel_size - 1) // 2
+ self.decoders = repeat(
+ att_layer_num,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ MultiHeadedAttentionSANMDecoder(
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+ ),
+ MultiHeadedAttentionCrossAtt(
+ attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
+ ),
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ if num_blocks - att_layer_num <= 0:
+ self.decoders2 = None
+ else:
+ self.decoders2 = repeat(
+ num_blocks - att_layer_num,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ MultiHeadedAttentionSANMDecoder(
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
+ ),
+ None,
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+ self.decoders3 = repeat(
+ 1,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ None,
+ None,
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+ self.chunk_multiply_factor = chunk_multiply_factor
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ chunk_mask: torch.Tensor = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+ if chunk_mask is not None:
+ memory_mask = memory_mask * chunk_mask
+ if tgt_mask.size(1) != memory_mask.size(1):
+ memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
+
+ x = tgt
+ x, tgt_mask, memory, memory_mask, _ = self.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.decoders2 is not None:
+ x, tgt_mask, memory, memory_mask, _ = self.decoders2(
+ x, tgt_mask, memory, memory_mask
+ )
+ x, tgt_mask, memory, memory_mask, _ = self.decoders3(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.output_layer is not None:
+ x = self.output_layer(x)
+
+ olens = tgt_mask.sum(1)
+ return x, olens
+
+ def score(self, ys, state, x):
+ """Score."""
+ ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+ logp, state = self.forward_one_step(
+ ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
+ )
+ return logp.squeeze(0), state
+
+ def forward_chunk(
+ self,
+ memory: torch.Tensor,
+ tgt: torch.Tensor,
+ cache: dict = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ x = tgt
+ if cache["decode_fsmn"] is None:
+ cache_layer_num = len(self.decoders)
+ if self.decoders2 is not None:
+ cache_layer_num += len(self.decoders2)
+ fsmn_cache = [None] * cache_layer_num
+ else:
+ fsmn_cache = cache["decode_fsmn"]
+
+ if cache["opt"] is None:
+ cache_layer_num = len(self.decoders)
+ opt_cache = [None] * cache_layer_num
+ else:
+ opt_cache = cache["opt"]
+
+ for i in range(self.att_layer_num):
+ decoder = self.decoders[i]
+ x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
+ x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
+ chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
+ )
+
+ if self.num_blocks - self.att_layer_num > 1:
+ for i in range(self.num_blocks - self.att_layer_num):
+ j = i + self.att_layer_num
+ decoder = self.decoders2[i]
+ x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
+ x, memory, fsmn_cache=fsmn_cache[j]
+ )
+
+ for decoder in self.decoders3:
+ x, memory, _, _ = decoder.forward_chunk(
+ x, memory
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.output_layer is not None:
+ x = self.output_layer(x)
+
+ cache["decode_fsmn"] = fsmn_cache
+ if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1:
+ cache["opt"] = opt_cache
+ return x
+
+ def forward_one_step(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ cache: List[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+
+ Args:
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+ x = self.embed(tgt)
+ if cache is None:
+ cache_layer_num = len(self.decoders)
+ if self.decoders2 is not None:
+ cache_layer_num += len(self.decoders2)
+ cache = [None] * cache_layer_num
+ new_cache = []
+ # for c, decoder in zip(cache, self.decoders):
+ for i in range(self.att_layer_num):
+ decoder = self.decoders[i]
+ c = cache[i]
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+ x, tgt_mask, memory, None, cache=c
+ )
+ new_cache.append(c_ret)
+
+ if self.num_blocks - self.att_layer_num > 1:
+ for i in range(self.num_blocks - self.att_layer_num):
+ j = i + self.att_layer_num
+ decoder = self.decoders2[i]
+ c = cache[j]
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+ x, tgt_mask, memory, None, cache=c
+ )
+ new_cache.append(c_ret)
+
+ for decoder in self.decoders3:
+
+ x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
+ x, tgt_mask, memory, None, cache=None
+ )
+
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.output_layer is not None:
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
+
+ return y, new_cache
+
+
+@register_class("decoder_classes", "ParaformerDecoderSAN")
+class ParaformerDecoderSAN(BaseTransformerDecoder):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2006.01713
+ """
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ embeds_id: int = -1,
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_output_layer=use_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+
+ attention_dim = encoder_output_size
+ self.decoders = repeat(
+ num_blocks,
+ lambda lnum: DecoderLayer(
+ attention_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, self_attention_dropout_rate
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ self.embeds_id = embeds_id
+ self.attention_dim = attention_dim
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ tgt = ys_in_pad
+ tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+
+ memory = hs_pad
+ memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
+ memory.device
+ )
+ # Padding for Longformer
+ if memory_mask.shape[-1] != memory.shape[1]:
+ padlen = memory.shape[1] - memory_mask.shape[-1]
+ memory_mask = torch.nn.functional.pad(
+ memory_mask, (0, padlen), "constant", False
+ )
+
+ # x = self.embed(tgt)
+ x = tgt
+ embeds_outputs = None
+ for layer_id, decoder in enumerate(self.decoders):
+ x, tgt_mask, memory, memory_mask = decoder(
+ x, tgt_mask, memory, memory_mask
+ )
+ if layer_id == self.embeds_id:
+ embeds_outputs = x
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.output_layer is not None:
+ x = self.output_layer(x)
+
+ olens = tgt_mask.sum(1)
+ if embeds_outputs is not None:
+ return x, olens, embeds_outputs
+ else:
+ return x, olens
+
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 46bd7b0..fad8385 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -1,56 +1,34 @@
+import os
import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
+from typing import Union, Dict, List, Tuple, Optional
+
import torch
import torch.nn as nn
-import random
-import numpy as np
+
import time
-# from funasr.layers.abs_normalize import AbsNormalize
+
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+
+from funasr.models.paraformer.cif_predictor import mae_loss
+
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
+
from funasr.models.paraformer.search import Hypothesis
-from funasr.models.model_class_factory import *
+from torch.cuda.amp import autocast
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
-else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+from funasr.utils.register import register_class, registry_tables
+from funasr.models.ctc.ctc import CTC
+@register_class("model_classes", "Paraformer")
class Paraformer(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -101,24 +79,19 @@
):
super().__init__()
-
- # import pdb;
- # pdb.set_trace()
-
- if frontend is not None:
- frontend_class = frontend_choices.get_class(frontend)
- frontend = frontend_class(**frontend_conf)
+
if specaug is not None:
- specaug_class = specaug_choices.get_class(specaug)
+ specaug_class = registry_tables.specaug_classes.get(specaug.lower())
specaug = specaug_class(**specaug_conf)
if normalize is not None:
- normalize_class = normalize_choices.get_class(normalize)
+ normalize_class = registry_tables.normalize_classes.get(normalize.lower())
normalize = normalize_class(**normalize_conf)
- encoder_class = encoder_choices.get_class(encoder)
+ encoder_class = registry_tables.encoder_classes.get(encoder.lower())
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
+
if decoder is not None:
- decoder_class = decoder_choices.get_class(decoder)
+ decoder_class = registry_tables.decoder_classes.get(decoder.lower())
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
@@ -133,7 +106,7 @@
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
)
if predictor is not None:
- predictor_class = predictor_choices.get_class(predictor)
+ predictor_class = registry_tables.predictor_classes.get(predictor.lower())
predictor = predictor_class(**predictor_conf)
# note that eos is the same as sos (equivalent ID)
@@ -145,7 +118,7 @@
self.ctc_weight = ctc_weight
# self.token_list = token_list.copy()
#
- self.frontend = frontend
+ # self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
# self.preencoder = preencoder
@@ -275,7 +248,7 @@
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
+ """Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -469,10 +442,11 @@
self.beam_search = beam_search
def generate(self,
- data_in: list,
- data_lengths: list=None,
+ data_in,
+ data_lengths=None,
key: list=None,
tokenizer=None,
+ frontend=None,
**kwargs,
):
@@ -485,16 +459,23 @@
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"), frontend=self.frontend)
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
-
+ if isinstance(data_in, torch.Tensor): # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is None:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
# Encoder
@@ -550,1211 +531,22 @@
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
- result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
- results.append(result_i)
-
- if ibest_writer is not None:
- ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text
- ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
-
- return results, meta_data
-
-
-
-class BiCifParaformer(Paraformer):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
-
- def __init__(
- self,
- *args,
- **kwargs,
- ):
- super().__init__(*args, **kwargs)
- assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
-
-
- def _calc_pre2_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
-
- # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
-
- return loss_pre2
-
-
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
- encoder_out_mask,
- ignore_id=self.ignore_id)
-
- # 0. sampler
- decoder_out_1st = None
- if self.sampling_ratio > 0.0:
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- sematic_embeds = pre_acoustic_embeds
-
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
-
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
- return loss_att, acc_att, cer_att, wer_att, loss_pre
-
-
- def calc_predictor(self, encoder_out, encoder_out_lens):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
- None,
- encoder_out_mask,
- ignore_id=self.ignore_id)
- return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
-
- def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
- encoder_out_mask,
- token_num)
- return ds_alphas, ds_cif_peak, us_alphas, us_peaks
-
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size = speech.shape[0]
-
- # Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
-
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
-
- # decoder: CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
-
-
- # decoder: Attention decoder branch
- loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- loss_pre2 = self._calc_pre2_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
- else:
- loss = self.ctc_weight * loss_ctc + (
- 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
-
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
- stats["loss_pre2"] = loss_pre2.detach().cpu()
-
- stats["loss"] = torch.clone(loss.detach())
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = int((text_lengths + self.predictor_bias).sum())
-
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
- def generate(self,
- data_in: list,
- data_lengths: list = None,
- key: list = None,
- tokenizer=None,
- **kwargs,
- ):
-
- # init beamsearch
- is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
- is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
- if self.beam_search is None and (is_use_lm or is_use_ctc):
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
-
- meta_data = {}
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"),
- frontend=self.frontend)
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data[
- "batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
-
- speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
-
- # Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
-
- # predictor
- predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
- predictor_outs[2], predictor_outs[3]
- pre_token_length = pre_token_length.round().long()
- if torch.max(pre_token_length) < 1:
- return []
- decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
- pre_token_length)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-
- # BiCifParaformer, test no bias cif2
-
- _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
- pre_token_length)
-
- results = []
- b, n, d = decoder_out.size()
- for i in range(b):
- x = encoder_out[i, :encoder_out_lens[i], :]
- am_scores = decoder_out[i, :pre_token_length[i], :]
- if self.beam_search is not None:
- nbest_hyps = self.beam_search(
- x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
- minlenratio=kwargs.get("minlenratio", 0.0)
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
- else:
-
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
- for nbest_idx, hyp in enumerate(nbest_hyps):
- ibest_writer = None
- if ibest_writer is None and kwargs.get("output_dir") is not None:
- writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
- us_peaks[i][:encoder_out_lens[i] * 3],
- copy.copy(token),
- vad_offset=kwargs.get("begin_time", 0))
-
- text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token, timestamp)
-
- result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
- "time_stamp_postprocessed": time_stamp_postprocessed,
- "word_lists": word_lists
- }
- results.append(result_i)
-
- if ibest_writer is not None:
- ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text
- ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+ if tokenizer is not None:
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
-
- return results, meta_data
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
-
-class NeatContextualParaformer(Paraformer):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
-
- def __init__(
- self,
- *args,
- **kwargs,
- ):
- super().__init__(*args, **kwargs)
-
- self.target_buffer_length = kwargs.get("target_buffer_length", -1)
- inner_dim = kwargs.get("inner_dim", 256)
- bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
- use_decoder_embedding = kwargs.get("use_decoder_embedding", False)
- crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
- crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
- bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
-
-
- if bias_encoder_type == 'lstm':
- logging.warning("enable bias encoder sampling and contextual training")
- self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
- self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
- elif bias_encoder_type == 'mean':
- logging.warning("enable bias encoder sampling and contextual training")
- self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
- else:
- logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type))
-
- if self.target_buffer_length > 0:
- self.hotword_buffer = None
- self.length_record = []
- self.current_buffer_length = 0
- self.use_decoder_embedding = use_decoder_embedding
- self.crit_attn_weight = crit_attn_weight
- if self.crit_attn_weight > 0:
- self.attn_loss = torch.nn.L1Loss()
- self.crit_attn_smooth = crit_attn_smooth
-
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- hotword_pad: torch.Tensor,
- hotword_lengths: torch.Tensor,
- dha_pad: torch.Tensor,
- **kwargs,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
-
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size = speech.shape[0]
-
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
-
- loss_ctc, cer_ctc = None, None
-
- stats = dict()
-
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
-
-
- # 2b. Attention decoder branch
- loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
- encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
- )
-
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
-
- if loss_ideal is not None:
- loss = loss + loss_ideal * self.crit_attn_weight
- stats["loss_ideal"] = loss_ideal.detach().cpu()
-
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
-
- stats["loss"] = torch.clone(loss.detach())
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = int((text_lengths + self.predictor_bias).sum())
-
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
-
- def _calc_att_clas_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- hotword_pad: torch.Tensor,
- hotword_lengths: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
-
- # -1. bias encoder
- if self.use_decoder_embedding:
- hw_embed = self.decoder.embed(hotword_pad)
- else:
- hw_embed = self.bias_embed(hotword_pad)
- hw_embed, (_, _) = self.bias_encoder(hw_embed)
- _ind = np.arange(0, hotword_pad.shape[0]).tolist()
- selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
- contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
-
- # 0. sampler
- decoder_out_1st = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds, contextual_info)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
-
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- '''
- if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
- ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
- attn_non_blank = attn[:,:,:,:-1]
- ideal_attn_non_blank = ideal_attn[:,:,:-1]
- loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device))
- else:
- loss_ideal = None
- '''
- loss_ideal = None
-
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
- return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
-
-
- def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad)
- with torch.no_grad():
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0,
- index=torch.randperm(seq_lens[li])[:target_num].to(pre_acoustic_embeds.device),
- value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
-
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask
-
-
- def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
- clas_scale=1.0):
- if hw_list is None:
- hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list
- hw_list_pad = pad_list(hw_list, 0)
- if self.use_decoder_embedding:
- hw_embed = self.decoder.embed(hw_list_pad)
- else:
- hw_embed = self.bias_embed(hw_list_pad)
- hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
- hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
- else:
- hw_lengths = [len(i) for i in hw_list]
- hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
- if self.use_decoder_embedding:
- hw_embed = self.decoder.embed(hw_list_pad)
- else:
- hw_embed = self.bias_embed(hw_list_pad)
- hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
- enforce_sorted=False)
- _, (h_n, _) = self.bias_encoder(hw_embed)
- hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
-
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
- )
- decoder_out = decoder_outs[0]
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out, ys_pad_lens
-
- def generate(self,
- data_in: list,
- data_lengths: list = None,
- key: list = None,
- tokenizer=None,
- **kwargs,
- ):
-
- # init beamsearch
- is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
- is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
- if self.beam_search is None and (is_use_lm or is_use_ctc):
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
-
- meta_data = {}
-
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"),
- frontend=self.frontend)
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data[
- "batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
-
- speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
-
- # hotword
- self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer)
-
- # Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
-
- # predictor
- predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
- predictor_outs[2], predictor_outs[3]
- pre_token_length = pre_token_length.round().long()
- if torch.max(pre_token_length) < 1:
- return []
-
-
- decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
- pre_acoustic_embeds,
- pre_token_length,
- hw_list=self.hotword_list,
- clas_scale=kwargs.get("clas_scale", 1.0))
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-
- results = []
- b, n, d = decoder_out.size()
- for i in range(b):
- x = encoder_out[i, :encoder_out_lens[i], :]
- am_scores = decoder_out[i, :pre_token_length[i], :]
- if self.beam_search is not None:
- nbest_hyps = self.beam_search(
- x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
- minlenratio=kwargs.get("minlenratio", 0.0)
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
- else:
-
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
- for nbest_idx, hyp in enumerate(nbest_hyps):
- ibest_writer = None
- if ibest_writer is None and kwargs.get("output_dir") is not None:
- writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(
- filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
- result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+ result_i = {"key": key[i], "token_int": token_int}
results.append(result_i)
- if ibest_writer is not None:
- ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text
- ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
-
return results, meta_data
-
-
- def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None):
- def load_seg_dict(seg_dict_file):
- seg_dict = {}
- assert isinstance(seg_dict_file, str)
- with open(seg_dict_file, "r", encoding="utf8") as f:
- lines = f.readlines()
- for line in lines:
- s = line.strip().split()
- key = s[0]
- value = s[1:]
- seg_dict[key] = " ".join(value)
- return seg_dict
-
- def seg_tokenize(txt, seg_dict):
- pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
- out_txt = ""
- for word in txt:
- word = word.lower()
- if word in seg_dict:
- out_txt += seg_dict[word] + " "
- else:
- if pattern.match(word):
- for char in word:
- if char in seg_dict:
- out_txt += seg_dict[char] + " "
- else:
- out_txt += "<unk>" + " "
- else:
- out_txt += "<unk>" + " "
- return out_txt.strip().split()
-
- seg_dict = None
- if self.frontend.cmvn_file is not None:
- model_dir = os.path.dirname(self.frontend.cmvn_file)
- seg_dict_file = os.path.join(model_dir, 'seg_dict')
- if os.path.exists(seg_dict_file):
- seg_dict = load_seg_dict(seg_dict_file)
- else:
- seg_dict = None
- # for None
- if hotword_list_or_file is None:
- hotword_list = None
- # for local txt inputs
- elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
- logging.info("Attempting to parse hotwords from local txt...")
- hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hw_list = hw.split()
- if seg_dict is not None:
- hw_list = seg_tokenize(hw_list, seg_dict)
- hotword_str_list.append(hw)
- hotword_list.append(tokenizer.tokens2ids(hw_list))
- hotword_list.append([self.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for url, download and generate txt
- elif hotword_list_or_file.startswith('http'):
- logging.info("Attempting to parse hotwords from url...")
- work_dir = tempfile.TemporaryDirectory().name
- if not os.path.exists(work_dir):
- os.makedirs(work_dir)
- text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
- local_file = requests.get(hotword_list_or_file)
- open(text_file_path, "wb").write(local_file.content)
- hotword_list_or_file = text_file_path
- hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hw_list = hw.split()
- if seg_dict is not None:
- hw_list = seg_tokenize(hw_list, seg_dict)
- hotword_str_list.append(hw)
- hotword_list.append(tokenizer.tokens2ids(hw_list))
- hotword_list.append([self.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for text str input
- elif not hotword_list_or_file.endswith('.txt'):
- logging.info("Attempting to parse hotwords as str...")
- hotword_list = []
- hotword_str_list = []
- for hw in hotword_list_or_file.strip().split():
- hotword_str_list.append(hw)
- hw_list = hw.strip().split()
- if seg_dict is not None:
- hw_list = seg_tokenize(hw_list, seg_dict)
- hotword_list.append(tokenizer.tokens2ids(hw_list))
- hotword_list.append([self.sos])
- hotword_str_list.append('<s>')
- logging.info("Hotword list: {}.".format(hotword_str_list))
- else:
- hotword_list = None
- return hotword_list
-
-
-class ParaformerOnline(Paraformer):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
-
- def __init__(
- self,
- *args,
- **kwargs,
- ):
-
- super().__init__(*args, **kwargs)
-
- # import pdb;
- # pdb.set_trace()
- self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
-
-
- self.scama_mask = None
- if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
- from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
- self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
- self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
-
-
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- # import pdb;
- # pdb.set_trace()
- decoding_ind = kwargs.get("decoding_ind")
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size = speech.shape[0]
-
- # Encoder
- if hasattr(self.encoder, "overlap_chunk_cls"):
- ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
- else:
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
-
- # decoder: CTC branch
-
- if self.ctc_weight > 0.0:
- if hasattr(self.encoder, "overlap_chunk_cls"):
- encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
- encoder_out_lens,
- chunk_outs=None)
- else:
- encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
-
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
- )
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
-
- # decoder: Attention decoder branch
- loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight
- else:
- loss = self.ctc_weight * loss_ctc + (
- 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
-
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
-
- stats["loss"] = torch.clone(loss.detach())
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = (text_lengths + self.predictor_bias).sum()
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
- def encode_chunk(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- with autocast(False):
-
- # Data augmentation
- if self.specaug is not None and self.training:
- speech, speech_lengths = self.specaug(speech, speech_lengths)
-
- # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- speech, speech_lengths = self.normalize(speech, speech_lengths)
-
- # Forward encoder
- encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
-
- return encoder_out, torch.tensor([encoder_out.size(1)])
-
- def _calc_att_predictor_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- mask_chunk_predictor = None
- if self.encoder.overlap_chunk_cls is not None:
- mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
- device=encoder_out.device,
- batch_size=encoder_out.size(
- 0))
- mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
- batch_size=encoder_out.size(0))
- encoder_out = encoder_out * mask_shfit_chunk
- pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
- ys_pad,
- encoder_out_mask,
- ignore_id=self.ignore_id,
- mask_chunk_predictor=mask_chunk_predictor,
- target_label_length=ys_pad_lens,
- )
- predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
- encoder_out_lens)
-
- scama_mask = None
- if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
- encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
- attention_chunk_center_bias = 0
- attention_chunk_size = encoder_chunk_size
- decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
- mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
- get_mask_shift_att_chunk_decoder(None,
- device=encoder_out.device,
- batch_size=encoder_out.size(0)
- )
- scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
- predictor_alignments=predictor_alignments,
- encoder_sequence_length=encoder_out_lens,
- chunk_size=1,
- encoder_chunk_size=encoder_chunk_size,
- attention_chunk_center_bias=attention_chunk_center_bias,
- attention_chunk_size=attention_chunk_size,
- attention_chunk_type=self.decoder_attention_chunk_type,
- step=None,
- predictor_mask_chunk_hopping=mask_chunk_predictor,
- decoder_att_look_back_factor=decoder_att_look_back_factor,
- mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
- target_length=ys_pad_lens,
- is_training=self.training,
- )
- elif self.encoder.overlap_chunk_cls is not None:
- encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
- encoder_out_lens,
- chunk_outs=None)
- # 0. sampler
- decoder_out_1st = None
- pre_loss_att = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- if self.use_1st_decoder_loss:
- sematic_embeds, decoder_out_1st, pre_loss_att = \
- self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
- ys_pad_lens, pre_acoustic_embeds, scama_mask)
- else:
- sematic_embeds, decoder_out_1st = \
- self.sampler(encoder_out, encoder_out_lens, ys_pad,
- ys_pad_lens, pre_acoustic_embeds, scama_mask)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
-
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
-
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
- return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
-
- def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
-
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad_masked)
- with torch.no_grad():
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
-
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask
-
-
- def calc_predictor(self, encoder_out, encoder_out_lens):
-
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- mask_chunk_predictor = None
- if self.encoder.overlap_chunk_cls is not None:
- mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
- device=encoder_out.device,
- batch_size=encoder_out.size(
- 0))
- mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
- batch_size=encoder_out.size(0))
- encoder_out = encoder_out * mask_shfit_chunk
- pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
- None,
- encoder_out_mask,
- ignore_id=self.ignore_id,
- mask_chunk_predictor=mask_chunk_predictor,
- target_label_length=None,
- )
- predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
- encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
-
- scama_mask = None
- if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
- encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
- attention_chunk_center_bias = 0
- attention_chunk_size = encoder_chunk_size
- decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
- mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
- get_mask_shift_att_chunk_decoder(None,
- device=encoder_out.device,
- batch_size=encoder_out.size(0)
- )
- scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
- predictor_alignments=predictor_alignments,
- encoder_sequence_length=encoder_out_lens,
- chunk_size=1,
- encoder_chunk_size=encoder_chunk_size,
- attention_chunk_center_bias=attention_chunk_center_bias,
- attention_chunk_size=attention_chunk_size,
- attention_chunk_type=self.decoder_attention_chunk_type,
- step=None,
- predictor_mask_chunk_hopping=mask_chunk_predictor,
- decoder_att_look_back_factor=decoder_att_look_back_factor,
- mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
- target_length=None,
- is_training=self.training,
- )
- self.scama_mask = scama_mask
-
- return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
-
- def calc_predictor_chunk(self, encoder_out, cache=None):
-
- pre_acoustic_embeds, pre_token_length = \
- self.predictor.forward_chunk(encoder_out, cache["encoder"])
- return pre_acoustic_embeds, pre_token_length
-
- def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
- )
- decoder_out = decoder_outs[0]
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out, ys_pad_lens
-
- def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
- decoder_outs = self.decoder.forward_chunk(
- encoder_out, sematic_embeds, cache["decoder"]
- )
- decoder_out = decoder_outs
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out
-
- def generate(self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- tokenizer=None,
- **kwargs,
- ):
-
- is_use_ctc = kwargs.get("ctc_weight", 0.0) > 0.00001 and self.ctc != None
- print(is_use_ctc)
- is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
-
- if self.beam_search is None and (is_use_lm or is_use_ctc):
- logging.info("enable beam_search")
- self.init_beam_search(speech, speech_lengths, **kwargs)
- self.nbest = kwargs.get("nbest", 1)
-
- # Forward Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
-
- # predictor
- predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
- predictor_outs[2], predictor_outs[3]
- pre_token_length = pre_token_length.round().long()
- if torch.max(pre_token_length) < 1:
- return []
- decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
- pre_token_length)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-
- results = []
- b, n, d = decoder_out.size()
- for i in range(b):
- x = encoder_out[i, :encoder_out_lens[i], :]
- am_scores = decoder_out[i, :pre_token_length[i], :]
- if self.beam_search is not None:
- nbest_hyps = self.beam_search(
- x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
- minlenratio=kwargs.get("minlenratio", 0.0)
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
- else:
-
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- timestamp = []
-
- results.append((text, token, timestamp))
-
- return results
diff --git a/funasr/models/paraformer/search.py b/funasr/models/paraformer/search.py
index 8789025..250baad 100644
--- a/funasr/models/paraformer/search.py
+++ b/funasr/models/paraformer/search.py
@@ -9,7 +9,7 @@
import torch
-from funasr.metrics import end_detect
+from funasr.metrics.common import end_detect
from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
diff --git a/funasr/models/paraformer/template.yaml b/funasr/models/paraformer/template.yaml
new file mode 100644
index 0000000..1909600
--- /dev/null
+++ b/funasr/models/paraformer/template.yaml
@@ -0,0 +1,126 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# network architecture
+#model: funasr.models.paraformer.model:Paraformer
+model: Paraformer
+model_conf:
+ ctc_weight: 0.0
+ lsm_weight: 0.1
+ length_normalized_loss: true
+ predictor_weight: 1.0
+ predictor_bias: 1
+ sampling_ratio: 0.75
+
+# encoder
+encoder: SANMEncoder
+encoder_conf:
+ output_size: 512
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 50
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.1
+ input_layer: pe
+ pos_enc_class: SinusoidalPositionEncoder
+ normalize_before: true
+ kernel_size: 11
+ sanm_shfit: 0
+ selfattention_layer_type: sanm
+
+# decoder
+decoder: ParaformerSANMDecoder
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 16
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.1
+ src_attention_dropout_rate: 0.1
+ att_layer_num: 16
+ kernel_size: 11
+ sanm_shfit: 0
+
+predictor: CifPredictorV2
+predictor_conf:
+ idim: 512
+ threshold: 1.0
+ l_order: 1
+ r_order: 1
+ tail_threshold: 0.45
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 7
+ lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+ apply_time_warp: false
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ lfr_rate: 6
+ num_freq_mask: 1
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 12
+ num_time_mask: 1
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 150
+ val_scheduler_criterion:
+ - valid
+ - acc
+ best_model_criterion:
+ - - valid
+ - acc
+ - max
+ keep_nbest_models: 10
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: DynamicBatchLocalShuffleSampler
+ batch_type: example # example or length
+ batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 500
+ shuffle: True
+ num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+ split_with_space: true
+
+
+input_size: 560
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin
+ reduce: true
+ ignore_nan_grad: true
+normalize: null
diff --git a/funasr/models/paraformer_online/model.py b/funasr/models/paraformer_online/model.py
new file mode 100644
index 0000000..5cbed26
--- /dev/null
+++ b/funasr/models/paraformer_online/model.py
@@ -0,0 +1,1284 @@
+import os
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+import tempfile
+import codecs
+import requests
+import re
+import copy
+import torch
+import torch.nn as nn
+import random
+import numpy as np
+import time
+# from funasr.layers.abs_normalize import AbsNormalize
+from funasr.losses.label_smoothing_loss import (
+ LabelSmoothingLoss, # noqa: H301
+)
+
+from funasr.models.paraformer.cif_predictor import mae_loss
+
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import force_gatherable
+
+from funasr.models.paraformer.search import Hypothesis
+
+# from funasr.models.model_class_factory import *
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+from funasr.utils.register import registry_tables
+from funasr.models.ctc.ctc import CTC
+
+class Paraformer(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ # token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[str] = None,
+ frontend_conf: Optional[Dict] = None,
+ specaug: Optional[str] = None,
+ specaug_conf: Optional[Dict] = None,
+ normalize: str = None,
+ normalize_conf: Optional[Dict] = None,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ decoder: str = None,
+ decoder_conf: Optional[Dict] = None,
+ ctc: str = None,
+ ctc_conf: Optional[Dict] = None,
+ predictor: str = None,
+ predictor_conf: Optional[Dict] = None,
+ ctc_weight: float = 0.5,
+ input_size: int = 80,
+ vocab_size: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ # report_cer: bool = True,
+ # report_wer: bool = True,
+ # sym_space: str = "<space>",
+ # sym_blank: str = "<blank>",
+ # extract_feats_in_collect_stats: bool = True,
+ # predictor=None,
+ predictor_weight: float = 0.0,
+ predictor_bias: int = 0,
+ sampling_ratio: float = 0.2,
+ share_embedding: bool = False,
+ # preencoder: Optional[AbsPreEncoder] = None,
+ # postencoder: Optional[AbsPostEncoder] = None,
+ use_1st_decoder_loss: bool = False,
+ **kwargs,
+ ):
+
+ super().__init__()
+
+ # import pdb;
+ # pdb.set_trace()
+
+ if frontend is not None:
+ frontend_class = registry_tables.frontend_classes.get_class(frontend.lower())
+ frontend = frontend_class(**frontend_conf)
+ if specaug is not None:
+ specaug_class = registry_tables.specaug_classes.get_class(specaug.lower())
+ specaug = specaug_class(**specaug_conf)
+ if normalize is not None:
+ normalize_class = registry_tables.normalize_classes.get_class(normalize.lower())
+ normalize = normalize_class(**normalize_conf)
+ encoder_class = registry_tables.encoder_classes.get_class(encoder.lower())
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
+ if decoder is not None:
+ decoder_class = registry_tables.decoder_classes.get_class(decoder.lower())
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **decoder_conf,
+ )
+ if ctc_weight > 0.0:
+
+ if ctc_conf is None:
+ ctc_conf = {}
+
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+ )
+ if predictor is not None:
+ predictor_class = registry_tables.predictor_classes.get_class(predictor.lower())
+ predictor = predictor_class(**predictor_conf)
+
+ # note that eos is the same as sos (equivalent ID)
+ self.blank_id = blank_id
+ self.sos = sos if sos is not None else vocab_size - 1
+ self.eos = eos if eos is not None else vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.ctc_weight = ctc_weight
+ # self.token_list = token_list.copy()
+ #
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ # self.preencoder = preencoder
+ # self.postencoder = postencoder
+ self.encoder = encoder
+ #
+ # if not hasattr(self.encoder, "interctc_use_conditioning"):
+ # self.encoder.interctc_use_conditioning = False
+ # if self.encoder.interctc_use_conditioning:
+ # self.encoder.conditioning_layer = torch.nn.Linear(
+ # vocab_size, self.encoder.output_size()
+ # )
+ #
+ # self.error_calculator = None
+ #
+ if ctc_weight == 1.0:
+ self.decoder = None
+ else:
+ self.decoder = decoder
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+ #
+ # if report_cer or report_wer:
+ # self.error_calculator = ErrorCalculator(
+ # token_list, sym_space, sym_blank, report_cer, report_wer
+ # )
+ #
+ if ctc_weight == 0.0:
+ self.ctc = None
+ else:
+ self.ctc = ctc
+ #
+ # self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+ self.predictor = predictor
+ self.predictor_weight = predictor_weight
+ self.predictor_bias = predictor_bias
+ self.sampling_ratio = sampling_ratio
+ self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
+ # self.step_cur = 0
+ #
+ self.share_embedding = share_embedding
+ if self.share_embedding:
+ self.decoder.embed = None
+
+ self.use_1st_decoder_loss = use_1st_decoder_loss
+ self.length_normalized_loss = length_normalized_loss
+ self.beam_search = None
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ # import pdb;
+ # pdb.set_trace()
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+
+ loss_ctc, cer_ctc = None, None
+ loss_pre = None
+ stats = dict()
+
+ # decoder: CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+
+ # decoder: Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = (text_lengths + self.predictor_bias).sum()
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+
+ # Forward encoder
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ return encoder_out, encoder_out_lens
+
+ def calc_predictor(self, encoder_out, encoder_out_lens):
+
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
+ return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # 0. sampler
+ decoder_out_1st = None
+ pre_loss_att = None
+ if self.sampling_ratio > 0.0:
+
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds)
+ else:
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
+
+ def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
+
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad_masked)
+ with torch.no_grad():
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0,
+ index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device),
+ value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
+
+
+ def init_beam_search(self,
+ **kwargs,
+ ):
+ from funasr.models.paraformer.search import BeamSearchPara
+ from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+ from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+ # 1. Build ASR model
+ scorers = {}
+
+ if self.ctc != None:
+ ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+ scorers.update(
+ ctc=ctc
+ )
+ token_list = kwargs.get("token_list")
+ scorers.update(
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ weights = dict(
+ decoder=1.0 - kwargs.get("decoding_ctc_weight"),
+ ctc=kwargs.get("decoding_ctc_weight", 0.0),
+ lm=kwargs.get("lm_weight", 0.0),
+ ngram=kwargs.get("ngram_weight", 0.0),
+ length_bonus=kwargs.get("penalty", 0.0),
+ )
+ beam_search = BeamSearchPara(
+ beam_size=kwargs.get("beam_size", 2),
+ weights=weights,
+ scorers=scorers,
+ sos=self.sos,
+ eos=self.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+ )
+ # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+ # for scorer in scorers.values():
+ # if isinstance(scorer, torch.nn.Module):
+ # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+ self.beam_search = beam_search
+
+ def generate(self,
+ data_in: list,
+ data_lengths: list=None,
+ key: list=None,
+ tokenizer=None,
+ **kwargs,
+ ):
+
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
+
+ speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # predictor
+ predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return []
+ decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+ pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{nbest_idx+1}best_recog"]
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+
+ return results, meta_data
+
+
+
+class BiCifParaformer(Paraformer):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+
+ def _calc_pre2_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
+
+ # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+ loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
+
+ return loss_pre2
+
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # 0. sampler
+ decoder_out_1st = None
+ if self.sampling_ratio > 0.0:
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds)
+ else:
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre
+
+
+ def calc_predictor(self, encoder_out, encoder_out_lens):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
+ None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
+ return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+
+ def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
+ encoder_out_mask,
+ token_num)
+ return ds_alphas, ds_cif_peak, us_alphas, us_peaks
+
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+
+ loss_ctc, cer_ctc = None, None
+ loss_pre = None
+ stats = dict()
+
+ # decoder: CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+
+ # decoder: Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ loss_pre2 = self._calc_pre2_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+ else:
+ loss = self.ctc_weight * loss_ctc + (
+ 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+ stats["loss_pre2"] = loss_pre2.detach().cpu()
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = int((text_lengths + self.predictor_bias).sum())
+
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def generate(self,
+ data_in: list,
+ data_lengths: list = None,
+ key: list = None,
+ tokenizer=None,
+ **kwargs,
+ ):
+
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=self.frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data[
+ "batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
+
+ speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # predictor
+ predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return []
+ decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+ pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+ # BiCifParaformer, test no bias cif2
+
+ _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
+ pre_token_length)
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+ minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
+ us_peaks[i][:encoder_out_lens[i] * 3],
+ copy.copy(token),
+ vad_offset=kwargs.get("begin_time", 0))
+
+ text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token, timestamp)
+
+ result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
+ "time_stamp_postprocessed": time_stamp_postprocessed,
+ "word_lists": word_lists
+ }
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+
+
+ return results, meta_data
+
+
+class ParaformerOnline(Paraformer):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+
+ super().__init__(*args, **kwargs)
+
+ # import pdb;
+ # pdb.set_trace()
+ self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
+
+
+ self.scama_mask = None
+ if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
+ from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
+ self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
+ self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
+
+
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ # import pdb;
+ # pdb.set_trace()
+ decoding_ind = kwargs.get("decoding_ind")
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # Encoder
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+ else:
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ loss_ctc, cer_ctc = None, None
+ loss_pre = None
+ stats = dict()
+
+ # decoder: CTC branch
+
+ if self.ctc_weight > 0.0:
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+ encoder_out_lens,
+ chunk_outs=None)
+ else:
+ encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
+
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
+ )
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+ # decoder: Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight
+ else:
+ loss = self.ctc_weight * loss_ctc + (
+ 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = (text_lengths + self.predictor_bias).sum()
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode_chunk(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+ # Forward encoder
+ encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ return encoder_out, torch.tensor([encoder_out.size(1)])
+
+ def _calc_att_predictor_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ mask_chunk_predictor = None
+ if self.encoder.overlap_chunk_cls is not None:
+ mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(
+ 0))
+ mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+ batch_size=encoder_out.size(0))
+ encoder_out = encoder_out * mask_shfit_chunk
+ pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
+ ys_pad,
+ encoder_out_mask,
+ ignore_id=self.ignore_id,
+ mask_chunk_predictor=mask_chunk_predictor,
+ target_label_length=ys_pad_lens,
+ )
+ predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+ encoder_out_lens)
+
+ scama_mask = None
+ if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+ encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+ attention_chunk_center_bias = 0
+ attention_chunk_size = encoder_chunk_size
+ decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+ mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
+ get_mask_shift_att_chunk_decoder(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(0)
+ )
+ scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+ predictor_alignments=predictor_alignments,
+ encoder_sequence_length=encoder_out_lens,
+ chunk_size=1,
+ encoder_chunk_size=encoder_chunk_size,
+ attention_chunk_center_bias=attention_chunk_center_bias,
+ attention_chunk_size=attention_chunk_size,
+ attention_chunk_type=self.decoder_attention_chunk_type,
+ step=None,
+ predictor_mask_chunk_hopping=mask_chunk_predictor,
+ decoder_att_look_back_factor=decoder_att_look_back_factor,
+ mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+ target_length=ys_pad_lens,
+ is_training=self.training,
+ )
+ elif self.encoder.overlap_chunk_cls is not None:
+ encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+ encoder_out_lens,
+ chunk_outs=None)
+ # 0. sampler
+ decoder_out_1st = None
+ pre_loss_att = None
+ if self.sampling_ratio > 0.0:
+ if self.step_cur < 2:
+ logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ if self.use_1st_decoder_loss:
+ sematic_embeds, decoder_out_1st, pre_loss_att = \
+ self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
+ ys_pad_lens, pre_acoustic_embeds, scama_mask)
+ else:
+ sematic_embeds, decoder_out_1st = \
+ self.sampler(encoder_out, encoder_out_lens, ys_pad,
+ ys_pad_lens, pre_acoustic_embeds, scama_mask)
+ else:
+ if self.step_cur < 2:
+ logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
+
+ def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
+
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad_masked)
+ with torch.no_grad():
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+
+ def calc_predictor(self, encoder_out, encoder_out_lens):
+
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ mask_chunk_predictor = None
+ if self.encoder.overlap_chunk_cls is not None:
+ mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(
+ 0))
+ mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+ batch_size=encoder_out.size(0))
+ encoder_out = encoder_out * mask_shfit_chunk
+ pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
+ None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id,
+ mask_chunk_predictor=mask_chunk_predictor,
+ target_label_length=None,
+ )
+ predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+ encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
+
+ scama_mask = None
+ if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+ encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+ attention_chunk_center_bias = 0
+ attention_chunk_size = encoder_chunk_size
+ decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+ mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
+ get_mask_shift_att_chunk_decoder(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(0)
+ )
+ scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+ predictor_alignments=predictor_alignments,
+ encoder_sequence_length=encoder_out_lens,
+ chunk_size=1,
+ encoder_chunk_size=encoder_chunk_size,
+ attention_chunk_center_bias=attention_chunk_center_bias,
+ attention_chunk_size=attention_chunk_size,
+ attention_chunk_type=self.decoder_attention_chunk_type,
+ step=None,
+ predictor_mask_chunk_hopping=mask_chunk_predictor,
+ decoder_att_look_back_factor=decoder_att_look_back_factor,
+ mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+ target_length=None,
+ is_training=self.training,
+ )
+ self.scama_mask = scama_mask
+
+ return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
+
+ def calc_predictor_chunk(self, encoder_out, cache=None):
+
+ pre_acoustic_embeds, pre_token_length = \
+ self.predictor.forward_chunk(encoder_out, cache["encoder"])
+ return pre_acoustic_embeds, pre_token_length
+
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
+ def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+ decoder_outs = self.decoder.forward_chunk(
+ encoder_out, sematic_embeds, cache["decoder"]
+ )
+ decoder_out = decoder_outs
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out
+
+ def generate(self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ tokenizer=None,
+ **kwargs,
+ ):
+
+ is_use_ctc = kwargs.get("ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ print(is_use_ctc)
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(speech, speech_lengths, **kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ # Forward Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # predictor
+ predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return []
+ decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+ pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+ minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ timestamp = []
+
+ results.append((text, token, timestamp))
+
+ return results
+
diff --git a/funasr/models/paraformer_online/sanm_decoder.py b/funasr/models/paraformer_online/sanm_decoder.py
new file mode 100644
index 0000000..b1e94d7
--- /dev/null
+++ b/funasr/models/paraformer_online/sanm_decoder.py
@@ -0,0 +1,507 @@
+from typing import List
+from typing import Tuple
+import logging
+import torch
+import torch.nn as nn
+import numpy as np
+
+from funasr.models.scama import utils as myutils
+from funasr.models.transformer.decoder import BaseTransformerDecoder
+
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.transformer.utils.repeat import repeat
+
+from funasr.utils.register import register_class, registry_tables
+
+class DecoderLayerSANM(nn.Module):
+ """Single decoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ src_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+
+ """
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an DecoderLayer object."""
+ super(DecoderLayerSANM, self).__init__()
+ self.size = size
+ self.self_attn = self_attn
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ if self_attn is not None:
+ self.norm2 = LayerNorm(size)
+ if src_attn is not None:
+ self.norm3 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ # tgt = self.dropout(tgt)
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ x, _ = self.self_attn(tgt, tgt_mask)
+ x = residual + self.dropout(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+ return x, tgt_mask, memory, memory_mask, cache
+
+ def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ # tgt = self.dropout(tgt)
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ if self.training:
+ cache = None
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + self.dropout(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+
+ return x, tgt_mask, memory, memory_mask, cache
+
+ def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
+ x = residual + self.dropout(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
+ x = residual + x
+
+ return x, memory, fsmn_cache, opt_cache
+
+
+@register_class("decoder_classes", "ParaformerSANMDecoder")
+class ParaformerSANMDecoder(BaseTransformerDecoder):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2006.01713
+ """
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ att_layer_num: int = 6,
+ kernel_size: int = 21,
+ sanm_shfit: int = 0,
+ lora_list: List[str] = None,
+ lora_rank: int = 8,
+ lora_alpha: int = 16,
+ lora_dropout: float = 0.1,
+ chunk_multiply_factor: tuple = (1,),
+ tf2torch_tensor_name_prefix_torch: str = "decoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_output_layer=use_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+
+ attention_dim = encoder_output_size
+
+ if input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(vocab_size, attention_dim),
+ # pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ elif input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(vocab_size, attention_dim),
+ torch.nn.LayerNorm(attention_dim),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ else:
+ raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+ self.normalize_before = normalize_before
+ if self.normalize_before:
+ self.after_norm = LayerNorm(attention_dim)
+ if use_output_layer:
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+ else:
+ self.output_layer = None
+
+ self.att_layer_num = att_layer_num
+ self.num_blocks = num_blocks
+ if sanm_shfit is None:
+ sanm_shfit = (kernel_size - 1) // 2
+ self.decoders = repeat(
+ att_layer_num,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ MultiHeadedAttentionSANMDecoder(
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+ ),
+ MultiHeadedAttentionCrossAtt(
+ attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
+ ),
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ if num_blocks - att_layer_num <= 0:
+ self.decoders2 = None
+ else:
+ self.decoders2 = repeat(
+ num_blocks - att_layer_num,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ MultiHeadedAttentionSANMDecoder(
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
+ ),
+ None,
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+ self.decoders3 = repeat(
+ 1,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ None,
+ None,
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+ self.chunk_multiply_factor = chunk_multiply_factor
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ chunk_mask: torch.Tensor = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+ if chunk_mask is not None:
+ memory_mask = memory_mask * chunk_mask
+ if tgt_mask.size(1) != memory_mask.size(1):
+ memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
+
+ x = tgt
+ x, tgt_mask, memory, memory_mask, _ = self.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.decoders2 is not None:
+ x, tgt_mask, memory, memory_mask, _ = self.decoders2(
+ x, tgt_mask, memory, memory_mask
+ )
+ x, tgt_mask, memory, memory_mask, _ = self.decoders3(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.output_layer is not None:
+ x = self.output_layer(x)
+
+ olens = tgt_mask.sum(1)
+ return x, olens
+
+ def score(self, ys, state, x):
+ """Score."""
+ ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+ logp, state = self.forward_one_step(
+ ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
+ )
+ return logp.squeeze(0), state
+
+ def forward_chunk(
+ self,
+ memory: torch.Tensor,
+ tgt: torch.Tensor,
+ cache: dict = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ x = tgt
+ if cache["decode_fsmn"] is None:
+ cache_layer_num = len(self.decoders)
+ if self.decoders2 is not None:
+ cache_layer_num += len(self.decoders2)
+ fsmn_cache = [None] * cache_layer_num
+ else:
+ fsmn_cache = cache["decode_fsmn"]
+
+ if cache["opt"] is None:
+ cache_layer_num = len(self.decoders)
+ opt_cache = [None] * cache_layer_num
+ else:
+ opt_cache = cache["opt"]
+
+ for i in range(self.att_layer_num):
+ decoder = self.decoders[i]
+ x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
+ x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
+ chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
+ )
+
+ if self.num_blocks - self.att_layer_num > 1:
+ for i in range(self.num_blocks - self.att_layer_num):
+ j = i + self.att_layer_num
+ decoder = self.decoders2[i]
+ x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
+ x, memory, fsmn_cache=fsmn_cache[j]
+ )
+
+ for decoder in self.decoders3:
+ x, memory, _, _ = decoder.forward_chunk(
+ x, memory
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.output_layer is not None:
+ x = self.output_layer(x)
+
+ cache["decode_fsmn"] = fsmn_cache
+ if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1:
+ cache["opt"] = opt_cache
+ return x
+
+ def forward_one_step(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ cache: List[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+
+ Args:
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+ x = self.embed(tgt)
+ if cache is None:
+ cache_layer_num = len(self.decoders)
+ if self.decoders2 is not None:
+ cache_layer_num += len(self.decoders2)
+ cache = [None] * cache_layer_num
+ new_cache = []
+ # for c, decoder in zip(cache, self.decoders):
+ for i in range(self.att_layer_num):
+ decoder = self.decoders[i]
+ c = cache[i]
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+ x, tgt_mask, memory, None, cache=c
+ )
+ new_cache.append(c_ret)
+
+ if self.num_blocks - self.att_layer_num > 1:
+ for i in range(self.num_blocks - self.att_layer_num):
+ j = i + self.att_layer_num
+ decoder = self.decoders2[i]
+ c = cache[j]
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+ x, tgt_mask, memory, None, cache=c
+ )
+ new_cache.append(c_ret)
+
+ for decoder in self.decoders3:
+
+ x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
+ x, tgt_mask, memory, None, cache=None
+ )
+
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.output_layer is not None:
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
+
+ return y, new_cache
+
diff --git a/funasr/models/sa_asr/attention.py b/funasr/models/sa_asr/attention.py
new file mode 100644
index 0000000..2cce9ec
--- /dev/null
+++ b/funasr/models/sa_asr/attention.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+
+"""Multi-Head Attention layer definition."""
+
+import math
+
+import numpy
+import torch
+from torch import nn
+from typing import Optional, Tuple
+
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+import funasr.models.lora.layers as lora
+
+
+
+class CosineDistanceAttention(nn.Module):
+ """ Compute Cosine Distance between spk decoder output and speaker profile
+ Args:
+ profile_path: speaker profile file path (.npy file)
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, spk_decoder_out, profile, profile_lens=None):
+ """
+ Args:
+ spk_decoder_out(torch.Tensor):(B, L, D)
+ spk_profiles(torch.Tensor):(B, N, D)
+ """
+ x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D)
+ if profile_lens is not None:
+
+ mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
+ )
+ weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
+ weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N)
+ else:
+ x = x[:, -1:, :, :]
+ weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
+ weights = self.softmax(weights_not_softmax) # (B, 1, N)
+ spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D)
+
+ return spk_embedding, weights
diff --git a/funasr/models/sa_asr/e2e_sa_asr.py b/funasr/models/sa_asr/e2e_sa_asr.py
index e0cb69a..3bfecfc 100644
--- a/funasr/models/sa_asr/e2e_sa_asr.py
+++ b/funasr/models/sa_asr/e2e_sa_asr.py
@@ -20,13 +20,13 @@
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
diff --git a/funasr/models/transformer/transformer_decoder.py b/funasr/models/sa_asr/transformer_decoder.py
similarity index 65%
rename from funasr/models/transformer/transformer_decoder.py
rename to funasr/models/sa_asr/transformer_decoder.py
index b2bea68..3319212 100644
--- a/funasr/models/transformer/transformer_decoder.py
+++ b/funasr/models/sa_asr/transformer_decoder.py
@@ -10,9 +10,9 @@
import torch
from torch import nn
-from funasr.models.decoder.abs_decoder import AbsDecoder
+
from funasr.models.transformer.attention import MultiHeadedAttention
-from funasr.models.transformer.attention import CosineDistanceAttention
+from funasr.models.sa_asr.attention import CosineDistanceAttention
from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
from funasr.models.transformer.embedding import PositionalEncoding
@@ -24,9 +24,10 @@
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
-from funasr.models.transformer.repeat import repeat
+from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
+from funasr.utils.register import register_class, registry_tables
class DecoderLayer(nn.Module):
"""Single decoder layer module.
@@ -150,7 +151,7 @@
return x, tgt_mask, memory, memory_mask
-class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
+class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
"""Base class of Transfomer decoder module.
Args:
@@ -352,7 +353,7 @@
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
-
+@register_class("decoder_classes", "TransformerDecoder")
class TransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -401,6 +402,7 @@
)
+@register_class("decoder_classes", "ParaformerDecoderSAN")
class ParaformerDecoderSAN(BaseTransformerDecoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -514,7 +516,7 @@
else:
return x, olens
-
+@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder")
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -575,7 +577,7 @@
),
)
-
+@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder")
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -637,6 +639,7 @@
)
+@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder")
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -697,7 +700,7 @@
),
)
-
+@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder")
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
@@ -757,426 +760,3 @@
concat_after,
),
)
-
-class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
-
- def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- spker_embedding_dim: int = 256,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- input_layer: str = "embed",
- use_asr_output_layer: bool = True,
- use_spk_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- ):
- super().__init__()
- attention_dim = encoder_output_size
-
- if input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(vocab_size, attention_dim),
- pos_enc_class(attention_dim, positional_dropout_rate),
- )
- elif input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(vocab_size, attention_dim),
- torch.nn.LayerNorm(attention_dim),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(attention_dim, positional_dropout_rate),
- )
- else:
- raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
-
- self.normalize_before = normalize_before
- if self.normalize_before:
- self.after_norm = LayerNorm(attention_dim)
- if use_asr_output_layer:
- self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
- else:
- self.asr_output_layer = None
-
- if use_spk_output_layer:
- self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
- else:
- self.spk_output_layer = None
-
- self.cos_distance_att = CosineDistanceAttention()
-
- self.decoder1 = None
- self.decoder2 = None
- self.decoder3 = None
- self.decoder4 = None
-
- def forward(
- self,
- asr_hs_pad: torch.Tensor,
- spk_hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- profile: torch.Tensor,
- profile_lens: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
-
- tgt = ys_in_pad
- # tgt_mask: (B, 1, L)
- tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
- # m: (1, L, L)
- m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
- # tgt_mask: (B, L, L)
- tgt_mask = tgt_mask & m
-
- asr_memory = asr_hs_pad
- spk_memory = spk_hs_pad
- memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
- # Spk decoder
- x = self.embed(tgt)
-
- x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
- x, tgt_mask, asr_memory, spk_memory, memory_mask
- )
- x, tgt_mask, spk_memory, memory_mask = self.decoder2(
- x, tgt_mask, spk_memory, memory_mask
- )
- if self.normalize_before:
- x = self.after_norm(x)
- if self.spk_output_layer is not None:
- x = self.spk_output_layer(x)
- dn, weights = self.cos_distance_att(x, profile, profile_lens)
- # Asr decoder
- x, tgt_mask, asr_memory, memory_mask = self.decoder3(
- z, tgt_mask, asr_memory, memory_mask, dn
- )
- x, tgt_mask, asr_memory, memory_mask = self.decoder4(
- x, tgt_mask, asr_memory, memory_mask
- )
-
- if self.normalize_before:
- x = self.after_norm(x)
- if self.asr_output_layer is not None:
- x = self.asr_output_layer(x)
-
- olens = tgt_mask.sum(1)
- return x, weights, olens
-
-
- def forward_one_step(
- self,
- tgt: torch.Tensor,
- tgt_mask: torch.Tensor,
- asr_memory: torch.Tensor,
- spk_memory: torch.Tensor,
- profile: torch.Tensor,
- cache: List[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
-
- x = self.embed(tgt)
-
- if cache is None:
- cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
- new_cache = []
- x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
- x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
- )
- new_cache.append(x)
- for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
- x, tgt_mask, spk_memory, _ = decoder(
- x, tgt_mask, spk_memory, None, cache=c
- )
- new_cache.append(x)
- if self.normalize_before:
- x = self.after_norm(x)
- else:
- x = x
- if self.spk_output_layer is not None:
- x = self.spk_output_layer(x)
- dn, weights = self.cos_distance_att(x, profile, None)
-
- x, tgt_mask, asr_memory, _ = self.decoder3(
- z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
- )
- new_cache.append(x)
-
- for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
- x, tgt_mask, asr_memory, _ = decoder(
- x, tgt_mask, asr_memory, None, cache=c
- )
- new_cache.append(x)
-
- if self.normalize_before:
- y = self.after_norm(x[:, -1])
- else:
- y = x[:, -1]
- if self.asr_output_layer is not None:
- y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
-
- return y, weights, new_cache
-
- def score(self, ys, state, asr_enc, spk_enc, profile):
- """Score."""
- ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
- logp, weights, state = self.forward_one_step(
- ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
- )
- return logp.squeeze(0), weights.squeeze(), state
-
-class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
- def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- spker_embedding_dim: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- asr_num_blocks: int = 6,
- spk_num_blocks: int = 3,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_asr_output_layer: bool = True,
- use_spk_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- ):
- super().__init__(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- spker_embedding_dim=spker_embedding_dim,
- dropout_rate=dropout_rate,
- positional_dropout_rate=positional_dropout_rate,
- input_layer=input_layer,
- use_asr_output_layer=use_asr_output_layer,
- use_spk_output_layer=use_spk_output_layer,
- pos_enc_class=pos_enc_class,
- normalize_before=normalize_before,
- )
-
- attention_dim = encoder_output_size
-
- self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
- attention_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, self_attention_dropout_rate
- ),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- )
- self.decoder2 = repeat(
- spk_num_blocks - 1,
- lambda lnum: DecoderLayer(
- attention_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, self_attention_dropout_rate
- ),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
-
- self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
- attention_dim,
- spker_embedding_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- )
- self.decoder4 = repeat(
- asr_num_blocks - 1,
- lambda lnum: DecoderLayer(
- attention_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, self_attention_dropout_rate
- ),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
-class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
-
- def __init__(
- self,
- size,
- self_attn,
- src_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- ):
- """Construct an DecoderLayer object."""
- super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
- self.size = size
- self.self_attn = self_attn
- self.src_attn = src_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(size)
- self.norm2 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear1 = nn.Linear(size + size, size)
- self.concat_linear2 = nn.Linear(size + size, size)
-
- def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
-
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
-
- if cache is None:
- tgt_q = tgt
- tgt_q_mask = tgt_mask
- else:
- # compute only the last frame query keeping dim: max_time_out -> 1
- assert cache.shape == (
- tgt.shape[0],
- tgt.shape[1] - 1,
- self.size,
- ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
- tgt_q = tgt[:, -1:, :]
- residual = residual[:, -1:, :]
- tgt_q_mask = None
- if tgt_mask is not None:
- tgt_q_mask = tgt_mask[:, -1:, :]
-
- if self.concat_after:
- tgt_concat = torch.cat(
- (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
- )
- x = residual + self.concat_linear1(tgt_concat)
- else:
- x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
- if not self.normalize_before:
- x = self.norm1(x)
- z = x
-
- residual = x
- if self.normalize_before:
- x = self.norm1(x)
-
- skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
-
- if self.concat_after:
- x_concat = torch.cat(
- (x, skip), dim=-1
- )
- x = residual + self.concat_linear2(x_concat)
- else:
- x = residual + self.dropout(skip)
- if not self.normalize_before:
- x = self.norm1(x)
-
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- x = residual + self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm2(x)
-
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
-
- return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
-
-class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
-
- def __init__(
- self,
- size,
- d_size,
- src_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- ):
- """Construct an DecoderLayer object."""
- super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
- self.size = size
- self.src_attn = src_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(size)
- self.norm2 = LayerNorm(size)
- self.norm3 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- self.spk_linear = nn.Linear(d_size, size, bias=False)
- if self.concat_after:
- self.concat_linear1 = nn.Linear(size + size, size)
- self.concat_linear2 = nn.Linear(size + size, size)
-
- def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
-
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
-
- if cache is None:
- tgt_q = tgt
- tgt_q_mask = tgt_mask
- else:
-
- tgt_q = tgt[:, -1:, :]
- residual = residual[:, -1:, :]
- tgt_q_mask = None
- if tgt_mask is not None:
- tgt_q_mask = tgt_mask[:, -1:, :]
-
- x = tgt_q
- if self.normalize_before:
- x = self.norm2(x)
- if self.concat_after:
- x_concat = torch.cat(
- (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
- )
- x = residual + self.concat_linear2(x_concat)
- else:
- x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
- if not self.normalize_before:
- x = self.norm2(x)
- residual = x
-
- if dn!=None:
- x = x + self.spk_linear(dn)
- if self.normalize_before:
- x = self.norm3(x)
-
- x = residual + self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm3(x)
-
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
-
- return x, tgt_mask, memory, memory_mask
\ No newline at end of file
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
new file mode 100644
index 0000000..f48617c
--- /dev/null
+++ b/funasr/models/sanm/attention.py
@@ -0,0 +1,641 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention layer definition."""
+
+import math
+
+import numpy
+import torch
+from torch import nn
+from typing import Optional, Tuple
+
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+import funasr.models.lora.layers as lora
+
+class MultiHeadedAttention(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadedAttention, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k = nn.Linear(n_feat, n_feat)
+ self.linear_v = nn.Linear(n_feat, n_feat)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(self, query, key, value):
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+ """
+ n_batch = query.size(0)
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q, k, v
+
+ def forward_attention(self, value, scores, mask):
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, query, key, value, mask):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, mask)
+
+
+
+
+
+
+class MultiHeadedAttentionSANM(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
+ """Construct an MultiHeadedAttention object."""
+ super().__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ # self.linear_q = nn.Linear(n_feat, n_feat)
+ # self.linear_k = nn.Linear(n_feat, n_feat)
+ # self.linear_v = nn.Linear(n_feat, n_feat)
+ if lora_list is not None:
+ if "o" in lora_list:
+ self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
+ if lora_qkv_list == [False, False, False]:
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ else:
+ self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
+ # padding
+ left_padding = (kernel_size - 1) // 2
+ if sanm_shfit > 0:
+ left_padding = left_padding + sanm_shfit
+ right_padding = kernel_size - 1 - left_padding
+ self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+
+ def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
+ b, t, d = inputs.size()
+ if mask is not None:
+ mask = torch.reshape(mask, (b, -1, 1))
+ if mask_shfit_chunk is not None:
+ mask = mask * mask_shfit_chunk
+ inputs = inputs * mask
+
+ x = inputs.transpose(1, 2)
+ x = self.pad_fn(x)
+ x = self.fsmn_block(x)
+ x = x.transpose(1, 2)
+ x += inputs
+ x = self.dropout(x)
+ if mask is not None:
+ x = x * mask
+ return x
+
+ def forward_qkv(self, x):
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+ """
+ b, t, d = x.size()
+ q_k_v = self.linear_q_k_v(x)
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+ q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
+ k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+ v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q_h, k_h, v_h, v
+
+ def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ if mask_att_chunk_encoder is not None:
+ mask = mask * mask_att_chunk_encoder
+
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+ return att_outs + fsmn_memory
+
+ def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ if chunk_size is not None and look_back > 0 or look_back == -1:
+ if cache is not None:
+ k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
+ v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
+ k_h = torch.cat((cache["k"], k_h), dim=2)
+ v_h = torch.cat((cache["v"], v_h), dim=2)
+
+ cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
+ cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
+ if look_back != -1:
+ cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
+ cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
+ else:
+ cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
+ "v": v_h[:, :, :-(chunk_size[2]), :]}
+ cache = cache_tmp
+ fsmn_memory = self.forward_fsmn(v, None)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, None)
+ return att_outs + fsmn_memory, cache
+
+
+
+class MultiHeadedAttentionSANMDecoder(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadedAttentionSANMDecoder, self).__init__()
+
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ self.fsmn_block = nn.Conv1d(n_feat, n_feat,
+ kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
+ # padding
+ # padding
+ left_padding = (kernel_size - 1) // 2
+ if sanm_shfit > 0:
+ left_padding = left_padding + sanm_shfit
+ right_padding = kernel_size - 1 - left_padding
+ self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+ self.kernel_size = kernel_size
+
+ def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
+ '''
+ :param x: (#batch, time1, size).
+ :param mask: Mask tensor (#batch, 1, time)
+ :return:
+ '''
+ # print("in fsmn, inputs", inputs.size())
+ b, t, d = inputs.size()
+ # logging.info(
+ # "mask: {}".format(mask.size()))
+ if mask is not None:
+ mask = torch.reshape(mask, (b ,-1, 1))
+ # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
+ if mask_shfit_chunk is not None:
+ # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
+ mask = mask * mask_shfit_chunk
+ # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
+ # print("in fsmn, mask", mask.size())
+ # print("in fsmn, inputs", inputs.size())
+ inputs = inputs * mask
+
+ x = inputs.transpose(1, 2)
+ b, d, t = x.size()
+ if cache is None:
+ # print("in fsmn, cache is None, x", x.size())
+
+ x = self.pad_fn(x)
+ if not self.training:
+ cache = x
+ else:
+ # print("in fsmn, cache is not None, x", x.size())
+ # x = torch.cat((x, cache), dim=2)[:, :, :-1]
+ # if t < self.kernel_size:
+ # x = self.pad_fn(x)
+ x = torch.cat((cache[:, :, 1:], x), dim=2)
+ x = x[:, :, -(self.kernel_size+t-1):]
+ # print("in fsmn, cache is not None, x_cat", x.size())
+ cache = x
+ x = self.fsmn_block(x)
+ x = x.transpose(1, 2)
+ # print("in fsmn, fsmn_out", x.size())
+ if x.size(1) != inputs.size(1):
+ inputs = inputs[:, -1, :]
+
+ x = x + inputs
+ x = self.dropout(x)
+ if mask is not None:
+ x = x * mask
+ return x, cache
+
+class MultiHeadedAttentionCrossAtt(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadedAttentionCrossAtt, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ if lora_list is not None:
+ if "q" in lora_list:
+ self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ lora_kv_list = ["k" in lora_list, "v" in lora_list]
+ if lora_kv_list == [False, False]:
+ self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+ else:
+ self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
+ r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+ if "o" in lora_list:
+ self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ else:
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(self, x, memory):
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+ """
+
+ # print("in forward_qkv, x", x.size())
+ b = x.size(0)
+ q = self.linear_q(x)
+ q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
+
+ k_v = self.linear_k_v(memory)
+ k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
+ k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+ v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+
+
+ return q_h, k_h, v_h
+
+ def forward_attention(self, value, scores, mask):
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ # logging.info(
+ # "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, x, memory, memory_mask):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h = self.forward_qkv(x, memory)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ return self.forward_attention(v_h, scores, memory_mask)
+
+ def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h = self.forward_qkv(x, memory)
+ if chunk_size is not None and look_back > 0:
+ if cache is not None:
+ k_h = torch.cat((cache["k"], k_h), dim=2)
+ v_h = torch.cat((cache["v"], v_h), dim=2)
+ cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
+ cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
+ else:
+ cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
+ "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
+ cache = cache_tmp
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ return self.forward_attention(v_h, scores, None), cache
+
+
+class MultiHeadSelfAttention(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head, in_feat, n_feat, dropout_rate):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadSelfAttention, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(self, x):
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+ """
+ b, t, d = x.size()
+ q_k_v = self.linear_q_k_v(x)
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+ q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
+ k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+ v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q_h, k_h, v_h, v
+
+ def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ if mask_att_chunk_encoder is not None:
+ mask = mask * mask_att_chunk_encoder
+
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, x, mask, mask_att_chunk_encoder=None):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+ return att_outs
+
+
+
diff --git a/funasr/models/sanm/decoder.py b/funasr/models/sanm/decoder.py
new file mode 100644
index 0000000..64033ad
--- /dev/null
+++ b/funasr/models/sanm/decoder.py
@@ -0,0 +1,474 @@
+from typing import List
+from typing import Tuple
+import logging
+import torch
+import torch.nn as nn
+import numpy as np
+
+from funasr.models.scama import utils as myutils
+from funasr.models.transformer.decoder import BaseTransformerDecoder
+
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.transformer.utils.repeat import repeat
+
+from funasr.utils.register import register_class, registry_tables
+
+class DecoderLayerSANM(nn.Module):
+ """Single decoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ src_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+
+ """
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an DecoderLayer object."""
+ super(DecoderLayerSANM, self).__init__()
+ self.size = size
+ self.self_attn = self_attn
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ if self_attn is not None:
+ self.norm2 = LayerNorm(size)
+ if src_attn is not None:
+ self.norm3 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ # tgt = self.dropout(tgt)
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ x, _ = self.self_attn(tgt, tgt_mask)
+ x = residual + self.dropout(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+ return x, tgt_mask, memory, memory_mask, cache
+
+ def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ # tgt = self.dropout(tgt)
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ if self.training:
+ cache = None
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + self.dropout(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+
+ return x, tgt_mask, memory, memory_mask, cache
+
+ def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
+ x = residual + self.dropout(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
+ x = residual + x
+
+ return x, memory, fsmn_cache, opt_cache
+
+
+@register_class("decoder_classes", "FsmnDecoder")
+class FsmnDecoder(BaseTransformerDecoder):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
+ https://arxiv.org/abs/2006.01713
+
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ att_layer_num: int = 6,
+ kernel_size: int = 21,
+ sanm_shfit: int = None,
+ concat_embeds: bool = False,
+ attention_dim: int = None,
+ tf2torch_tensor_name_prefix_torch: str = "decoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+ embed_tensor_name_prefix_tf: str = None,
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_output_layer=use_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+ if attention_dim is None:
+ attention_dim = encoder_output_size
+
+ if input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(vocab_size, attention_dim),
+ )
+ elif input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(vocab_size, attention_dim),
+ torch.nn.LayerNorm(attention_dim),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ else:
+ raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+ self.normalize_before = normalize_before
+ if self.normalize_before:
+ self.after_norm = LayerNorm(attention_dim)
+ if use_output_layer:
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+ else:
+ self.output_layer = None
+
+ self.att_layer_num = att_layer_num
+ self.num_blocks = num_blocks
+ if sanm_shfit is None:
+ sanm_shfit = (kernel_size - 1) // 2
+ self.decoders = repeat(
+ att_layer_num,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ MultiHeadedAttentionSANMDecoder(
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+ ),
+ MultiHeadedAttentionCrossAtt(
+ attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
+ ),
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ if num_blocks - att_layer_num <= 0:
+ self.decoders2 = None
+ else:
+ self.decoders2 = repeat(
+ num_blocks - att_layer_num,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ MultiHeadedAttentionSANMDecoder(
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+ ),
+ None,
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+ self.decoders3 = repeat(
+ 1,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ None,
+ None,
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ if concat_embeds:
+ self.embed_concat_ffn = repeat(
+ 1,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim + encoder_output_size,
+ None,
+ None,
+ PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
+ adim=attention_dim),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ else:
+ self.embed_concat_ffn = None
+ self.concat_embeds = concat_embeds
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+ self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ chunk_mask: torch.Tensor = None,
+ pre_acoustic_embeds: torch.Tensor = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+ if chunk_mask is not None:
+ memory_mask = memory_mask * chunk_mask
+ if tgt_mask.size(1) != memory_mask.size(1):
+ memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
+
+ x = self.embed(tgt)
+
+ if pre_acoustic_embeds is not None and self.concat_embeds:
+ x = torch.cat((x, pre_acoustic_embeds), dim=-1)
+ x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
+
+ x, tgt_mask, memory, memory_mask, _ = self.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.decoders2 is not None:
+ x, tgt_mask, memory, memory_mask, _ = self.decoders2(
+ x, tgt_mask, memory, memory_mask
+ )
+ x, tgt_mask, memory, memory_mask, _ = self.decoders3(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.output_layer is not None:
+ x = self.output_layer(x)
+
+ olens = tgt_mask.sum(1)
+ return x, olens
+
+ def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
+ """Score."""
+ ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+ logp, state = self.forward_one_step(
+ ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
+ cache=state
+ )
+ return logp.squeeze(0), state
+
+ def forward_one_step(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor = None,
+ pre_acoustic_embeds: torch.Tensor = None,
+ cache: List[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+
+ Args:
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+
+ x = tgt[:, -1:]
+ tgt_mask = None
+ x = self.embed(x)
+
+ if pre_acoustic_embeds is not None and self.concat_embeds:
+ x = torch.cat((x, pre_acoustic_embeds), dim=-1)
+ x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
+
+ if cache is None:
+ cache_layer_num = len(self.decoders)
+ if self.decoders2 is not None:
+ cache_layer_num += len(self.decoders2)
+ cache = [None] * cache_layer_num
+ new_cache = []
+ # for c, decoder in zip(cache, self.decoders):
+ for i in range(self.att_layer_num):
+ decoder = self.decoders[i]
+ c = cache[i]
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+ x, tgt_mask, memory, memory_mask, cache=c
+ )
+ new_cache.append(c_ret)
+
+ if self.num_blocks - self.att_layer_num >= 1:
+ for i in range(self.num_blocks - self.att_layer_num):
+ j = i + self.att_layer_num
+ decoder = self.decoders2[i]
+ c = cache[j]
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+ x, tgt_mask, memory, memory_mask, cache=c
+ )
+ new_cache.append(c_ret)
+
+ for decoder in self.decoders3:
+ x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
+ x, tgt_mask, memory, None, cache=None
+ )
+
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.output_layer is not None:
+ y = self.output_layer(y)
+ y = torch.log_softmax(y, dim=-1)
+
+ return y, new_cache
+
+
\ No newline at end of file
diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
new file mode 100644
index 0000000..8e159e2
--- /dev/null
+++ b/funasr/models/sanm/encoder.py
@@ -0,0 +1,454 @@
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import numpy as np
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
+from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
+from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
+from funasr.models.transformer.positionwise_feed_forward import (
+ PositionwiseFeedForward, # noqa: H301
+)
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
+
+
+from funasr.models.ctc.ctc import CTC
+
+from funasr.utils.register import register_class
+
+class EncoderLayerSANM(nn.Module):
+ def __init__(
+ self,
+ in_size,
+ size,
+ self_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ stochastic_depth_rate=0.0,
+ ):
+ """Construct an EncoderLayer object."""
+ super(EncoderLayerSANM, self).__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(in_size)
+ self.norm2 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.in_size = in_size
+ self.size = size
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear = nn.Linear(size + size, size)
+ self.stochastic_depth_rate = stochastic_depth_rate
+ self.dropout_rate = dropout_rate
+
+ def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+ """Compute encoded features.
+
+ Args:
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+
+ """
+ skip_layer = False
+ # with stochastic depth, residual connection `x + f(x)` becomes
+ # `x <- x + 1 / (1 - p) * f(x)` at training time.
+ stoch_layer_coeff = 1.0
+ if self.training and self.stochastic_depth_rate > 0:
+ skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+ stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+ if skip_layer:
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+ return x, mask
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+
+ if self.concat_after:
+ x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+ if self.in_size == self.size:
+ x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+ else:
+ x = stoch_layer_coeff * self.concat_linear(x_concat)
+ else:
+ if self.in_size == self.size:
+ x = residual + stoch_layer_coeff * self.dropout(
+ self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ )
+ else:
+ x = stoch_layer_coeff * self.dropout(
+ self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ )
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+
+ def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+ """Compute encoded features.
+
+ Args:
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+
+ """
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+
+ if self.in_size == self.size:
+ attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+ x = residual + attn
+ else:
+ x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.feed_forward(x)
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ return x, cache
+
+@register_class("encoder_classes", "SANMEncoder")
+class SANMEncoder(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ San-m: Memory equipped self-attention for end-to-end speech recognition
+ https://arxiv.org/abs/2006.01713
+
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: Optional[str] = "conv2d",
+ pos_enc_class=SinusoidalPositionEncoder,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 1,
+ padding_idx: int = -1,
+ interctc_layer_idx: List[int] = [],
+ interctc_use_conditioning: bool = False,
+ kernel_size : int = 11,
+ sanm_shfit : int = 0,
+ lora_list: List[str] = None,
+ lora_rank: int = 8,
+ lora_alpha: int = 16,
+ lora_dropout: float = 0.1,
+ selfattention_layer_type: str = "sanm",
+ tf2torch_tensor_name_prefix_torch: str = "encoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
+ ):
+ super().__init__()
+ self._output_size = output_size
+
+ if input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(input_size, output_size),
+ torch.nn.LayerNorm(output_size),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d":
+ self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d2":
+ self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d6":
+ self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d8":
+ self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+ elif input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+ SinusoidalPositionEncoder(),
+ )
+ elif input_layer is None:
+ if input_size == output_size:
+ self.embed = None
+ else:
+ self.embed = torch.nn.Linear(input_size, output_size)
+ elif input_layer == "pe":
+ self.embed = SinusoidalPositionEncoder()
+ elif input_layer == "pe_online":
+ self.embed = StreamSinusoidalPositionEncoder()
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+ self.normalize_before = normalize_before
+ if positionwise_layer_type == "linear":
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d":
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d-linear":
+ positionwise_layer = Conv1dLinear
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ else:
+ raise NotImplementedError("Support only linear or conv1d.")
+
+ if selfattention_layer_type == "selfattn":
+ encoder_selfattn_layer = MultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+
+ elif selfattention_layer_type == "sanm":
+ encoder_selfattn_layer = MultiHeadedAttentionSANM
+ encoder_selfattn_layer_args0 = (
+ attention_heads,
+ input_size,
+ output_size,
+ attention_dropout_rate,
+ kernel_size,
+ sanm_shfit,
+ lora_list,
+ lora_rank,
+ lora_alpha,
+ lora_dropout,
+ )
+
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ output_size,
+ attention_dropout_rate,
+ kernel_size,
+ sanm_shfit,
+ lora_list,
+ lora_rank,
+ lora_alpha,
+ lora_dropout,
+ )
+ self.encoders0 = repeat(
+ 1,
+ lambda lnum: EncoderLayerSANM(
+ input_size,
+ output_size,
+ encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+ self.encoders = repeat(
+ num_blocks-1,
+ lambda lnum: EncoderLayerSANM(
+ output_size,
+ output_size,
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ if self.normalize_before:
+ self.after_norm = LayerNorm(output_size)
+
+ self.interctc_layer_idx = interctc_layer_idx
+ if len(interctc_layer_idx) > 0:
+ assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+ self.interctc_use_conditioning = interctc_use_conditioning
+ self.conditioning_layer = None
+ self.dropout = nn.Dropout(dropout_rate)
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Embed positions in tensor.
+
+ Args:
+ xs_pad: input tensor (B, L, D)
+ ilens: input length (B)
+ prev_states: Not to be used now.
+ Returns:
+ position embedded tensor and mask
+ """
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+ xs_pad = xs_pad * self.output_size()**0.5
+ if self.embed is None:
+ xs_pad = xs_pad
+ elif (
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
+ ):
+ short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+ if short_status:
+ raise TooShortUttError(
+ f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ xs_pad.size(1),
+ limit_size,
+ )
+ xs_pad, masks = self.embed(xs_pad, masks)
+ else:
+ xs_pad = self.embed(xs_pad)
+
+ # xs_pad = self.dropout(xs_pad)
+ encoder_outs = self.encoders0(xs_pad, masks)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ intermediate_outs = []
+ if len(self.interctc_layer_idx) == 0:
+ encoder_outs = self.encoders(xs_pad, masks)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ else:
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ encoder_outs = encoder_layer(xs_pad, masks)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+ if layer_idx + 1 in self.interctc_layer_idx:
+ encoder_out = xs_pad
+
+ # intermediate outputs are also normalized
+ if self.normalize_before:
+ encoder_out = self.after_norm(encoder_out)
+
+ intermediate_outs.append((layer_idx + 1, encoder_out))
+
+ if self.interctc_use_conditioning:
+ ctc_out = ctc.softmax(encoder_out)
+ xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ olens = masks.squeeze(1).sum(1)
+ if len(intermediate_outs) > 0:
+ return (xs_pad, intermediate_outs), olens, None
+ return xs_pad, olens, None
+
+ def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
+ if len(cache) == 0:
+ return feats
+ cache["feats"] = to_device(cache["feats"], device=feats.device)
+ overlap_feats = torch.cat((cache["feats"], feats), dim=1)
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ return overlap_feats
+
+ def forward_chunk(self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ cache: dict = None,
+ ctc: CTC = None,
+ ):
+ xs_pad *= self.output_size() ** 0.5
+ if self.embed is None:
+ xs_pad = xs_pad
+ else:
+ xs_pad = self.embed(xs_pad, cache)
+ if cache["tail_chunk"]:
+ xs_pad = to_device(cache["feats"], device=xs_pad.device)
+ else:
+ xs_pad = self._add_overlap_chunk(xs_pad, cache)
+ encoder_outs = self.encoders0(xs_pad, None, None, None, None)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ intermediate_outs = []
+ if len(self.interctc_layer_idx) == 0:
+ encoder_outs = self.encoders(xs_pad, None, None, None, None)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ else:
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ encoder_outs = encoder_layer(xs_pad, None, None, None, None)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ if layer_idx + 1 in self.interctc_layer_idx:
+ encoder_out = xs_pad
+
+ # intermediate outputs are also normalized
+ if self.normalize_before:
+ encoder_out = self.after_norm(encoder_out)
+
+ intermediate_outs.append((layer_idx + 1, encoder_out))
+
+ if self.interctc_use_conditioning:
+ ctc_out = ctc.softmax(encoder_out)
+ xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ if len(intermediate_outs) > 0:
+ return (xs_pad, intermediate_outs), None, None
+ return xs_pad, ilens, None
+
diff --git a/funasr/models/sanm/model.py b/funasr/models/sanm/model.py
new file mode 100644
index 0000000..e01394c
--- /dev/null
+++ b/funasr/models/sanm/model.py
@@ -0,0 +1,18 @@
+import logging
+
+import torch
+
+from funasr.models.transformer.model import Transformer
+from funasr.utils.register import register_class, registry_tables
+
+@register_class("model_classes", "SANM")
+class SANM(Transformer):
+ """CTC-attention hybrid Encoder-Decoder model"""
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+
+ super().__init__(*args, **kwargs)
diff --git a/funasr/models/sanm/positionwise_feed_forward.py b/funasr/models/sanm/positionwise_feed_forward.py
new file mode 100644
index 0000000..7125854
--- /dev/null
+++ b/funasr/models/sanm/positionwise_feed_forward.py
@@ -0,0 +1,34 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+
+"""Positionwise feed forward layer definition."""
+
+import torch
+
+from funasr.models.transformer.layer_norm import LayerNorm
+
+
+
+class PositionwiseFeedForwardDecoderSANM(torch.nn.Module):
+ """Positionwise feed forward layer.
+
+ Args:
+ idim (int): Input dimenstion.
+ hidden_units (int): The number of hidden units.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim, hidden_units, dropout_rate, adim=None, activation=torch.nn.ReLU()):
+ """Construct an PositionwiseFeedForward object."""
+ super(PositionwiseFeedForwardDecoderSANM, self).__init__()
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
+ self.w_2 = torch.nn.Linear(hidden_units, idim if adim is None else adim, bias=False)
+ self.dropout = torch.nn.Dropout(dropout_rate)
+ self.activation = activation
+ self.norm = LayerNorm(hidden_units)
+
+ def forward(self, x):
+ """Forward function."""
+ return self.w_2(self.norm(self.dropout(self.activation(self.w_1(x)))))
diff --git a/funasr/models/sanm/sanm_decoder.py b/funasr/models/sanm/sanm_decoder.py
deleted file mode 100644
index cb38c12..0000000
--- a/funasr/models/sanm/sanm_decoder.py
+++ /dev/null
@@ -1,1541 +0,0 @@
-from typing import List
-from typing import Tuple
-import logging
-import torch
-import torch.nn as nn
-import numpy as np
-
-from funasr.models.scama import utils as myutils
-from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
-
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
-from funasr.models.transformer.embedding import PositionalEncoding
-from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.models.transformer.repeat import repeat
-
-
-class DecoderLayerSANM(nn.Module):
- """Single decoder layer module.
-
- Args:
- size (int): Input dimension.
- self_attn (torch.nn.Module): Self-attention module instance.
- `MultiHeadedAttention` instance can be used as the argument.
- src_attn (torch.nn.Module): Self-attention module instance.
- `MultiHeadedAttention` instance can be used as the argument.
- feed_forward (torch.nn.Module): Feed-forward module instance.
- `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
- can be used as the argument.
- dropout_rate (float): Dropout rate.
- normalize_before (bool): Whether to use layer_norm before the first block.
- concat_after (bool): Whether to concat attention layer's input and output.
- if True, additional linear will be applied.
- i.e. x -> x + linear(concat(x, att(x)))
- if False, no additional linear will be applied. i.e. x -> x + att(x)
-
-
- """
-
- def __init__(
- self,
- size,
- self_attn,
- src_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- ):
- """Construct an DecoderLayer object."""
- super(DecoderLayerSANM, self).__init__()
- self.size = size
- self.self_attn = self_attn
- self.src_attn = src_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(size)
- if self_attn is not None:
- self.norm2 = LayerNorm(size)
- if src_attn is not None:
- self.norm3 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear1 = nn.Linear(size + size, size)
- self.concat_linear2 = nn.Linear(size + size, size)
-
- def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
- """Compute decoded features.
-
- Args:
- tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
- tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
- memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
- memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
- cache (List[torch.Tensor]): List of cached tensors.
- Each tensor shape should be (#batch, maxlen_out - 1, size).
-
- Returns:
- torch.Tensor: Output tensor(#batch, maxlen_out, size).
- torch.Tensor: Mask for output tensor (#batch, maxlen_out).
- torch.Tensor: Encoded memory (#batch, maxlen_in, size).
- torch.Tensor: Encoded memory mask (#batch, maxlen_in).
-
- """
- # tgt = self.dropout(tgt)
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
- tgt = self.feed_forward(tgt)
-
- x = tgt
- if self.self_attn:
- if self.normalize_before:
- tgt = self.norm2(tgt)
- x, _ = self.self_attn(tgt, tgt_mask)
- x = residual + self.dropout(x)
-
- if self.src_attn is not None:
- residual = x
- if self.normalize_before:
- x = self.norm3(x)
-
- x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
-
- return x, tgt_mask, memory, memory_mask, cache
-
- def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
- """Compute decoded features.
-
- Args:
- tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
- tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
- memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
- memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
- cache (List[torch.Tensor]): List of cached tensors.
- Each tensor shape should be (#batch, maxlen_out - 1, size).
-
- Returns:
- torch.Tensor: Output tensor(#batch, maxlen_out, size).
- torch.Tensor: Mask for output tensor (#batch, maxlen_out).
- torch.Tensor: Encoded memory (#batch, maxlen_in, size).
- torch.Tensor: Encoded memory mask (#batch, maxlen_in).
-
- """
- # tgt = self.dropout(tgt)
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
- tgt = self.feed_forward(tgt)
-
- x = tgt
- if self.self_attn:
- if self.normalize_before:
- tgt = self.norm2(tgt)
- if self.training:
- cache = None
- x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
- x = residual + self.dropout(x)
-
- if self.src_attn is not None:
- residual = x
- if self.normalize_before:
- x = self.norm3(x)
-
- x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
-
-
- return x, tgt_mask, memory, memory_mask, cache
-
- def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
- """Compute decoded features.
-
- Args:
- tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
- tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
- memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
- memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
- cache (List[torch.Tensor]): List of cached tensors.
- Each tensor shape should be (#batch, maxlen_out - 1, size).
-
- Returns:
- torch.Tensor: Output tensor(#batch, maxlen_out, size).
- torch.Tensor: Mask for output tensor (#batch, maxlen_out).
- torch.Tensor: Encoded memory (#batch, maxlen_in, size).
- torch.Tensor: Encoded memory mask (#batch, maxlen_in).
-
- """
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
- tgt = self.feed_forward(tgt)
-
- x = tgt
- if self.self_attn:
- if self.normalize_before:
- tgt = self.norm2(tgt)
- x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
- x = residual + self.dropout(x)
-
- if self.src_attn is not None:
- residual = x
- if self.normalize_before:
- x = self.norm3(x)
-
- x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
- x = residual + x
-
- return x, memory, fsmn_cache, opt_cache
-
-
-class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
- https://arxiv.org/abs/2006.01713
-
- """
- def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- att_layer_num: int = 6,
- kernel_size: int = 21,
- sanm_shfit: int = None,
- concat_embeds: bool = False,
- attention_dim: int = None,
- tf2torch_tensor_name_prefix_torch: str = "decoder",
- tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
- embed_tensor_name_prefix_tf: str = None,
- ):
- super().__init__(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- dropout_rate=dropout_rate,
- positional_dropout_rate=positional_dropout_rate,
- input_layer=input_layer,
- use_output_layer=use_output_layer,
- pos_enc_class=pos_enc_class,
- normalize_before=normalize_before,
- )
- if attention_dim is None:
- attention_dim = encoder_output_size
-
- if input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(vocab_size, attention_dim),
- )
- elif input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(vocab_size, attention_dim),
- torch.nn.LayerNorm(attention_dim),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(attention_dim, positional_dropout_rate),
- )
- else:
- raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
-
- self.normalize_before = normalize_before
- if self.normalize_before:
- self.after_norm = LayerNorm(attention_dim)
- if use_output_layer:
- self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
- else:
- self.output_layer = None
-
- self.att_layer_num = att_layer_num
- self.num_blocks = num_blocks
- if sanm_shfit is None:
- sanm_shfit = (kernel_size - 1) // 2
- self.decoders = repeat(
- att_layer_num,
- lambda lnum: DecoderLayerSANM(
- attention_dim,
- MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
- ),
- MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
- ),
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if num_blocks - att_layer_num <= 0:
- self.decoders2 = None
- else:
- self.decoders2 = repeat(
- num_blocks - att_layer_num,
- lambda lnum: DecoderLayerSANM(
- attention_dim,
- MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
- ),
- None,
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
- self.decoders3 = repeat(
- 1,
- lambda lnum: DecoderLayerSANM(
- attention_dim,
- None,
- None,
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if concat_embeds:
- self.embed_concat_ffn = repeat(
- 1,
- lambda lnum: DecoderLayerSANM(
- attention_dim + encoder_output_size,
- None,
- None,
- PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
- adim=attention_dim),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- else:
- self.embed_concat_ffn = None
- self.concat_embeds = concat_embeds
- self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
- self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
- self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
-
- def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- chunk_mask: torch.Tensor = None,
- pre_acoustic_embeds: torch.Tensor = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Forward decoder.
-
- Args:
- hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
- hlens: (batch)
- ys_in_pad:
- input token ids, int64 (batch, maxlen_out)
- if input_layer == "embed"
- input tensor (batch, maxlen_out, #mels) in the other cases
- ys_in_lens: (batch)
- Returns:
- (tuple): tuple containing:
-
- x: decoded token score before softmax (batch, maxlen_out, token)
- if use_output_layer is True,
- olens: (batch, )
- """
- tgt = ys_in_pad
- tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
- memory = hs_pad
- memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
- if chunk_mask is not None:
- memory_mask = memory_mask * chunk_mask
- if tgt_mask.size(1) != memory_mask.size(1):
- memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
-
- x = self.embed(tgt)
-
- if pre_acoustic_embeds is not None and self.concat_embeds:
- x = torch.cat((x, pre_acoustic_embeds), dim=-1)
- x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
-
- x, tgt_mask, memory, memory_mask, _ = self.decoders(
- x, tgt_mask, memory, memory_mask
- )
- if self.decoders2 is not None:
- x, tgt_mask, memory, memory_mask, _ = self.decoders2(
- x, tgt_mask, memory, memory_mask
- )
- x, tgt_mask, memory, memory_mask, _ = self.decoders3(
- x, tgt_mask, memory, memory_mask
- )
- if self.normalize_before:
- x = self.after_norm(x)
- if self.output_layer is not None:
- x = self.output_layer(x)
-
- olens = tgt_mask.sum(1)
- return x, olens
-
- def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
- """Score."""
- ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
- logp, state = self.forward_one_step(
- ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
- cache=state
- )
- return logp.squeeze(0), state
-
- def forward_one_step(
- self,
- tgt: torch.Tensor,
- tgt_mask: torch.Tensor,
- memory: torch.Tensor,
- memory_mask: torch.Tensor = None,
- pre_acoustic_embeds: torch.Tensor = None,
- cache: List[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
- """Forward one step.
-
- Args:
- tgt: input token ids, int64 (batch, maxlen_out)
- tgt_mask: input token mask, (batch, maxlen_out)
- dtype=torch.uint8 in PyTorch 1.2-
- dtype=torch.bool in PyTorch 1.2+ (include 1.2)
- memory: encoded memory, float32 (batch, maxlen_in, feat)
- cache: cached output list of (batch, max_time_out-1, size)
- Returns:
- y, cache: NN output value and cache per `self.decoders`.
- y.shape` is (batch, maxlen_out, token)
- """
-
- x = tgt[:, -1:]
- tgt_mask = None
- x = self.embed(x)
-
- if pre_acoustic_embeds is not None and self.concat_embeds:
- x = torch.cat((x, pre_acoustic_embeds), dim=-1)
- x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
-
- if cache is None:
- cache_layer_num = len(self.decoders)
- if self.decoders2 is not None:
- cache_layer_num += len(self.decoders2)
- cache = [None] * cache_layer_num
- new_cache = []
- # for c, decoder in zip(cache, self.decoders):
- for i in range(self.att_layer_num):
- decoder = self.decoders[i]
- c = cache[i]
- x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
- x, tgt_mask, memory, memory_mask, cache=c
- )
- new_cache.append(c_ret)
-
- if self.num_blocks - self.att_layer_num >= 1:
- for i in range(self.num_blocks - self.att_layer_num):
- j = i + self.att_layer_num
- decoder = self.decoders2[i]
- c = cache[j]
- x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
- x, tgt_mask, memory, memory_mask, cache=c
- )
- new_cache.append(c_ret)
-
- for decoder in self.decoders3:
- x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
- x, tgt_mask, memory, None, cache=None
- )
-
- if self.normalize_before:
- y = self.after_norm(x[:, -1])
- else:
- y = x[:, -1]
- if self.output_layer is not None:
- y = self.output_layer(y)
- y = torch.log_softmax(y, dim=-1)
-
- return y, new_cache
-
- def gen_tf2torch_map_dict(self):
-
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
- map_dict_local = {
-
- ## decoder
- # ffn
- "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # fsmn
- "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
- tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
- tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
- tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 2, 0),
- }, # (256,1,31),(1,31,256,1)
- # src att
- "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # dnn
- "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # embed_concat_ffn
- "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # out norm
- "{}.after_norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.after_norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
-
- # in embed
- "{}.embed.0.weight".format(tensor_name_prefix_torch):
- {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (4235,256),(4235,256)
-
- # out layer
- "{}.output_layer.weight".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
- "{}/w_embs".format(embed_tensor_name_prefix_tf)],
- "squeeze": [None, None],
- "transpose": [(1, 0), None],
- }, # (4235,256),(256,4235)
- "{}.output_layer.bias".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
- "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
- "squeeze": [None, None],
- "transpose": [None, None],
- }, # (4235,),(4235,)
-
- }
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
- var_dict_torch_update = dict()
- decoder_layeridx_sets = set()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- if names[1] == "decoders":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- layeridx_bias = 0
- layeridx += layeridx_bias
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "decoders2":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- name_q = name_q.replace("decoders2", "decoders")
- layeridx_bias = len(decoder_layeridx_sets)
-
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "decoders3":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- layeridx_bias = 0
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "embed" or names[1] == "output_layer":
- name_tf = map_dict[name]["name"]
- if isinstance(name_tf, list):
- idx_list = 0
- if name_tf[idx_list] in var_dict_tf.keys():
- pass
- else:
- idx_list = 1
- data_tf = var_dict_tf[name_tf[idx_list]]
- if map_dict[name]["squeeze"][idx_list] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
- if map_dict[name]["transpose"][idx_list] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
- name_tf[idx_list],
- var_dict_tf[name_tf[
- idx_list]].shape))
-
- else:
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "after_norm":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "embed_concat_ffn":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- layeridx_bias = 0
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
-
-
-class ParaformerSANMDecoder(BaseTransformerDecoder):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2006.01713
- """
- def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- att_layer_num: int = 6,
- kernel_size: int = 21,
- sanm_shfit: int = 0,
- lora_list: List[str] = None,
- lora_rank: int = 8,
- lora_alpha: int = 16,
- lora_dropout: float = 0.1,
- chunk_multiply_factor: tuple = (1,),
- tf2torch_tensor_name_prefix_torch: str = "decoder",
- tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
- ):
- super().__init__(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- dropout_rate=dropout_rate,
- positional_dropout_rate=positional_dropout_rate,
- input_layer=input_layer,
- use_output_layer=use_output_layer,
- pos_enc_class=pos_enc_class,
- normalize_before=normalize_before,
- )
-
- attention_dim = encoder_output_size
-
- if input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(vocab_size, attention_dim),
- # pos_enc_class(attention_dim, positional_dropout_rate),
- )
- elif input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(vocab_size, attention_dim),
- torch.nn.LayerNorm(attention_dim),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(attention_dim, positional_dropout_rate),
- )
- else:
- raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
-
- self.normalize_before = normalize_before
- if self.normalize_before:
- self.after_norm = LayerNorm(attention_dim)
- if use_output_layer:
- self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
- else:
- self.output_layer = None
-
- self.att_layer_num = att_layer_num
- self.num_blocks = num_blocks
- if sanm_shfit is None:
- sanm_shfit = (kernel_size - 1) // 2
- self.decoders = repeat(
- att_layer_num,
- lambda lnum: DecoderLayerSANM(
- attention_dim,
- MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
- ),
- MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
- ),
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if num_blocks - att_layer_num <= 0:
- self.decoders2 = None
- else:
- self.decoders2 = repeat(
- num_blocks - att_layer_num,
- lambda lnum: DecoderLayerSANM(
- attention_dim,
- MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
- ),
- None,
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
- self.decoders3 = repeat(
- 1,
- lambda lnum: DecoderLayerSANM(
- attention_dim,
- None,
- None,
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
- self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
- self.chunk_multiply_factor = chunk_multiply_factor
-
- def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- chunk_mask: torch.Tensor = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Forward decoder.
-
- Args:
- hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
- hlens: (batch)
- ys_in_pad:
- input token ids, int64 (batch, maxlen_out)
- if input_layer == "embed"
- input tensor (batch, maxlen_out, #mels) in the other cases
- ys_in_lens: (batch)
- Returns:
- (tuple): tuple containing:
-
- x: decoded token score before softmax (batch, maxlen_out, token)
- if use_output_layer is True,
- olens: (batch, )
- """
- tgt = ys_in_pad
- tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
- memory = hs_pad
- memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
- if chunk_mask is not None:
- memory_mask = memory_mask * chunk_mask
- if tgt_mask.size(1) != memory_mask.size(1):
- memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
-
- x = tgt
- x, tgt_mask, memory, memory_mask, _ = self.decoders(
- x, tgt_mask, memory, memory_mask
- )
- if self.decoders2 is not None:
- x, tgt_mask, memory, memory_mask, _ = self.decoders2(
- x, tgt_mask, memory, memory_mask
- )
- x, tgt_mask, memory, memory_mask, _ = self.decoders3(
- x, tgt_mask, memory, memory_mask
- )
- if self.normalize_before:
- x = self.after_norm(x)
- if self.output_layer is not None:
- x = self.output_layer(x)
-
- olens = tgt_mask.sum(1)
- return x, olens
-
- def score(self, ys, state, x):
- """Score."""
- ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
- logp, state = self.forward_one_step(
- ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
- )
- return logp.squeeze(0), state
-
- def forward_chunk(
- self,
- memory: torch.Tensor,
- tgt: torch.Tensor,
- cache: dict = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Forward decoder.
-
- Args:
- hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
- hlens: (batch)
- ys_in_pad:
- input token ids, int64 (batch, maxlen_out)
- if input_layer == "embed"
- input tensor (batch, maxlen_out, #mels) in the other cases
- ys_in_lens: (batch)
- Returns:
- (tuple): tuple containing:
-
- x: decoded token score before softmax (batch, maxlen_out, token)
- if use_output_layer is True,
- olens: (batch, )
- """
- x = tgt
- if cache["decode_fsmn"] is None:
- cache_layer_num = len(self.decoders)
- if self.decoders2 is not None:
- cache_layer_num += len(self.decoders2)
- fsmn_cache = [None] * cache_layer_num
- else:
- fsmn_cache = cache["decode_fsmn"]
-
- if cache["opt"] is None:
- cache_layer_num = len(self.decoders)
- opt_cache = [None] * cache_layer_num
- else:
- opt_cache = cache["opt"]
-
- for i in range(self.att_layer_num):
- decoder = self.decoders[i]
- x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
- x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
- chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
- )
-
- if self.num_blocks - self.att_layer_num > 1:
- for i in range(self.num_blocks - self.att_layer_num):
- j = i + self.att_layer_num
- decoder = self.decoders2[i]
- x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
- x, memory, fsmn_cache=fsmn_cache[j]
- )
-
- for decoder in self.decoders3:
- x, memory, _, _ = decoder.forward_chunk(
- x, memory
- )
- if self.normalize_before:
- x = self.after_norm(x)
- if self.output_layer is not None:
- x = self.output_layer(x)
-
- cache["decode_fsmn"] = fsmn_cache
- if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1:
- cache["opt"] = opt_cache
- return x
-
- def forward_one_step(
- self,
- tgt: torch.Tensor,
- tgt_mask: torch.Tensor,
- memory: torch.Tensor,
- cache: List[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
- """Forward one step.
-
- Args:
- tgt: input token ids, int64 (batch, maxlen_out)
- tgt_mask: input token mask, (batch, maxlen_out)
- dtype=torch.uint8 in PyTorch 1.2-
- dtype=torch.bool in PyTorch 1.2+ (include 1.2)
- memory: encoded memory, float32 (batch, maxlen_in, feat)
- cache: cached output list of (batch, max_time_out-1, size)
- Returns:
- y, cache: NN output value and cache per `self.decoders`.
- y.shape` is (batch, maxlen_out, token)
- """
- x = self.embed(tgt)
- if cache is None:
- cache_layer_num = len(self.decoders)
- if self.decoders2 is not None:
- cache_layer_num += len(self.decoders2)
- cache = [None] * cache_layer_num
- new_cache = []
- # for c, decoder in zip(cache, self.decoders):
- for i in range(self.att_layer_num):
- decoder = self.decoders[i]
- c = cache[i]
- x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
- x, tgt_mask, memory, None, cache=c
- )
- new_cache.append(c_ret)
-
- if self.num_blocks - self.att_layer_num > 1:
- for i in range(self.num_blocks - self.att_layer_num):
- j = i + self.att_layer_num
- decoder = self.decoders2[i]
- c = cache[j]
- x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
- x, tgt_mask, memory, None, cache=c
- )
- new_cache.append(c_ret)
-
- for decoder in self.decoders3:
-
- x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
- x, tgt_mask, memory, None, cache=None
- )
-
- if self.normalize_before:
- y = self.after_norm(x[:, -1])
- else:
- y = x[:, -1]
- if self.output_layer is not None:
- y = torch.log_softmax(self.output_layer(y), dim=-1)
-
- return y, new_cache
-
- def gen_tf2torch_map_dict(self):
-
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
-
- ## decoder
- # ffn
- "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # fsmn
- "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
- tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
- tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
- tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 2, 0),
- }, # (256,1,31),(1,31,256,1)
- # src att
- "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # dnn
- "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # embed_concat_ffn
- "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
-
- # out norm
- "{}.after_norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.after_norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
-
- # in embed
- "{}.embed.0.weight".format(tensor_name_prefix_torch):
- {"name": "{}/w_embs".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (4235,256),(4235,256)
-
- # out layer
- "{}.output_layer.weight".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
- "squeeze": [None, None],
- "transpose": [(1, 0), None],
- }, # (4235,256),(256,4235)
- "{}.output_layer.bias".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
- "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
- "squeeze": [None, None],
- "transpose": [None, None],
- }, # (4235,),(4235,)
-
- }
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
- map_dict = self.gen_tf2torch_map_dict()
- var_dict_torch_update = dict()
- decoder_layeridx_sets = set()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- if names[1] == "decoders":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- layeridx_bias = 0
- layeridx += layeridx_bias
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "decoders2":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- name_q = name_q.replace("decoders2", "decoders")
- layeridx_bias = len(decoder_layeridx_sets)
-
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "decoders3":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- layeridx_bias = 0
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "embed" or names[1] == "output_layer":
- name_tf = map_dict[name]["name"]
- if isinstance(name_tf, list):
- idx_list = 0
- if name_tf[idx_list] in var_dict_tf.keys():
- pass
- else:
- idx_list = 1
- data_tf = var_dict_tf[name_tf[idx_list]]
- if map_dict[name]["squeeze"][idx_list] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
- if map_dict[name]["transpose"][idx_list] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
- name_tf[idx_list],
- var_dict_tf[name_tf[
- idx_list]].shape))
-
- else:
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "after_norm":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "embed_concat_ffn":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- layeridx_bias = 0
- layeridx += layeridx_bias
- if "decoders." in name:
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
diff --git a/funasr/models/sanm/sanm_encoder.py b/funasr/models/sanm/sanm_encoder.py
deleted file mode 100644
index 83edbe7..0000000
--- a/funasr/models/sanm/sanm_encoder.py
+++ /dev/null
@@ -1,1293 +0,0 @@
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-import logging
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from funasr.models.scama.chunk_utilis import overlap_chunk
-import numpy as np
-from funasr.train_utils.device_funcs import to_device
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.transformer.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
-from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
-from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
-from funasr.models.transformer.positionwise_feed_forward import (
- PositionwiseFeedForward, # noqa: H301
-)
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.subsampling import Conv2dSubsampling
-from funasr.models.transformer.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.subsampling import TooShortUttError
-from funasr.models.transformer.subsampling import check_short_utt
-from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
-
-from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
-
-class EncoderLayerSANM(nn.Module):
- def __init__(
- self,
- in_size,
- size,
- self_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- stochastic_depth_rate=0.0,
- ):
- """Construct an EncoderLayer object."""
- super(EncoderLayerSANM, self).__init__()
- self.self_attn = self_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(in_size)
- self.norm2 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.in_size = in_size
- self.size = size
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear = nn.Linear(size + size, size)
- self.stochastic_depth_rate = stochastic_depth_rate
- self.dropout_rate = dropout_rate
-
- def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
- """Compute encoded features.
-
- Args:
- x_input (torch.Tensor): Input tensor (#batch, time, size).
- mask (torch.Tensor): Mask tensor for the input (#batch, time).
- cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-
- Returns:
- torch.Tensor: Output tensor (#batch, time, size).
- torch.Tensor: Mask tensor (#batch, time).
-
- """
- skip_layer = False
- # with stochastic depth, residual connection `x + f(x)` becomes
- # `x <- x + 1 / (1 - p) * f(x)` at training time.
- stoch_layer_coeff = 1.0
- if self.training and self.stochastic_depth_rate > 0:
- skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
- stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
-
- if skip_layer:
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
- return x, mask
-
- residual = x
- if self.normalize_before:
- x = self.norm1(x)
-
- if self.concat_after:
- x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
- if self.in_size == self.size:
- x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
- else:
- x = stoch_layer_coeff * self.concat_linear(x_concat)
- else:
- if self.in_size == self.size:
- x = residual + stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
- )
- else:
- x = stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
- )
- if not self.normalize_before:
- x = self.norm1(x)
-
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm2(x)
-
- return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
-
- def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
- """Compute encoded features.
-
- Args:
- x_input (torch.Tensor): Input tensor (#batch, time, size).
- mask (torch.Tensor): Mask tensor for the input (#batch, time).
- cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-
- Returns:
- torch.Tensor: Output tensor (#batch, time, size).
- torch.Tensor: Mask tensor (#batch, time).
-
- """
-
- residual = x
- if self.normalize_before:
- x = self.norm1(x)
-
- if self.in_size == self.size:
- attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
- x = residual + attn
- else:
- x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
-
- if not self.normalize_before:
- x = self.norm1(x)
-
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- x = residual + self.feed_forward(x)
- if not self.normalize_before:
- x = self.norm2(x)
-
- return x, cache
-
-
-class SANMEncoder(AbsEncoder):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- San-m: Memory equipped self-attention for end-to-end speech recognition
- https://arxiv.org/abs/2006.01713
-
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: Optional[str] = "conv2d",
- pos_enc_class=SinusoidalPositionEncoder,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- kernel_size : int = 11,
- sanm_shfit : int = 0,
- lora_list: List[str] = None,
- lora_rank: int = 8,
- lora_alpha: int = 16,
- lora_dropout: float = 0.1,
- selfattention_layer_type: str = "sanm",
- tf2torch_tensor_name_prefix_torch: str = "encoder",
- tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
- ):
- super().__init__()
- self._output_size = output_size
-
- if input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(input_size, output_size),
- torch.nn.LayerNorm(output_size),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d":
- self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d2":
- self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d6":
- self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d8":
- self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
- elif input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
- SinusoidalPositionEncoder(),
- )
- elif input_layer is None:
- if input_size == output_size:
- self.embed = None
- else:
- self.embed = torch.nn.Linear(input_size, output_size)
- elif input_layer == "pe":
- self.embed = SinusoidalPositionEncoder()
- elif input_layer == "pe_online":
- self.embed = StreamSinusoidalPositionEncoder()
- else:
- raise ValueError("unknown input_layer: " + input_layer)
- self.normalize_before = normalize_before
- if positionwise_layer_type == "linear":
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d":
- positionwise_layer = MultiLayeredConv1d
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d-linear":
- positionwise_layer = Conv1dLinear
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- else:
- raise NotImplementedError("Support only linear or conv1d.")
-
- if selfattention_layer_type == "selfattn":
- encoder_selfattn_layer = MultiHeadedAttention
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- )
-
- elif selfattention_layer_type == "sanm":
- encoder_selfattn_layer = MultiHeadedAttentionSANM
- encoder_selfattn_layer_args0 = (
- attention_heads,
- input_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- lora_list,
- lora_rank,
- lora_alpha,
- lora_dropout,
- )
-
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- lora_list,
- lora_rank,
- lora_alpha,
- lora_dropout,
- )
- self.encoders0 = repeat(
- 1,
- lambda lnum: EncoderLayerSANM(
- input_size,
- output_size,
- encoder_selfattn_layer(*encoder_selfattn_layer_args0),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
- self.encoders = repeat(
- num_blocks-1,
- lambda lnum: EncoderLayerSANM(
- output_size,
- output_size,
- encoder_selfattn_layer(*encoder_selfattn_layer_args),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
-
- self.interctc_layer_idx = interctc_layer_idx
- if len(interctc_layer_idx) > 0:
- assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
- self.interctc_use_conditioning = interctc_use_conditioning
- self.conditioning_layer = None
- self.dropout = nn.Dropout(dropout_rate)
- self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
- self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
-
- def output_size(self) -> int:
- return self._output_size
-
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Embed positions in tensor.
-
- Args:
- xs_pad: input tensor (B, L, D)
- ilens: input length (B)
- prev_states: Not to be used now.
- Returns:
- position embedded tensor and mask
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- xs_pad = xs_pad * self.output_size()**0.5
- if self.embed is None:
- xs_pad = xs_pad
- elif (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
- ):
- short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
- if short_status:
- raise TooShortUttError(
- f"has {xs_pad.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- xs_pad.size(1),
- limit_size,
- )
- xs_pad, masks = self.embed(xs_pad, masks)
- else:
- xs_pad = self.embed(xs_pad)
-
- # xs_pad = self.dropout(xs_pad)
- encoder_outs = self.encoders0(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- intermediate_outs = []
- if len(self.interctc_layer_idx) == 0:
- encoder_outs = self.encoders(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- else:
- for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
- if layer_idx + 1 in self.interctc_layer_idx:
- encoder_out = xs_pad
-
- # intermediate outputs are also normalized
- if self.normalize_before:
- encoder_out = self.after_norm(encoder_out)
-
- intermediate_outs.append((layer_idx + 1, encoder_out))
-
- if self.interctc_use_conditioning:
- ctc_out = ctc.softmax(encoder_out)
- xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
-
- olens = masks.squeeze(1).sum(1)
- if len(intermediate_outs) > 0:
- return (xs_pad, intermediate_outs), olens, None
- return xs_pad, olens, None
-
- def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
- if len(cache) == 0:
- return feats
- cache["feats"] = to_device(cache["feats"], device=feats.device)
- overlap_feats = torch.cat((cache["feats"], feats), dim=1)
- cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
- return overlap_feats
-
- def forward_chunk(self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- cache: dict = None,
- ctc: CTC = None,
- ):
- xs_pad *= self.output_size() ** 0.5
- if self.embed is None:
- xs_pad = xs_pad
- else:
- xs_pad = self.embed(xs_pad, cache)
- if cache["tail_chunk"]:
- xs_pad = to_device(cache["feats"], device=xs_pad.device)
- else:
- xs_pad = self._add_overlap_chunk(xs_pad, cache)
- encoder_outs = self.encoders0(xs_pad, None, None, None, None)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- intermediate_outs = []
- if len(self.interctc_layer_idx) == 0:
- encoder_outs = self.encoders(xs_pad, None, None, None, None)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- else:
- for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer(xs_pad, None, None, None, None)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- if layer_idx + 1 in self.interctc_layer_idx:
- encoder_out = xs_pad
-
- # intermediate outputs are also normalized
- if self.normalize_before:
- encoder_out = self.after_norm(encoder_out)
-
- intermediate_outs.append((layer_idx + 1, encoder_out))
-
- if self.interctc_use_conditioning:
- ctc_out = ctc.softmax(encoder_out)
- xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
-
- if len(intermediate_outs) > 0:
- return (xs_pad, intermediate_outs), None, None
- return xs_pad, ilens, None
-
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
- ## encoder
- # cicd
- "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (768,256),(1,256,768)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (768,),(768,)
- "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 2, 0),
- }, # (256,1,31),(1,31,256,1)
- "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # ffn
- "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
- "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # out norm
- "{}.after_norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.after_norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
-
- }
-
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
-
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- if names[1] == "encoders0":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- name_q = name_q.replace("encoders0", "encoders")
- layeridx_bias = 0
- layeridx += layeridx_bias
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
- elif names[1] == "encoders":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- layeridx_bias = 1
- layeridx += layeridx_bias
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "after_norm":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
-
-
-class SANMEncoderChunkOpt(AbsEncoder):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
- https://arxiv.org/abs/2006.01713
-
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: Optional[str] = "conv2d",
- pos_enc_class=SinusoidalPositionEncoder,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- kernel_size: int = 11,
- sanm_shfit: int = 0,
- selfattention_layer_type: str = "sanm",
- chunk_size: Union[int, Sequence[int]] = (16,),
- stride: Union[int, Sequence[int]] = (10,),
- pad_left: Union[int, Sequence[int]] = (0,),
- encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
- decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
- tf2torch_tensor_name_prefix_torch: str = "encoder",
- tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
- ):
- super().__init__()
- self._output_size = output_size
-
- if input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(input_size, output_size),
- torch.nn.LayerNorm(output_size),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d":
- self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d2":
- self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d6":
- self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d8":
- self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
- elif input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer is None:
- if input_size == output_size:
- self.embed = None
- else:
- self.embed = torch.nn.Linear(input_size, output_size)
- elif input_layer == "pe":
- self.embed = SinusoidalPositionEncoder()
- elif input_layer == "pe_online":
- self.embed = StreamSinusoidalPositionEncoder()
- else:
- raise ValueError("unknown input_layer: " + input_layer)
- self.normalize_before = normalize_before
- if positionwise_layer_type == "linear":
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d":
- positionwise_layer = MultiLayeredConv1d
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d-linear":
- positionwise_layer = Conv1dLinear
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- else:
- raise NotImplementedError("Support only linear or conv1d.")
-
- if selfattention_layer_type == "selfattn":
- encoder_selfattn_layer = MultiHeadedAttention
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- )
- elif selfattention_layer_type == "sanm":
- encoder_selfattn_layer = MultiHeadedAttentionSANM
- encoder_selfattn_layer_args0 = (
- attention_heads,
- input_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
-
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
- self.encoders0 = repeat(
- 1,
- lambda lnum: EncoderLayerSANM(
- input_size,
- output_size,
- encoder_selfattn_layer(*encoder_selfattn_layer_args0),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
- self.encoders = repeat(
- num_blocks - 1,
- lambda lnum: EncoderLayerSANM(
- output_size,
- output_size,
- encoder_selfattn_layer(*encoder_selfattn_layer_args),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
-
- self.interctc_layer_idx = interctc_layer_idx
- if len(interctc_layer_idx) > 0:
- assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
- self.interctc_use_conditioning = interctc_use_conditioning
- self.conditioning_layer = None
- shfit_fsmn = (kernel_size - 1) // 2
- self.overlap_chunk_cls = overlap_chunk(
- chunk_size=chunk_size,
- stride=stride,
- pad_left=pad_left,
- shfit_fsmn=shfit_fsmn,
- encoder_att_look_back_factor=encoder_att_look_back_factor,
- decoder_att_look_back_factor=decoder_att_look_back_factor,
- )
- self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
- self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
-
- def output_size(self) -> int:
- return self._output_size
-
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ind: int = 0,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Embed positions in tensor.
-
- Args:
- xs_pad: input tensor (B, L, D)
- ilens: input length (B)
- prev_states: Not to be used now.
- Returns:
- position embedded tensor and mask
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- xs_pad *= self.output_size() ** 0.5
- if self.embed is None:
- xs_pad = xs_pad
- elif (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
- ):
- short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
- if short_status:
- raise TooShortUttError(
- f"has {xs_pad.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- xs_pad.size(1),
- limit_size,
- )
- xs_pad, masks = self.embed(xs_pad, masks)
- else:
- xs_pad = self.embed(xs_pad)
-
- mask_shfit_chunk, mask_att_chunk_encoder = None, None
- if self.overlap_chunk_cls is not None:
- ilens = masks.squeeze(1).sum(1)
- chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
- xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
- dtype=xs_pad.dtype)
- mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
- xs_pad.size(0),
- dtype=xs_pad.dtype)
-
- encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- intermediate_outs = []
- if len(self.interctc_layer_idx) == 0:
- encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- else:
- for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- if layer_idx + 1 in self.interctc_layer_idx:
- encoder_out = xs_pad
-
- # intermediate outputs are also normalized
- if self.normalize_before:
- encoder_out = self.after_norm(encoder_out)
-
- intermediate_outs.append((layer_idx + 1, encoder_out))
-
- if self.interctc_use_conditioning:
- ctc_out = ctc.softmax(encoder_out)
- xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
-
- olens = masks.squeeze(1).sum(1)
- if len(intermediate_outs) > 0:
- return (xs_pad, intermediate_outs), olens, None
- return xs_pad, olens, None
-
- def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
- if len(cache) == 0:
- return feats
- cache["feats"] = to_device(cache["feats"], device=feats.device)
- overlap_feats = torch.cat((cache["feats"], feats), dim=1)
- cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
- return overlap_feats
-
- def forward_chunk(self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- cache: dict = None,
- ):
- xs_pad *= self.output_size() ** 0.5
- if self.embed is None:
- xs_pad = xs_pad
- else:
- xs_pad = self.embed(xs_pad, cache)
- if cache["tail_chunk"]:
- xs_pad = to_device(cache["feats"], device=xs_pad.device)
- else:
- xs_pad = self._add_overlap_chunk(xs_pad, cache)
- if cache["opt"] is None:
- cache_layer_num = len(self.encoders0) + len(self.encoders)
- new_cache = [None] * cache_layer_num
- else:
- new_cache = cache["opt"]
-
- for layer_idx, encoder_layer in enumerate(self.encoders0):
- encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
- xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
-
- for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
- xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
-
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
- if cache["encoder_chunk_look_back"] > 0 or cache["encoder_chunk_look_back"] == -1:
- cache["opt"] = new_cache
-
- return xs_pad, ilens, None
-
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
- ## encoder
- # cicd
- "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (768,256),(1,256,768)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (768,),(768,)
- "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 2, 0),
- }, # (256,1,31),(1,31,256,1)
- "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # ffn
- "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
- "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # out norm
- "{}.after_norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.after_norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
-
- }
-
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
-
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- if names[1] == "encoders0":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- name_q = name_q.replace("encoders0", "encoders")
- layeridx_bias = 0
- layeridx += layeridx_bias
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
- elif names[1] == "encoders":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- layeridx_bias = 1
- layeridx += layeridx_bias
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "after_norm":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
-
-
-class SANMVadEncoder(AbsEncoder):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
-
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: Optional[str] = "conv2d",
- pos_enc_class=SinusoidalPositionEncoder,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- kernel_size : int = 11,
- sanm_shfit : int = 0,
- selfattention_layer_type: str = "sanm",
- ):
- super().__init__()
- self._output_size = output_size
-
- if input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(input_size, output_size),
- torch.nn.LayerNorm(output_size),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d":
- self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d2":
- self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d6":
- self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d8":
- self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
- elif input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
- SinusoidalPositionEncoder(),
- )
- elif input_layer is None:
- if input_size == output_size:
- self.embed = None
- else:
- self.embed = torch.nn.Linear(input_size, output_size)
- elif input_layer == "pe":
- self.embed = SinusoidalPositionEncoder()
- else:
- raise ValueError("unknown input_layer: " + input_layer)
- self.normalize_before = normalize_before
- if positionwise_layer_type == "linear":
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d":
- positionwise_layer = MultiLayeredConv1d
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d-linear":
- positionwise_layer = Conv1dLinear
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- else:
- raise NotImplementedError("Support only linear or conv1d.")
-
- if selfattention_layer_type == "selfattn":
- encoder_selfattn_layer = MultiHeadedAttention
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- )
-
- elif selfattention_layer_type == "sanm":
- self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
- encoder_selfattn_layer_args0 = (
- attention_heads,
- input_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
-
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
-
- self.encoders0 = repeat(
- 1,
- lambda lnum: EncoderLayerSANM(
- input_size,
- output_size,
- self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
- self.encoders = repeat(
- num_blocks-1,
- lambda lnum: EncoderLayerSANM(
- output_size,
- output_size,
- self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
-
- self.interctc_layer_idx = interctc_layer_idx
- if len(interctc_layer_idx) > 0:
- assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
- self.interctc_use_conditioning = interctc_use_conditioning
- self.conditioning_layer = None
- self.dropout = nn.Dropout(dropout_rate)
-
- def output_size(self) -> int:
- return self._output_size
-
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- vad_indexes: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Embed positions in tensor.
-
- Args:
- xs_pad: input tensor (B, L, D)
- ilens: input length (B)
- prev_states: Not to be used now.
- Returns:
- position embedded tensor and mask
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
- no_future_masks = masks & sub_masks
- xs_pad *= self.output_size()**0.5
- if self.embed is None:
- xs_pad = xs_pad
- elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
- short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
- if short_status:
- raise TooShortUttError(
- f"has {xs_pad.size(1)} frames and is too short for subsampling " +
- f"(it needs more than {limit_size} frames), return empty results",
- xs_pad.size(1),
- limit_size,
- )
- xs_pad, masks = self.embed(xs_pad, masks)
- else:
- xs_pad = self.embed(xs_pad)
-
- # xs_pad = self.dropout(xs_pad)
- mask_tup0 = [masks, no_future_masks]
- encoder_outs = self.encoders0(xs_pad, mask_tup0)
- xs_pad, _ = encoder_outs[0], encoder_outs[1]
- intermediate_outs = []
-
-
- for layer_idx, encoder_layer in enumerate(self.encoders):
- if layer_idx + 1 == len(self.encoders):
- # This is last layer.
- coner_mask = torch.ones(masks.size(0),
- masks.size(-1),
- masks.size(-1),
- device=xs_pad.device,
- dtype=torch.bool)
- for word_index, length in enumerate(ilens):
- coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
- vad_indexes[word_index],
- device=xs_pad.device)
- layer_mask = masks & coner_mask
- else:
- layer_mask = no_future_masks
- mask_tup1 = [masks, layer_mask]
- encoder_outs = encoder_layer(xs_pad, mask_tup1)
- xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
-
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
-
- olens = masks.squeeze(1).sum(1)
- if len(intermediate_outs) > 0:
- return (xs_pad, intermediate_outs), olens, None
- return xs_pad, olens, None
diff --git a/funasr/models/paraformer/contextual_decoder.py b/funasr/models/scama/sanm_decoder.py
similarity index 70%
copy from funasr/models/paraformer/contextual_decoder.py
copy to funasr/models/scama/sanm_decoder.py
index 626cdef..53423d0 100644
--- a/funasr/models/paraformer/contextual_decoder.py
+++ b/funasr/models/scama/sanm_decoder.py
@@ -6,17 +6,38 @@
import numpy as np
from funasr.models.scama import utils as myutils
-from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
+from funasr.models.transformer.decoder import BaseTransformerDecoder
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.models.transformer.repeat import repeat
-from funasr.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.transformer.utils.repeat import repeat
+
+from funasr.utils.register import register_class, registry_tables
+
+class DecoderLayerSANM(nn.Module):
+ """Single decoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ src_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
-class ContextualDecoderLayer(nn.Module):
+ """
+
def __init__(
self,
size,
@@ -28,7 +49,7 @@
concat_after=False,
):
"""Construct an DecoderLayer object."""
- super(ContextualDecoderLayer, self).__init__()
+ super(DecoderLayerSANM, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
@@ -45,85 +66,161 @@
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
- def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
# tgt = self.dropout(tgt)
- if isinstance(tgt, Tuple):
- tgt, _ = tgt
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt = self.feed_forward(tgt)
x = tgt
- if self.normalize_before:
- tgt = self.norm2(tgt)
- if self.training:
- cache = None
- x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
- x = residual + self.dropout(x)
- x_self_attn = x
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ x, _ = self.self_attn(tgt, tgt_mask)
+ x = residual + self.dropout(x)
- residual = x
- if self.normalize_before:
- x = self.norm3(x)
- x = self.src_attn(x, memory, memory_mask)
- x_src_attn = x
-
- x = residual + self.dropout(x)
- return x, tgt_mask, x_self_attn, x_src_attn
-
-
-class ContextualBiasDecoder(nn.Module):
- def __init__(
- self,
- size,
- src_attn,
- dropout_rate,
- normalize_before=True,
- ):
- """Construct an DecoderLayer object."""
- super(ContextualBiasDecoder, self).__init__()
- self.size = size
- self.src_attn = src_attn
- if src_attn is not None:
- self.norm3 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.normalize_before = normalize_before
-
- def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
- x = tgt
if self.src_attn is not None:
+ residual = x
if self.normalize_before:
x = self.norm3(x)
- x = self.dropout(self.src_attn(x, memory, memory_mask))
+
+ x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
return x, tgt_mask, memory, memory_mask, cache
+ def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ """Compute decoded features.
-class ContextualParaformerDecoder(ParaformerSANMDecoder):
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ # tgt = self.dropout(tgt)
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ if self.training:
+ cache = None
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + self.dropout(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+
+ return x, tgt_mask, memory, memory_mask, cache
+
+ def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn:
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
+ x = residual + self.dropout(x)
+
+ if self.src_attn is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
+ x = residual + x
+
+ return x, memory, fsmn_cache, opt_cache
+
+@register_class("decoder_classes", "FsmnDecoderSCAMAOpt")
+class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01713
+
"""
def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- att_layer_num: int = 6,
- kernel_size: int = 21,
- sanm_shfit: int = 0,
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ att_layer_num: int = 6,
+ kernel_size: int = 21,
+ sanm_shfit: int = None,
+ concat_embeds: bool = False,
+ attention_dim: int = None,
+ tf2torch_tensor_name_prefix_torch: str = "decoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+ embed_tensor_name_prefix_tf: str = None,
):
super().__init__(
vocab_size=vocab_size,
@@ -135,14 +232,12 @@
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
+ if attention_dim is None:
+ attention_dim = encoder_output_size
- attention_dim = encoder_output_size
- if input_layer == 'none':
- self.embed = None
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
- # pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
@@ -168,14 +263,14 @@
if sanm_shfit is None:
sanm_shfit = (kernel_size - 1) // 2
self.decoders = repeat(
- att_layer_num - 1,
+ att_layer_num,
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate
+ attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
@@ -183,29 +278,6 @@
concat_after,
),
)
- self.dropout = nn.Dropout(dropout_rate)
- self.bias_decoder = ContextualBiasDecoder(
- size=attention_dim,
- src_attn=MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- dropout_rate=dropout_rate,
- normalize_before=True,
- )
- self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False)
- self.last_decoder = ContextualDecoderLayer(
- attention_dim,
- MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
- ),
- MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
- PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
- dropout_rate,
- normalize_before,
- concat_after,
- )
if num_blocks - att_layer_num <= 0:
self.decoders2 = None
else:
@@ -214,7 +286,7 @@
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
None,
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
@@ -236,16 +308,35 @@
concat_after,
),
)
+ if concat_embeds:
+ self.embed_concat_ffn = repeat(
+ 1,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim + encoder_output_size,
+ None,
+ None,
+ PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
+ adim=attention_dim),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ else:
+ self.embed_concat_ffn = None
+ self.concat_embeds = concat_embeds
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+ self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- contextual_info: torch.Tensor,
- clas_scale: float = 1.0,
- return_hidden: bool = False,
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ chunk_mask: torch.Tensor = None,
+ pre_acoustic_embeds: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -269,46 +360,122 @@
memory = hs_pad
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+ if chunk_mask is not None:
+ memory_mask = memory_mask * chunk_mask
+ if tgt_mask.size(1) != memory_mask.size(1):
+ memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
- x = tgt
+ x = self.embed(tgt)
+
+ if pre_acoustic_embeds is not None and self.concat_embeds:
+ x = torch.cat((x, pre_acoustic_embeds), dim=-1)
+ x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
+
x, tgt_mask, memory, memory_mask, _ = self.decoders(
x, tgt_mask, memory, memory_mask
)
- _, _, x_self_attn, x_src_attn = self.last_decoder(
- x, tgt_mask, memory, memory_mask
- )
-
- # contextual paraformer related
- contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
- contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
- cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
-
- if self.bias_output is not None:
- x = torch.cat([x_src_attn, cx*clas_scale], dim=2)
- x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
- x = x_self_attn + self.dropout(x)
-
if self.decoders2 is not None:
x, tgt_mask, memory, memory_mask, _ = self.decoders2(
x, tgt_mask, memory, memory_mask
)
-
x, tgt_mask, memory, memory_mask, _ = self.decoders3(
x, tgt_mask, memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
- olens = tgt_mask.sum(1)
- if self.output_layer is not None and return_hidden is False:
+ if self.output_layer is not None:
x = self.output_layer(x)
+
+ olens = tgt_mask.sum(1)
return x, olens
- def gen_tf2torch_map_dict(self):
+ def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
+ """Score."""
+ ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+ logp, state = self.forward_one_step(
+ ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
+ cache=state
+ )
+ return logp.squeeze(0), state
+ def forward_one_step(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor = None,
+ pre_acoustic_embeds: torch.Tensor = None,
+ cache: List[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+
+ Args:
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+
+ x = tgt[:, -1:]
+ tgt_mask = None
+ x = self.embed(x)
+
+ if pre_acoustic_embeds is not None and self.concat_embeds:
+ x = torch.cat((x, pre_acoustic_embeds), dim=-1)
+ x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
+
+ if cache is None:
+ cache_layer_num = len(self.decoders)
+ if self.decoders2 is not None:
+ cache_layer_num += len(self.decoders2)
+ cache = [None] * cache_layer_num
+ new_cache = []
+ # for c, decoder in zip(cache, self.decoders):
+ for i in range(self.att_layer_num):
+ decoder = self.decoders[i]
+ c = cache[i]
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+ x, tgt_mask, memory, memory_mask, cache=c
+ )
+ new_cache.append(c_ret)
+
+ if self.num_blocks - self.att_layer_num >= 1:
+ for i in range(self.num_blocks - self.att_layer_num):
+ j = i + self.att_layer_num
+ decoder = self.decoders2[i]
+ c = cache[j]
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+ x, tgt_mask, memory, memory_mask, cache=c
+ )
+ new_cache.append(c_ret)
+
+ for decoder in self.decoders3:
+ x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
+ x, tgt_mask, memory, None, cache=None
+ )
+
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.output_layer is not None:
+ y = self.output_layer(y)
+ y = torch.log_softmax(y, dim=-1)
+
+ return y, new_cache
+
+ def gen_tf2torch_map_dict(self):
+
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+ embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
map_dict_local = {
-
+
## decoder
# ffn
"{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
@@ -346,7 +513,7 @@
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
-
+
# fsmn
"{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
{"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
@@ -443,7 +610,7 @@
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
-
+
# embed_concat_ffn
"{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
{"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
@@ -480,7 +647,7 @@
"squeeze": 0,
"transpose": (1, 0),
}, # (256,1024),(1,1024,256)
-
+
# out norm
"{}.after_norm.weight".format(tensor_name_prefix_torch):
{"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
@@ -492,17 +659,18 @@
"squeeze": None,
"transpose": None,
}, # (256,),(256,)
-
+
# in embed
"{}.embed.0.weight".format(tensor_name_prefix_torch):
- {"name": "{}/w_embs".format(tensor_name_prefix_tf),
+ {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
"squeeze": None,
"transpose": None,
}, # (4235,256),(4235,256)
-
+
# out layer
"{}.output_layer.weight".format(tensor_name_prefix_torch):
- {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
+ {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
+ "{}/w_embs".format(embed_tensor_name_prefix_tf)],
"squeeze": [None, None],
"transpose": [(1, 0), None],
}, # (4235,256),(256,4235)
@@ -512,56 +680,7 @@
"squeeze": [None, None],
"transpose": [None, None],
}, # (4235,),(4235,)
-
- ## clas decoder
- # src att
- "{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # dnn
- "{}.bias_output.weight".format(tensor_name_prefix_torch):
- {"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- }, # (1024,256),(1,256,1024)
-
+
}
return map_dict_local
@@ -569,6 +688,7 @@
var_dict_tf,
var_dict_torch,
):
+
map_dict = self.gen_tf2torch_map_dict()
var_dict_torch_update = dict()
decoder_layeridx_sets = set()
@@ -598,37 +718,13 @@
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
- elif names[1] == "last_decoder":
- layeridx = 15
- name_q = name.replace("last_decoder", "decoders.layeridx")
- layeridx_bias = 0
- layeridx += layeridx_bias
- decoder_layeridx_sets.add(layeridx)
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
-
+
elif names[1] == "decoders2":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
name_q = name_q.replace("decoders2", "decoders")
layeridx_bias = len(decoder_layeridx_sets)
-
+
layeridx += layeridx_bias
if "decoders." in name:
decoder_layeridx_sets.add(layeridx)
@@ -649,11 +745,11 @@
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
-
+
elif names[1] == "decoders3":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
+
layeridx_bias = 0
layeridx += layeridx_bias
if "decoders." in name:
@@ -675,29 +771,8 @@
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
- elif names[1] == "bias_decoder":
- name_q = name
-
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
-
- elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
+
+ elif names[1] == "embed" or names[1] == "output_layer":
name_tf = map_dict[name]["name"]
if isinstance(name_tf, list):
idx_list = 0
@@ -720,7 +795,7 @@
name_tf[idx_list],
var_dict_tf[name_tf[
idx_list]].shape))
-
+
else:
data_tf = var_dict_tf[name_tf]
if map_dict[name]["squeeze"] is not None:
@@ -736,7 +811,7 @@
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
-
+
elif names[1] == "after_norm":
name_tf = map_dict[name]["name"]
data_tf = var_dict_tf[name_tf]
@@ -745,11 +820,11 @@
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
-
+
elif names[1] == "embed_concat_ffn":
layeridx = int(names[2])
name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
+
layeridx_bias = 0
layeridx += layeridx_bias
if "decoders." in name:
@@ -771,5 +846,6 @@
logging.info(
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
var_dict_tf[name_tf].shape))
-
+
return var_dict_torch_update
+
diff --git a/funasr/models/scama/sanm_encoder.py b/funasr/models/scama/sanm_encoder.py
new file mode 100644
index 0000000..c89bfb3
--- /dev/null
+++ b/funasr/models/scama/sanm_encoder.py
@@ -0,0 +1,613 @@
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from funasr.models.scama.chunk_utilis import overlap_chunk
+import numpy as np
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
+from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
+from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
+from funasr.models.transformer.positionwise_feed_forward import (
+ PositionwiseFeedForward, # noqa: H301
+)
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
+from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
+
+from funasr.models.ctc.ctc import CTC
+
+from funasr.utils.register import register_class, registry_tables
+
+class EncoderLayerSANM(nn.Module):
+ def __init__(
+ self,
+ in_size,
+ size,
+ self_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ stochastic_depth_rate=0.0,
+ ):
+ """Construct an EncoderLayer object."""
+ super(EncoderLayerSANM, self).__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(in_size)
+ self.norm2 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.in_size = in_size
+ self.size = size
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear = nn.Linear(size + size, size)
+ self.stochastic_depth_rate = stochastic_depth_rate
+ self.dropout_rate = dropout_rate
+
+ def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+ """Compute encoded features.
+
+ Args:
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+
+ """
+ skip_layer = False
+ # with stochastic depth, residual connection `x + f(x)` becomes
+ # `x <- x + 1 / (1 - p) * f(x)` at training time.
+ stoch_layer_coeff = 1.0
+ if self.training and self.stochastic_depth_rate > 0:
+ skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+ stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+ if skip_layer:
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+ return x, mask
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+
+ if self.concat_after:
+ x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+ if self.in_size == self.size:
+ x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+ else:
+ x = stoch_layer_coeff * self.concat_linear(x_concat)
+ else:
+ if self.in_size == self.size:
+ x = residual + stoch_layer_coeff * self.dropout(
+ self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ )
+ else:
+ x = stoch_layer_coeff * self.dropout(
+ self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ )
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+
+ def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+ """Compute encoded features.
+
+ Args:
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+
+ """
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+
+ if self.in_size == self.size:
+ attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+ x = residual + attn
+ else:
+ x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.feed_forward(x)
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ return x, cache
+
+
+@register_class("encoder_classes", "SANMEncoderChunkOpt")
+class SANMEncoderChunkOpt(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
+ https://arxiv.org/abs/2006.01713
+
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: Optional[str] = "conv2d",
+ pos_enc_class=SinusoidalPositionEncoder,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 1,
+ padding_idx: int = -1,
+ interctc_layer_idx: List[int] = [],
+ interctc_use_conditioning: bool = False,
+ kernel_size: int = 11,
+ sanm_shfit: int = 0,
+ selfattention_layer_type: str = "sanm",
+ chunk_size: Union[int, Sequence[int]] = (16,),
+ stride: Union[int, Sequence[int]] = (10,),
+ pad_left: Union[int, Sequence[int]] = (0,),
+ encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+ decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+ tf2torch_tensor_name_prefix_torch: str = "encoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
+ ):
+ super().__init__()
+ self._output_size = output_size
+
+ if input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(input_size, output_size),
+ torch.nn.LayerNorm(output_size),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d":
+ self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d2":
+ self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d6":
+ self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d8":
+ self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+ elif input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer is None:
+ if input_size == output_size:
+ self.embed = None
+ else:
+ self.embed = torch.nn.Linear(input_size, output_size)
+ elif input_layer == "pe":
+ self.embed = SinusoidalPositionEncoder()
+ elif input_layer == "pe_online":
+ self.embed = StreamSinusoidalPositionEncoder()
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+ self.normalize_before = normalize_before
+ if positionwise_layer_type == "linear":
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d":
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d-linear":
+ positionwise_layer = Conv1dLinear
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ else:
+ raise NotImplementedError("Support only linear or conv1d.")
+
+ if selfattention_layer_type == "selfattn":
+ encoder_selfattn_layer = MultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+ elif selfattention_layer_type == "sanm":
+ encoder_selfattn_layer = MultiHeadedAttentionSANM
+ encoder_selfattn_layer_args0 = (
+ attention_heads,
+ input_size,
+ output_size,
+ attention_dropout_rate,
+ kernel_size,
+ sanm_shfit,
+ )
+
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ output_size,
+ attention_dropout_rate,
+ kernel_size,
+ sanm_shfit,
+ )
+ self.encoders0 = repeat(
+ 1,
+ lambda lnum: EncoderLayerSANM(
+ input_size,
+ output_size,
+ encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+ self.encoders = repeat(
+ num_blocks - 1,
+ lambda lnum: EncoderLayerSANM(
+ output_size,
+ output_size,
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ if self.normalize_before:
+ self.after_norm = LayerNorm(output_size)
+
+ self.interctc_layer_idx = interctc_layer_idx
+ if len(interctc_layer_idx) > 0:
+ assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+ self.interctc_use_conditioning = interctc_use_conditioning
+ self.conditioning_layer = None
+ shfit_fsmn = (kernel_size - 1) // 2
+ self.overlap_chunk_cls = overlap_chunk(
+ chunk_size=chunk_size,
+ stride=stride,
+ pad_left=pad_left,
+ shfit_fsmn=shfit_fsmn,
+ encoder_att_look_back_factor=encoder_att_look_back_factor,
+ decoder_att_look_back_factor=decoder_att_look_back_factor,
+ )
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
+ ind: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Embed positions in tensor.
+
+ Args:
+ xs_pad: input tensor (B, L, D)
+ ilens: input length (B)
+ prev_states: Not to be used now.
+ Returns:
+ position embedded tensor and mask
+ """
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+ xs_pad *= self.output_size() ** 0.5
+ if self.embed is None:
+ xs_pad = xs_pad
+ elif (
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
+ ):
+ short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+ if short_status:
+ raise TooShortUttError(
+ f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ xs_pad.size(1),
+ limit_size,
+ )
+ xs_pad, masks = self.embed(xs_pad, masks)
+ else:
+ xs_pad = self.embed(xs_pad)
+
+ mask_shfit_chunk, mask_att_chunk_encoder = None, None
+ if self.overlap_chunk_cls is not None:
+ ilens = masks.squeeze(1).sum(1)
+ chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
+ xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+ mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
+ dtype=xs_pad.dtype)
+ mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
+ xs_pad.size(0),
+ dtype=xs_pad.dtype)
+
+ encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ intermediate_outs = []
+ if len(self.interctc_layer_idx) == 0:
+ encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ else:
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ if layer_idx + 1 in self.interctc_layer_idx:
+ encoder_out = xs_pad
+
+ # intermediate outputs are also normalized
+ if self.normalize_before:
+ encoder_out = self.after_norm(encoder_out)
+
+ intermediate_outs.append((layer_idx + 1, encoder_out))
+
+ if self.interctc_use_conditioning:
+ ctc_out = ctc.softmax(encoder_out)
+ xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ olens = masks.squeeze(1).sum(1)
+ if len(intermediate_outs) > 0:
+ return (xs_pad, intermediate_outs), olens, None
+ return xs_pad, olens, None
+
+ def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
+ if len(cache) == 0:
+ return feats
+ cache["feats"] = to_device(cache["feats"], device=feats.device)
+ overlap_feats = torch.cat((cache["feats"], feats), dim=1)
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ return overlap_feats
+
+ def forward_chunk(self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ cache: dict = None,
+ ):
+ xs_pad *= self.output_size() ** 0.5
+ if self.embed is None:
+ xs_pad = xs_pad
+ else:
+ xs_pad = self.embed(xs_pad, cache)
+ if cache["tail_chunk"]:
+ xs_pad = to_device(cache["feats"], device=xs_pad.device)
+ else:
+ xs_pad = self._add_overlap_chunk(xs_pad, cache)
+ if cache["opt"] is None:
+ cache_layer_num = len(self.encoders0) + len(self.encoders)
+ new_cache = [None] * cache_layer_num
+ else:
+ new_cache = cache["opt"]
+
+ for layer_idx, encoder_layer in enumerate(self.encoders0):
+ encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
+ xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
+
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
+ xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+ if cache["encoder_chunk_look_back"] > 0 or cache["encoder_chunk_look_back"] == -1:
+ cache["opt"] = new_cache
+
+ return xs_pad, ilens, None
+
+ def gen_tf2torch_map_dict(self):
+ tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+ tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+ map_dict_local = {
+ ## encoder
+ # cicd
+ "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (768,256),(1,256,768)
+ "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (768,),(768,)
+ "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 2, 0),
+ }, # (256,1,31),(1,31,256,1)
+ "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ # ffn
+ "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+ "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ # out norm
+ "{}.after_norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.after_norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+
+ }
+
+ return map_dict_local
+
+ def convert_tf2torch(self,
+ var_dict_tf,
+ var_dict_torch,
+ ):
+
+ map_dict = self.gen_tf2torch_map_dict()
+
+ var_dict_torch_update = dict()
+ for name in sorted(var_dict_torch.keys(), reverse=False):
+ names = name.split('.')
+ if names[0] == self.tf2torch_tensor_name_prefix_torch:
+ if names[1] == "encoders0":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+
+ name_q = name_q.replace("encoders0", "encoders")
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+ elif names[1] == "encoders":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+ layeridx_bias = 1
+ layeridx += layeridx_bias
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "after_norm":
+ name_tf = map_dict[name]["name"]
+ data_tf = var_dict_tf[name_tf]
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+ var_dict_tf[name_tf].shape))
+
+ return var_dict_torch_update
+
diff --git a/funasr/models/sond/attention.py b/funasr/models/sond/attention.py
new file mode 100644
index 0000000..290ab03
--- /dev/null
+++ b/funasr/models/sond/attention.py
@@ -0,0 +1,328 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention layer definition."""
+
+import math
+
+import numpy
+import torch
+from torch import nn
+from typing import Optional, Tuple
+
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+import funasr.models.lora.layers as lora
+
+class MultiHeadedAttention(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadedAttention, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k = nn.Linear(n_feat, n_feat)
+ self.linear_v = nn.Linear(n_feat, n_feat)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(self, query, key, value):
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+ """
+ n_batch = query.size(0)
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q, k, v
+
+ def forward_attention(self, value, scores, mask):
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, query, key, value, mask):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, mask)
+
+
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding (new implementation).
+
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+ Paper: https://arxiv.org/abs/1901.02860
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
+
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate)
+ self.zero_triu = zero_triu
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x):
+ """Compute relative positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
+ time1 means the length of query vector.
+
+ Returns:
+ torch.Tensor: Output tensor.
+
+ """
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)[
+ :, :, :, : x.size(-1) // 2 + 1
+ ] # only keep the positions from 0 to time2
+
+ if self.zero_triu:
+ ones = torch.ones((x.size(2), x.size(3)), device=x.device)
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+ return x
+
+ def forward(self, query, key, value, pos_emb, mask):
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ pos_emb (torch.Tensor): Positional embedding tensor
+ (#batch, 2*time1-1, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+ p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, 2*time1-1)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k
+ ) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask)
+
+
+
+
+
+
+class MultiHeadSelfAttention(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head, in_feat, n_feat, dropout_rate):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadSelfAttention, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(self, x):
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+ """
+ b, t, d = x.size()
+ q_k_v = self.linear_q_k_v(x)
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+ q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
+ k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+ v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q_h, k_h, v_h, v
+
+ def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ if mask_att_chunk_encoder is not None:
+ mask = mask * mask_att_chunk_encoder
+
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, x, mask, mask_att_chunk_encoder=None):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+ return att_outs
+
+
diff --git a/funasr/models/sond/e2e_diar_sond.py b/funasr/models/sond/e2e_diar_sond.py
index ff70502..21d1d59 100644
--- a/funasr/models/sond/e2e_diar_sond.py
+++ b/funasr/models/sond/e2e_diar_sond.py
@@ -18,7 +18,7 @@
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.abs_profileaug import AbsProfileAug
from funasr.layers.abs_normalize import AbsNormalize
diff --git a/funasr/models/sond/encoder/conv_encoder.py b/funasr/models/sond/encoder/conv_encoder.py
index 4c345cb..3933c01 100644
--- a/funasr/models/sond/encoder/conv_encoder.py
+++ b/funasr/models/sond/encoder/conv_encoder.py
@@ -12,7 +12,7 @@
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.encoder.abs_encoder import AbsEncoder
import math
-from funasr.models.transformer.repeat import repeat
+from funasr.models.transformer.utils.repeat import repeat
class EncoderLayer(nn.Module):
diff --git a/funasr/models/sond/encoder/fsmn_encoder.py b/funasr/models/sond/encoder/fsmn_encoder.py
index e23f3f1..129a748 100644
--- a/funasr/models/sond/encoder/fsmn_encoder.py
+++ b/funasr/models/sond/encoder/fsmn_encoder.py
@@ -12,7 +12,7 @@
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.encoder.abs_encoder import AbsEncoder
import math
-from funasr.models.transformer.repeat import repeat
+from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.utils.multi_layer_conv import FsmnFeedForward
diff --git a/funasr/models/sond/encoder/self_attention_encoder.py b/funasr/models/sond/encoder/self_attention_encoder.py
index c620295..ea974c6 100644
--- a/funasr/models/sond/encoder/self_attention_encoder.py
+++ b/funasr/models/sond/encoder/self_attention_encoder.py
@@ -9,7 +9,7 @@
from funasr.models.scama.chunk_utilis import overlap_chunk
import numpy as np
from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.transformer.attention import MultiHeadSelfAttention, MultiHeadedAttentionSANM
+from funasr.models.sond.attention import MultiHeadSelfAttention
from funasr.models.transformer.embedding import SinusoidalPositionEncoder
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
@@ -17,13 +17,13 @@
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.subsampling import Conv2dSubsampling
-from funasr.models.transformer.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.subsampling import TooShortUttError
-from funasr.models.transformer.subsampling import check_short_utt
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
diff --git a/funasr/models/specaug/specaug.py b/funasr/models/specaug/specaug.py
index e5da8e2..17f2657 100644
--- a/funasr/models/specaug/specaug.py
+++ b/funasr/models/specaug/specaug.py
@@ -7,9 +7,11 @@
from funasr.models.specaug.mask_along_axis import MaskAlongAxisVariableMaxWidth
from funasr.models.specaug.mask_along_axis import MaskAlongAxisLFR
from funasr.models.specaug.time_warp import TimeWarp
+from funasr.utils.register import register_class
import torch.nn as nn
+@register_class("specaug_classes", "SpecAug")
class SpecAug(nn.Module):
"""Implementation of SpecAug.
@@ -99,7 +101,8 @@
x, x_lengths = self.time_mask(x, x_lengths)
return x, x_lengths
-class SpecAugLFR(AbsSpecAug):
+@register_class("specaug_classes", "SpecAugLFR")
+class SpecAugLFR(nn.Module):
"""Implementation of SpecAug.
lfr_rate锛歭ow frame rate
"""
diff --git a/funasr/models/tp_aligner/e2e_tp.py b/funasr/models/tp_aligner/e2e_tp.py
index e6c4028..c675b0e 100644
--- a/funasr/models/tp_aligner/e2e_tp.py
+++ b/funasr/models/tp_aligner/e2e_tp.py
@@ -11,13 +11,13 @@
import numpy as np
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.predictor.cif import mae_loss
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.frontends.abs_frontend import AbsFrontend
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
-from funasr.models.predictor.cif import CifPredictorV3
+from funasr.models.paraformer.cif_predictor import CifPredictorV3
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
diff --git a/funasr/models/transducer/model.py b/funasr/models/transducer/model.py
index ec340f6..21de4ba 100644
--- a/funasr/models/transducer/model.py
+++ b/funasr/models/transducer/model.py
@@ -16,7 +16,6 @@
import random
import numpy as np
import time
-# from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
@@ -24,17 +23,17 @@
# from funasr.models.decoder.abs_decoder import AbsDecoder
# from funasr.models.e2e_asr_common import ErrorCalculator
# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
+# from funasr.frontends.abs_frontend import AbsFrontend
# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
+from funasr.models.paraformer.cif_predictor import mae_loss
# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
+# from funasr.models.paraformer.cif_predictor import CifPredictorV3
from funasr.models.paraformer.search import Hypothesis
from funasr.models.model_class_factory import *
@@ -46,7 +45,7 @@
@contextmanager
def autocast(enabled=True):
yield
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
@@ -98,19 +97,19 @@
super().__init__()
if frontend is not None:
- frontend_class = frontend_choices.get_class(frontend)
+ frontend_class = frontend_classes.get_class(frontend)
frontend = frontend_class(**frontend_conf)
if specaug is not None:
- specaug_class = specaug_choices.get_class(specaug)
+ specaug_class = specaug_classes.get_class(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
- normalize_class = normalize_choices.get_class(normalize)
+ normalize_class = normalize_classes.get_class(normalize)
normalize = normalize_class(**normalize_conf)
- encoder_class = encoder_choices.get_class(encoder)
+ encoder_class = encoder_classes.get_class(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
- decoder_class = decoder_choices.get_class(decoder)
+ decoder_class = decoder_classes.get_class(decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
@@ -118,7 +117,7 @@
)
decoder_output_size = decoder.output_size
- joint_network_class = joint_network_choices.get_class(decoder)
+ joint_network_class = joint_network_classes.get_class(decoder)
joint_network = joint_network_class(
vocab_size,
encoder_output_size,
@@ -521,7 +520,7 @@
audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"), frontend=self.frontend)
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
diff --git a/funasr/models/transducer/rnn_decoder.py b/funasr/models/transducer/rnn_decoder.py
index 1743f99..204f0b1 100644
--- a/funasr/models/transducer/rnn_decoder.py
+++ b/funasr/models/transducer/rnn_decoder.py
@@ -2,13 +2,12 @@
import numpy as np
import torch
+import torch.nn as nn
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.nets_utils import to_device
from funasr.models.language_model.rnn.attentions import initial_att
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.utils.get_default_kwargs import get_default_kwargs
def build_attention_list(
@@ -80,7 +79,7 @@
return att_list
-class RNNDecoder(AbsDecoder):
+class RNNDecoder(nn.Module):
def __init__(
self,
vocab_size: int,
@@ -93,7 +92,7 @@
context_residual: bool = False,
replace_sos: bool = False,
num_encs: int = 1,
- att_conf: dict = get_default_kwargs(build_attention_list),
+ att_conf: dict = None,
):
# FIXME(kamo): The parts of num_spk should be refactored more more more
if rnn_type not in {"lstm", "gru"}:
diff --git a/funasr/models/transformer/attention.py b/funasr/models/transformer/attention.py
index 04607c6..32e1e47 100644
--- a/funasr/models/transformer/attention.py
+++ b/funasr/models/transformer/attention.py
@@ -15,7 +15,7 @@
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
-import funasr.modules.lora.layers as lora
+import funasr.models.lora.layers as lora
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@@ -312,780 +312,8 @@
return self.forward_attention(v, scores, mask)
-class MultiHeadedAttentionSANM(nn.Module):
- """Multi-Head Attention layer.
-
- Args:
- n_head (int): The number of heads.
- n_feat (int): The number of features.
- dropout_rate (float): Dropout rate.
-
- """
-
- def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
- """Construct an MultiHeadedAttention object."""
- super(MultiHeadedAttentionSANM, self).__init__()
- assert n_feat % n_head == 0
- # We assume d_v always equals d_k
- self.d_k = n_feat // n_head
- self.h = n_head
- # self.linear_q = nn.Linear(n_feat, n_feat)
- # self.linear_k = nn.Linear(n_feat, n_feat)
- # self.linear_v = nn.Linear(n_feat, n_feat)
- if lora_list is not None:
- if "o" in lora_list:
- self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
- else:
- self.linear_out = nn.Linear(n_feat, n_feat)
- lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
- if lora_qkv_list == [False, False, False]:
- self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
- else:
- self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
- else:
- self.linear_out = nn.Linear(n_feat, n_feat)
- self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
- self.attn = None
- self.dropout = nn.Dropout(p=dropout_rate)
-
- self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
- # padding
- left_padding = (kernel_size - 1) // 2
- if sanm_shfit > 0:
- left_padding = left_padding + sanm_shfit
- right_padding = kernel_size - 1 - left_padding
- self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
-
- def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
- b, t, d = inputs.size()
- if mask is not None:
- mask = torch.reshape(mask, (b, -1, 1))
- if mask_shfit_chunk is not None:
- mask = mask * mask_shfit_chunk
- inputs = inputs * mask
-
- x = inputs.transpose(1, 2)
- x = self.pad_fn(x)
- x = self.fsmn_block(x)
- x = x.transpose(1, 2)
- x += inputs
- x = self.dropout(x)
- if mask is not None:
- x = x * mask
- return x
-
- def forward_qkv(self, x):
- """Transform query, key and value.
-
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
-
- Returns:
- torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
- torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
- torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
- """
- b, t, d = x.size()
- q_k_v = self.linear_q_k_v(x)
- q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
- q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
- k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
- v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
-
- return q_h, k_h, v_h, v
-
- def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
- """Compute attention context vector.
-
- Args:
- value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
- scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
- mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
- Returns:
- torch.Tensor: Transformed value (#batch, time1, d_model)
- weighted by the attention score (#batch, time1, time2).
-
- """
- n_batch = value.size(0)
- if mask is not None:
- if mask_att_chunk_encoder is not None:
- mask = mask * mask_att_chunk_encoder
-
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
-
- min_value = float(
- numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
- )
- scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
- mask, 0.0
- ) # (batch, head, time1, time2)
- else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
-
- p_attn = self.dropout(self.attn)
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
- x = (
- x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
- ) # (batch, time1, d_model)
-
- return self.linear_out(x) # (batch, time1, d_model)
-
- def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
- """Compute scaled dot product attention.
-
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
-
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
-
- """
- q_h, k_h, v_h, v = self.forward_qkv(x)
- fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
- return att_outs + fsmn_memory
-
- def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
- """Compute scaled dot product attention.
-
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
-
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
-
- """
- q_h, k_h, v_h, v = self.forward_qkv(x)
- if chunk_size is not None and look_back > 0 or look_back == -1:
- if cache is not None:
- k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
- v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
- k_h = torch.cat((cache["k"], k_h), dim=2)
- v_h = torch.cat((cache["v"], v_h), dim=2)
-
- cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
- cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
- if look_back != -1:
- cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
- cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
- else:
- cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
- "v": v_h[:, :, :-(chunk_size[2]), :]}
- cache = cache_tmp
- fsmn_memory = self.forward_fsmn(v, None)
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- att_outs = self.forward_attention(v_h, scores, None)
- return att_outs + fsmn_memory, cache
-
-
-class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
- q_h, k_h, v_h, v = self.forward_qkv(x)
- fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
- return att_outs + fsmn_memory
-
-class MultiHeadedAttentionSANMDecoder(nn.Module):
- """Multi-Head Attention layer.
-
- Args:
- n_head (int): The number of heads.
- n_feat (int): The number of features.
- dropout_rate (float): Dropout rate.
-
- """
-
- def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
- """Construct an MultiHeadedAttention object."""
- super(MultiHeadedAttentionSANMDecoder, self).__init__()
-
- self.dropout = nn.Dropout(p=dropout_rate)
-
- self.fsmn_block = nn.Conv1d(n_feat, n_feat,
- kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
- # padding
- # padding
- left_padding = (kernel_size - 1) // 2
- if sanm_shfit > 0:
- left_padding = left_padding + sanm_shfit
- right_padding = kernel_size - 1 - left_padding
- self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
- self.kernel_size = kernel_size
-
- def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
- '''
- :param x: (#batch, time1, size).
- :param mask: Mask tensor (#batch, 1, time)
- :return:
- '''
- # print("in fsmn, inputs", inputs.size())
- b, t, d = inputs.size()
- # logging.info(
- # "mask: {}".format(mask.size()))
- if mask is not None:
- mask = torch.reshape(mask, (b ,-1, 1))
- # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
- if mask_shfit_chunk is not None:
- # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
- mask = mask * mask_shfit_chunk
- # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
- # print("in fsmn, mask", mask.size())
- # print("in fsmn, inputs", inputs.size())
- inputs = inputs * mask
-
- x = inputs.transpose(1, 2)
- b, d, t = x.size()
- if cache is None:
- # print("in fsmn, cache is None, x", x.size())
-
- x = self.pad_fn(x)
- if not self.training:
- cache = x
- else:
- # print("in fsmn, cache is not None, x", x.size())
- # x = torch.cat((x, cache), dim=2)[:, :, :-1]
- # if t < self.kernel_size:
- # x = self.pad_fn(x)
- x = torch.cat((cache[:, :, 1:], x), dim=2)
- x = x[:, :, -(self.kernel_size+t-1):]
- # print("in fsmn, cache is not None, x_cat", x.size())
- cache = x
- x = self.fsmn_block(x)
- x = x.transpose(1, 2)
- # print("in fsmn, fsmn_out", x.size())
- if x.size(1) != inputs.size(1):
- inputs = inputs[:, -1, :]
-
- x = x + inputs
- x = self.dropout(x)
- if mask is not None:
- x = x * mask
- return x, cache
-
-class MultiHeadedAttentionCrossAtt(nn.Module):
- """Multi-Head Attention layer.
-
- Args:
- n_head (int): The number of heads.
- n_feat (int): The number of features.
- dropout_rate (float): Dropout rate.
-
- """
-
- def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
- """Construct an MultiHeadedAttention object."""
- super(MultiHeadedAttentionCrossAtt, self).__init__()
- assert n_feat % n_head == 0
- # We assume d_v always equals d_k
- self.d_k = n_feat // n_head
- self.h = n_head
- if lora_list is not None:
- if "q" in lora_list:
- self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
- else:
- self.linear_q = nn.Linear(n_feat, n_feat)
- lora_kv_list = ["k" in lora_list, "v" in lora_list]
- if lora_kv_list == [False, False]:
- self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
- else:
- self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
- r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
- if "o" in lora_list:
- self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
- else:
- self.linear_out = nn.Linear(n_feat, n_feat)
- else:
- self.linear_q = nn.Linear(n_feat, n_feat)
- self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
- self.linear_out = nn.Linear(n_feat, n_feat)
- self.attn = None
- self.dropout = nn.Dropout(p=dropout_rate)
-
- def forward_qkv(self, x, memory):
- """Transform query, key and value.
-
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
-
- Returns:
- torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
- torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
- torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
- """
-
- # print("in forward_qkv, x", x.size())
- b = x.size(0)
- q = self.linear_q(x)
- q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
-
- k_v = self.linear_k_v(memory)
- k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
- k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
- v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
-
-
- return q_h, k_h, v_h
-
- def forward_attention(self, value, scores, mask):
- """Compute attention context vector.
-
- Args:
- value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
- scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
- mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
- Returns:
- torch.Tensor: Transformed value (#batch, time1, d_model)
- weighted by the attention score (#batch, time1, time2).
-
- """
- n_batch = value.size(0)
- if mask is not None:
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(
- numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
- )
- # logging.info(
- # "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
- scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
- mask, 0.0
- ) # (batch, head, time1, time2)
- else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
-
- p_attn = self.dropout(self.attn)
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
- x = (
- x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
- ) # (batch, time1, d_model)
-
- return self.linear_out(x) # (batch, time1, d_model)
-
- def forward(self, x, memory, memory_mask):
- """Compute scaled dot product attention.
-
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
-
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
-
- """
- q_h, k_h, v_h = self.forward_qkv(x, memory)
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- return self.forward_attention(v_h, scores, memory_mask)
-
- def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
- """Compute scaled dot product attention.
-
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
-
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
-
- """
- q_h, k_h, v_h = self.forward_qkv(x, memory)
- if chunk_size is not None and look_back > 0:
- if cache is not None:
- k_h = torch.cat((cache["k"], k_h), dim=2)
- v_h = torch.cat((cache["v"], v_h), dim=2)
- cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
- cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
- else:
- cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
- "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
- cache = cache_tmp
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- return self.forward_attention(v_h, scores, None), cache
-
-
-class MultiHeadSelfAttention(nn.Module):
- """Multi-Head Attention layer.
-
- Args:
- n_head (int): The number of heads.
- n_feat (int): The number of features.
- dropout_rate (float): Dropout rate.
-
- """
-
- def __init__(self, n_head, in_feat, n_feat, dropout_rate):
- """Construct an MultiHeadedAttention object."""
- super(MultiHeadSelfAttention, self).__init__()
- assert n_feat % n_head == 0
- # We assume d_v always equals d_k
- self.d_k = n_feat // n_head
- self.h = n_head
- self.linear_out = nn.Linear(n_feat, n_feat)
- self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
- self.attn = None
- self.dropout = nn.Dropout(p=dropout_rate)
-
- def forward_qkv(self, x):
- """Transform query, key and value.
-
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
-
- Returns:
- torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
- torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
- torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
- """
- b, t, d = x.size()
- q_k_v = self.linear_q_k_v(x)
- q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
- q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
- k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
- v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
-
- return q_h, k_h, v_h, v
-
- def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
- """Compute attention context vector.
-
- Args:
- value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
- scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
- mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
- Returns:
- torch.Tensor: Transformed value (#batch, time1, d_model)
- weighted by the attention score (#batch, time1, time2).
-
- """
- n_batch = value.size(0)
- if mask is not None:
- if mask_att_chunk_encoder is not None:
- mask = mask * mask_att_chunk_encoder
-
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
-
- min_value = float(
- numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
- )
- scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
- mask, 0.0
- ) # (batch, head, time1, time2)
- else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
-
- p_attn = self.dropout(self.attn)
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
- x = (
- x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
- ) # (batch, time1, d_model)
-
- return self.linear_out(x) # (batch, time1, d_model)
-
- def forward(self, x, mask, mask_att_chunk_encoder=None):
- """Compute scaled dot product attention.
-
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
-
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
-
- """
- q_h, k_h, v_h, v = self.forward_qkv(x)
- q_h = q_h * self.d_k ** (-0.5)
- scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
- return att_outs
-
-class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
- """RelPositionMultiHeadedAttention definition.
- Args:
- num_heads: Number of attention heads.
- embed_size: Embedding size.
- dropout_rate: Dropout rate.
- """
-
- def __init__(
- self,
- num_heads: int,
- embed_size: int,
- dropout_rate: float = 0.0,
- simplified_attention_score: bool = False,
- ) -> None:
- """Construct an MultiHeadedAttention object."""
- super().__init__()
-
- self.d_k = embed_size // num_heads
- self.num_heads = num_heads
-
- assert self.d_k * num_heads == embed_size, (
- "embed_size (%d) must be divisible by num_heads (%d)",
- (embed_size, num_heads),
- )
-
- self.linear_q = torch.nn.Linear(embed_size, embed_size)
- self.linear_k = torch.nn.Linear(embed_size, embed_size)
- self.linear_v = torch.nn.Linear(embed_size, embed_size)
-
- self.linear_out = torch.nn.Linear(embed_size, embed_size)
-
- if simplified_attention_score:
- self.linear_pos = torch.nn.Linear(embed_size, num_heads)
-
- self.compute_att_score = self.compute_simplified_attention_score
- else:
- self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
-
- self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
- self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
- torch.nn.init.xavier_uniform_(self.pos_bias_u)
- torch.nn.init.xavier_uniform_(self.pos_bias_v)
-
- self.compute_att_score = self.compute_attention_score
-
- self.dropout = torch.nn.Dropout(p=dropout_rate)
- self.attn = None
-
- def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
- """Compute relative positional encoding.
- Args:
- x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
- left_context: Number of frames in left context.
- Returns:
- x: Output sequence. (B, H, T_1, T_2)
- """
- batch_size, n_heads, time1, n = x.shape
- time2 = time1 + left_context
-
- batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
-
- return x.as_strided(
- (batch_size, n_heads, time1, time2),
- (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
- storage_offset=(n_stride * (time1 - 1)),
- )
-
- def compute_simplified_attention_score(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- pos_enc: torch.Tensor,
- left_context: int = 0,
- ) -> torch.Tensor:
- """Simplified attention score computation.
- Reference: https://github.com/k2-fsa/icefall/pull/458
- Args:
- query: Transformed query tensor. (B, H, T_1, d_k)
- key: Transformed key tensor. (B, H, T_2, d_k)
- pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
- left_context: Number of frames in left context.
- Returns:
- : Attention score. (B, H, T_1, T_2)
- """
- pos_enc = self.linear_pos(pos_enc)
-
- matrix_ac = torch.matmul(query, key.transpose(2, 3))
-
- matrix_bd = self.rel_shift(
- pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
- left_context=left_context,
- )
-
- return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
-
- def compute_attention_score(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- pos_enc: torch.Tensor,
- left_context: int = 0,
- ) -> torch.Tensor:
- """Attention score computation.
- Args:
- query: Transformed query tensor. (B, H, T_1, d_k)
- key: Transformed key tensor. (B, H, T_2, d_k)
- pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
- left_context: Number of frames in left context.
- Returns:
- : Attention score. (B, H, T_1, T_2)
- """
- p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
-
- query = query.transpose(1, 2)
- q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
- q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
-
- matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
-
- matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
- matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
-
- return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
-
- def forward_qkv(
- self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Transform query, key and value.
- Args:
- query: Query tensor. (B, T_1, size)
- key: Key tensor. (B, T_2, size)
- v: Value tensor. (B, T_2, size)
- Returns:
- q: Transformed query tensor. (B, H, T_1, d_k)
- k: Transformed key tensor. (B, H, T_2, d_k)
- v: Transformed value tensor. (B, H, T_2, d_k)
- """
- n_batch = query.size(0)
-
- q = (
- self.linear_q(query)
- .view(n_batch, -1, self.num_heads, self.d_k)
- .transpose(1, 2)
- )
- k = (
- self.linear_k(key)
- .view(n_batch, -1, self.num_heads, self.d_k)
- .transpose(1, 2)
- )
- v = (
- self.linear_v(value)
- .view(n_batch, -1, self.num_heads, self.d_k)
- .transpose(1, 2)
- )
-
- return q, k, v
-
- def forward_attention(
- self,
- value: torch.Tensor,
- scores: torch.Tensor,
- mask: torch.Tensor,
- chunk_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- """Compute attention context vector.
- Args:
- value: Transformed value. (B, H, T_2, d_k)
- scores: Attention score. (B, H, T_1, T_2)
- mask: Source mask. (B, T_2)
- chunk_mask: Chunk mask. (T_1, T_1)
- Returns:
- attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
- """
- batch_size = scores.size(0)
- mask = mask.unsqueeze(1).unsqueeze(2)
- if chunk_mask is not None:
- mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
- scores = scores.masked_fill(mask, float("-inf"))
- self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
-
- attn_output = self.dropout(self.attn)
- attn_output = torch.matmul(attn_output, value)
-
- attn_output = self.linear_out(
- attn_output.transpose(1, 2)
- .contiguous()
- .view(batch_size, -1, self.num_heads * self.d_k)
- )
-
- return attn_output
-
- def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- pos_enc: torch.Tensor,
- mask: torch.Tensor,
- chunk_mask: Optional[torch.Tensor] = None,
- left_context: int = 0,
- ) -> torch.Tensor:
- """Compute scaled dot product attention with rel. positional encoding.
- Args:
- query: Query tensor. (B, T_1, size)
- key: Key tensor. (B, T_2, size)
- value: Value tensor. (B, T_2, size)
- pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
- mask: Source mask. (B, T_2)
- chunk_mask: Chunk mask. (T_1, T_1)
- left_context: Number of frames in left context.
- Returns:
- : Output tensor. (B, T_1, H * d_k)
- """
- q, k, v = self.forward_qkv(query, key, value)
- scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
- return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
-class CosineDistanceAttention(nn.Module):
- """ Compute Cosine Distance between spk decoder output and speaker profile
- Args:
- profile_path: speaker profile file path (.npy file)
- """
- def __init__(self):
- super().__init__()
- self.softmax = nn.Softmax(dim=-1)
- def forward(self, spk_decoder_out, profile, profile_lens=None):
- """
- Args:
- spk_decoder_out(torch.Tensor):(B, L, D)
- spk_profiles(torch.Tensor):(B, N, D)
- """
- x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D)
- if profile_lens is not None:
-
- mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
- min_value = float(
- numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
- )
- weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
- weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N)
- else:
- x = x[:, -1:, :, :]
- weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
- weights = self.softmax(weights_not_softmax) # (B, 1, N)
- spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D)
- return spk_embedding, weights
diff --git a/funasr/models/transformer/decoder.py b/funasr/models/transformer/decoder.py
new file mode 100644
index 0000000..3e8d224
--- /dev/null
+++ b/funasr/models/transformer/decoder.py
@@ -0,0 +1,647 @@
+# Copyright 2019 Shigeki Karita
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Decoder definition."""
+from typing import Any
+from typing import List
+from typing import Sequence
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+from funasr.models.transformer.attention import MultiHeadedAttention
+from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
+from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.lightconv import LightweightConvolution
+from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
+from funasr.models.transformer.utils.mask import subsequent_mask
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.positionwise_feed_forward import (
+ PositionwiseFeedForward, # noqa: H301
+)
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
+
+from funasr.utils.register import register_class, registry_tables
+
+class DecoderLayer(nn.Module):
+ """Single decoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ src_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+
+ """
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an DecoderLayer object."""
+ super(DecoderLayer, self).__init__()
+ self.size = size
+ self.self_attn = self_attn
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ self.norm2 = LayerNorm(size)
+ self.norm3 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+
+ if cache is None:
+ tgt_q = tgt
+ tgt_q_mask = tgt_mask
+ else:
+ # compute only the last frame query keeping dim: max_time_out -> 1
+ assert cache.shape == (
+ tgt.shape[0],
+ tgt.shape[1] - 1,
+ self.size,
+ ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
+ tgt_q = tgt[:, -1:, :]
+ residual = residual[:, -1:, :]
+ tgt_q_mask = None
+ if tgt_mask is not None:
+ tgt_q_mask = tgt_mask[:, -1:, :]
+
+ if self.concat_after:
+ tgt_concat = torch.cat(
+ (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
+ )
+ x = residual + self.concat_linear1(tgt_concat)
+ else:
+ x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ if self.concat_after:
+ x_concat = torch.cat(
+ (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
+ )
+ x = residual + self.concat_linear2(x_concat)
+ else:
+ x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm3(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, tgt_mask, memory, memory_mask
+
+
+class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
+ """Base class of Transfomer decoder module.
+
+ Args:
+ vocab_size: output dim
+ encoder_output_size: dimension of attention
+ attention_heads: the number of heads of multi head attention
+ linear_units: the number of units of position-wise feed forward
+ num_blocks: the number of decoder blocks
+ dropout_rate: dropout rate
+ self_attention_dropout_rate: dropout rate for attention
+ input_layer: input layer type
+ use_output_layer: whether to use output layer
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+ normalize_before: whether to use layer_norm before the first block
+ concat_after: whether to concat attention layer's input and output
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied.
+ i.e. x -> x + att(x)
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ ):
+ super().__init__()
+ attention_dim = encoder_output_size
+
+ if input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(vocab_size, attention_dim),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ elif input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(vocab_size, attention_dim),
+ torch.nn.LayerNorm(attention_dim),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ else:
+ raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+ self.normalize_before = normalize_before
+ if self.normalize_before:
+ self.after_norm = LayerNorm(attention_dim)
+ if use_output_layer:
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+ else:
+ self.output_layer = None
+
+ # Must set by the inheritance
+ self.decoders = None
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ tgt = ys_in_pad
+ # tgt_mask: (B, 1, L)
+ tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+ # m: (1, L, L)
+ m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
+ # tgt_mask: (B, L, L)
+ tgt_mask = tgt_mask & m
+
+ memory = hs_pad
+ memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
+ memory.device
+ )
+ # Padding for Longformer
+ if memory_mask.shape[-1] != memory.shape[1]:
+ padlen = memory.shape[1] - memory_mask.shape[-1]
+ memory_mask = torch.nn.functional.pad(
+ memory_mask, (0, padlen), "constant", False
+ )
+
+ x = self.embed(tgt)
+ x, tgt_mask, memory, memory_mask = self.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.output_layer is not None:
+ x = self.output_layer(x)
+
+ olens = tgt_mask.sum(1)
+ return x, olens
+
+ def forward_one_step(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ cache: List[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+
+ Args:
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+ x = self.embed(tgt)
+ if cache is None:
+ cache = [None] * len(self.decoders)
+ new_cache = []
+ for c, decoder in zip(cache, self.decoders):
+ x, tgt_mask, memory, memory_mask = decoder(
+ x, tgt_mask, memory, None, cache=c
+ )
+ new_cache.append(x)
+
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.output_layer is not None:
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
+
+ return y, new_cache
+
+ def score(self, ys, state, x):
+ """Score."""
+ ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
+ logp, state = self.forward_one_step(
+ ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
+ )
+ return logp.squeeze(0), state
+
+ def batch_score(
+ self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
+ ) -> Tuple[torch.Tensor, List[Any]]:
+ """Score new token batch.
+
+ Args:
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+ states (List[Any]): Scorer states for prefix tokens.
+ xs (torch.Tensor):
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+ Returns:
+ tuple[torch.Tensor, List[Any]]: Tuple of
+ batchfied scores for next token with shape of `(n_batch, n_vocab)`
+ and next state list for ys.
+
+ """
+ # merge states
+ n_batch = len(ys)
+ n_layers = len(self.decoders)
+ if states[0] is None:
+ batch_state = None
+ else:
+ # transpose state of [batch, layer] into [layer, batch]
+ batch_state = [
+ torch.stack([states[b][i] for b in range(n_batch)])
+ for i in range(n_layers)
+ ]
+
+ # batch decoding
+ ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
+ logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
+
+ # transpose state of [layer, batch] into [batch, layer]
+ state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
+ return logp, state_list
+
+@register_class("decoder_classes", "TransformerDecoder")
+class TransformerDecoder(BaseTransformerDecoder):
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_output_layer=use_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+
+ attention_dim = encoder_output_size
+ self.decoders = repeat(
+ num_blocks,
+ lambda lnum: DecoderLayer(
+ attention_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, self_attention_dropout_rate
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+
+@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder")
+class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ conv_wshare: int = 4,
+ conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+ conv_usebias: int = False,
+ ):
+ if len(conv_kernel_length) != num_blocks:
+ raise ValueError(
+ "conv_kernel_length must have equal number of values to num_blocks: "
+ f"{len(conv_kernel_length)} != {num_blocks}"
+ )
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_output_layer=use_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+
+ attention_dim = encoder_output_size
+ self.decoders = repeat(
+ num_blocks,
+ lambda lnum: DecoderLayer(
+ attention_dim,
+ LightweightConvolution(
+ wshare=conv_wshare,
+ n_feat=attention_dim,
+ dropout_rate=self_attention_dropout_rate,
+ kernel_size=conv_kernel_length[lnum],
+ use_kernel_mask=True,
+ use_bias=conv_usebias,
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder")
+class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ conv_wshare: int = 4,
+ conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+ conv_usebias: int = False,
+ ):
+ if len(conv_kernel_length) != num_blocks:
+ raise ValueError(
+ "conv_kernel_length must have equal number of values to num_blocks: "
+ f"{len(conv_kernel_length)} != {num_blocks}"
+ )
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_output_layer=use_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+
+ attention_dim = encoder_output_size
+ self.decoders = repeat(
+ num_blocks,
+ lambda lnum: DecoderLayer(
+ attention_dim,
+ LightweightConvolution2D(
+ wshare=conv_wshare,
+ n_feat=attention_dim,
+ dropout_rate=self_attention_dropout_rate,
+ kernel_size=conv_kernel_length[lnum],
+ use_kernel_mask=True,
+ use_bias=conv_usebias,
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+
+@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder")
+class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ conv_wshare: int = 4,
+ conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+ conv_usebias: int = False,
+ ):
+ if len(conv_kernel_length) != num_blocks:
+ raise ValueError(
+ "conv_kernel_length must have equal number of values to num_blocks: "
+ f"{len(conv_kernel_length)} != {num_blocks}"
+ )
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_output_layer=use_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+ attention_dim = encoder_output_size
+
+ self.decoders = repeat(
+ num_blocks,
+ lambda lnum: DecoderLayer(
+ attention_dim,
+ DynamicConvolution(
+ wshare=conv_wshare,
+ n_feat=attention_dim,
+ dropout_rate=self_attention_dropout_rate,
+ kernel_size=conv_kernel_length[lnum],
+ use_kernel_mask=True,
+ use_bias=conv_usebias,
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder")
+class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ conv_wshare: int = 4,
+ conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+ conv_usebias: int = False,
+ ):
+ if len(conv_kernel_length) != num_blocks:
+ raise ValueError(
+ "conv_kernel_length must have equal number of values to num_blocks: "
+ f"{len(conv_kernel_length)} != {num_blocks}"
+ )
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_output_layer=use_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+ attention_dim = encoder_output_size
+
+ self.decoders = repeat(
+ num_blocks,
+ lambda lnum: DecoderLayer(
+ attention_dim,
+ DynamicConvolution2D(
+ wshare=conv_wshare,
+ n_feat=attention_dim,
+ dropout_rate=self_attention_dropout_rate,
+ kernel_size=conv_kernel_length[lnum],
+ use_kernel_mask=True,
+ use_bias=conv_usebias,
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
diff --git a/funasr/models/transformer/encoder.py b/funasr/models/transformer/encoder.py
new file mode 100644
index 0000000..a3d5249
--- /dev/null
+++ b/funasr/models/transformer/encoder.py
@@ -0,0 +1,332 @@
+# Copyright 2019 Shigeki Karita
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Transformer encoder definition."""
+
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import torch
+from torch import nn
+import logging
+
+from funasr.models.transformer.attention import MultiHeadedAttention
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
+from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.ctc.ctc import CTC
+
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
+
+from funasr.utils.register import register_class
+
+class EncoderLayer(nn.Module):
+ """Encoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
+ can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+ stochastic_depth_rate (float): Proability to skip this layer.
+ During training, the layer may skip residual computation and return input
+ as-is with given probability.
+ """
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ stochastic_depth_rate=0.0,
+ ):
+ """Construct an EncoderLayer object."""
+ super(EncoderLayer, self).__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ self.norm2 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear = nn.Linear(size + size, size)
+ self.stochastic_depth_rate = stochastic_depth_rate
+
+ def forward(self, x, mask, cache=None):
+ """Compute encoded features.
+
+ Args:
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+
+ """
+ skip_layer = False
+ # with stochastic depth, residual connection `x + f(x)` becomes
+ # `x <- x + 1 / (1 - p) * f(x)` at training time.
+ stoch_layer_coeff = 1.0
+ if self.training and self.stochastic_depth_rate > 0:
+ skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+ stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+ if skip_layer:
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+ return x, mask
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+
+ if cache is None:
+ x_q = x
+ else:
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
+ x_q = x[:, -1:, :]
+ residual = residual[:, -1:, :]
+ mask = None if mask is None else mask[:, -1:, :]
+
+ if self.concat_after:
+ x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
+ x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+ else:
+ x = residual + stoch_layer_coeff * self.dropout(
+ self.self_attn(x_q, x, x, mask)
+ )
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, mask
+
+@register_class("encoder_classes", "TransformerEncoder")
+class TransformerEncoder(nn.Module):
+ """Transformer encoder module.
+
+ Args:
+ input_size: input dim
+ output_size: dimension of attention
+ attention_heads: the number of heads of multi head attention
+ linear_units: the number of units of position-wise feed forward
+ num_blocks: the number of decoder blocks
+ dropout_rate: dropout rate
+ attention_dropout_rate: dropout rate in attention
+ positional_dropout_rate: dropout rate after adding positional encoding
+ input_layer: input layer type
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+ normalize_before: whether to use layer_norm before the first block
+ concat_after: whether to concat attention layer's input and output
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied.
+ i.e. x -> x + att(x)
+ positionwise_layer_type: linear of conv1d
+ positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
+ padding_idx: padding_idx for input_layer=embed
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: Optional[str] = "conv2d",
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 1,
+ padding_idx: int = -1,
+ interctc_layer_idx: List[int] = [],
+ interctc_use_conditioning: bool = False,
+ ):
+ super().__init__()
+ self._output_size = output_size
+
+ if input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(input_size, output_size),
+ torch.nn.LayerNorm(output_size),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d":
+ self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d2":
+ self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d6":
+ self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d8":
+ self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+ elif input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer is None:
+ if input_size == output_size:
+ self.embed = None
+ else:
+ self.embed = torch.nn.Linear(input_size, output_size)
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+ self.normalize_before = normalize_before
+ if positionwise_layer_type == "linear":
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d":
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d-linear":
+ positionwise_layer = Conv1dLinear
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ else:
+ raise NotImplementedError("Support only linear or conv1d.")
+ self.encoders = repeat(
+ num_blocks,
+ lambda lnum: EncoderLayer(
+ output_size,
+ MultiHeadedAttention(
+ attention_heads, output_size, attention_dropout_rate
+ ),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ if self.normalize_before:
+ self.after_norm = LayerNorm(output_size)
+
+ self.interctc_layer_idx = interctc_layer_idx
+ if len(interctc_layer_idx) > 0:
+ assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+ self.interctc_use_conditioning = interctc_use_conditioning
+ self.conditioning_layer = None
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Embed positions in tensor.
+
+ Args:
+ xs_pad: input tensor (B, L, D)
+ ilens: input length (B)
+ prev_states: Not to be used now.
+ Returns:
+ position embedded tensor and mask
+ """
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+
+ if self.embed is None:
+ xs_pad = xs_pad
+ elif (
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
+ ):
+ short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+ if short_status:
+ raise TooShortUttError(
+ f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ xs_pad.size(1),
+ limit_size,
+ )
+ xs_pad, masks = self.embed(xs_pad, masks)
+ else:
+ xs_pad = self.embed(xs_pad)
+
+ intermediate_outs = []
+ if len(self.interctc_layer_idx) == 0:
+ xs_pad, masks = self.encoders(xs_pad, masks)
+ else:
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ xs_pad, masks = encoder_layer(xs_pad, masks)
+
+ if layer_idx + 1 in self.interctc_layer_idx:
+ encoder_out = xs_pad
+
+ # intermediate outputs are also normalized
+ if self.normalize_before:
+ encoder_out = self.after_norm(encoder_out)
+
+ intermediate_outs.append((layer_idx + 1, encoder_out))
+
+ if self.interctc_use_conditioning:
+ ctc_out = ctc.softmax(encoder_out)
+ xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ olens = masks.squeeze(1).sum(1)
+ if len(intermediate_outs) > 0:
+ return (xs_pad, intermediate_outs), olens, None
+ return xs_pad, olens, None
+
diff --git a/funasr/models/transformer/model.py b/funasr/models/transformer/model.py
index 107d62f..e4eae10 100644
--- a/funasr/models/transformer/model.py
+++ b/funasr/models/transformer/model.py
@@ -1,56 +1,23 @@
import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
+from typing import Union, Dict, List, Tuple, Optional
+
+import time
import torch
import torch.nn as nn
-import random
-import numpy as np
-import time
-# from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
+from torch.cuda.amp import autocast
+
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.ctc.ctc import CTC
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy
# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
-from funasr.models.paraformer.search import Hypothesis
-
-from funasr.models.model_class_factory import *
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
-else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
+from funasr.utils.register import register_class, registry_tables
-
+@register_class("model_classes", "Transformer")
class Transformer(nn.Module):
"""CTC-attention hybrid Encoder-Decoder model"""
@@ -93,19 +60,19 @@
super().__init__()
if frontend is not None:
- frontend_class = frontend_choices.get_class(frontend)
+ frontend_class = registry_tables.frontend_classes.get_class(frontend.lower())
frontend = frontend_class(**frontend_conf)
if specaug is not None:
- specaug_class = specaug_choices.get_class(specaug)
+ specaug_class = registry_tables.specaug_classes.get_class(specaug.lower())
specaug = specaug_class(**specaug_conf)
if normalize is not None:
- normalize_class = normalize_choices.get_class(normalize)
+ normalize_class = registry_tables.normalize_classes.get_class(normalize.lower())
normalize = normalize_class(**normalize_conf)
- encoder_class = encoder_choices.get_class(encoder)
+ encoder_class = registry_tables.encoder_classes.get_class(encoder.lower())
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
if decoder is not None:
- decoder_class = decoder_choices.get_class(decoder)
+ decoder_class = registry_tables.decoder_classes.get_class(decoder.lower())
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
@@ -239,7 +206,7 @@
) * loss_ctc + self.interctc_weight * loss_interctc
# decoder: Attention decoder branch
- loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
+ loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
@@ -428,7 +395,7 @@
audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"), frontend=self.frontend)
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
diff --git a/funasr/models/transformer/positionwise_feed_forward.py b/funasr/models/transformer/positionwise_feed_forward.py
index ffa0f4e..7ca55cb 100644
--- a/funasr/models/transformer/positionwise_feed_forward.py
+++ b/funasr/models/transformer/positionwise_feed_forward.py
@@ -34,25 +34,3 @@
return self.w_2(self.dropout(self.activation(self.w_1(x))))
-class PositionwiseFeedForwardDecoderSANM(torch.nn.Module):
- """Positionwise feed forward layer.
-
- Args:
- idim (int): Input dimenstion.
- hidden_units (int): The number of hidden units.
- dropout_rate (float): Dropout rate.
-
- """
-
- def __init__(self, idim, hidden_units, dropout_rate, adim=None, activation=torch.nn.ReLU()):
- """Construct an PositionwiseFeedForward object."""
- super(PositionwiseFeedForwardDecoderSANM, self).__init__()
- self.w_1 = torch.nn.Linear(idim, hidden_units)
- self.w_2 = torch.nn.Linear(hidden_units, idim if adim is None else adim, bias=False)
- self.dropout = torch.nn.Dropout(dropout_rate)
- self.activation = activation
- self.norm = LayerNorm(hidden_units)
-
- def forward(self, x):
- """Forward function."""
- return self.w_2(self.norm(self.dropout(self.activation(self.w_1(x)))))
diff --git a/funasr/models/transformer/utils/nets_utils.py b/funasr/models/transformer/utils/nets_utils.py
index 0beb083..ce151a0 100644
--- a/funasr/models/transformer/utils/nets_utils.py
+++ b/funasr/models/transformer/utils/nets_utils.py
@@ -342,27 +342,6 @@
return ret
-def th_accuracy(pad_outputs, pad_targets, ignore_label):
- """Calculate accuracy.
-
- Args:
- pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
- pad_targets (LongTensor): Target label tensors (B, Lmax).
- ignore_label (int): Ignore label id.
-
- Returns:
- float: Accuracy value (0.0 - 1.0).
-
- """
- pad_pred = pad_outputs.view(
- pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
- ).argmax(2)
- mask = pad_targets != ignore_label
- numerator = torch.sum(
- pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
- )
- denominator = torch.sum(mask)
- return float(numerator) / float(denominator)
def to_torch_tensor(x):
diff --git a/funasr/models/uniasr/e2e_uni_asr.py b/funasr/models/uniasr/e2e_uni_asr.py
index 0fb4039..46c5832 100644
--- a/funasr/models/uniasr/e2e_uni_asr.py
+++ b/funasr/models/uniasr/e2e_uni_asr.py
@@ -10,15 +10,15 @@
import torch
from funasr.models.e2e_asr_common import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
@@ -26,7 +26,7 @@
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
from funasr.models.scama.chunk_utilis import sequence_mask
-from funasr.models.predictor.cif import mae_loss
+from funasr.models.paraformer.cif_predictor import mae_loss
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
diff --git a/funasr/models/xvector/e2e_sv.py b/funasr/models/xvector/e2e_sv.py
index 3eac9ef..fce0dfe 100644
--- a/funasr/models/xvector/e2e_sv.py
+++ b/funasr/models/xvector/e2e_sv.py
@@ -20,13 +20,13 @@
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
diff --git a/funasr/optimizers/__init__.py b/funasr/optimizers/__init__.py
index b4dfe5d..177f89e 100644
--- a/funasr/optimizers/__init__.py
+++ b/funasr/optimizers/__init__.py
@@ -2,7 +2,7 @@
from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.optimizers.sgd import SGD
-optim_choices = dict(
+optim_classes = dict(
adam=torch.optim.Adam,
fairseq_adam=FairseqAdam,
adamw=torch.optim.AdamW,
diff --git a/funasr/schedulers/__init__.py b/funasr/schedulers/__init__.py
index 7bb8118..2ee3a9e 100644
--- a/funasr/schedulers/__init__.py
+++ b/funasr/schedulers/__init__.py
@@ -7,7 +7,7 @@
from funasr.schedulers.tri_stage_scheduler import TriStageLR
from funasr.schedulers.warmup_lr import WarmupLR
-scheduler_choices = dict(
+scheduler_classes = dict(
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
lambdalr=torch.optim.lr_scheduler.LambdaLR,
steplr=torch.optim.lr_scheduler.StepLR,
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index bbfd173..349ebc0 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -7,6 +7,7 @@
from typing import Iterable
from typing import List
from typing import Union
+import json
import numpy as np
@@ -27,7 +28,7 @@
):
if token_list is not None:
- if isinstance(token_list, (Path, str)):
+ if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
token_list = Path(token_list)
self.token_list_repr = str(token_list)
self.token_list: List[str] = []
@@ -36,7 +37,14 @@
for idx, line in enumerate(f):
line = line.rstrip()
self.token_list.append(line)
-
+ elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
+ token_list = Path(token_list)
+ self.token_list_repr = str(token_list)
+ self.token_list: List[str] = []
+
+ with open('data.json', 'r', encoding='utf-8') as f:
+ self.token_list = json.loads(f.read())
+
else:
self.token_list: List[str] = list(token_list)
self.token_list_repr = ""
diff --git a/funasr/tokenizer/char_tokenizer.py b/funasr/tokenizer/char_tokenizer.py
index 80528a2..23ff743 100644
--- a/funasr/tokenizer/char_tokenizer.py
+++ b/funasr/tokenizer/char_tokenizer.py
@@ -4,10 +4,10 @@
from typing import Union
import warnings
-
-from funasr.tokenizer.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.abs_tokenizer import BaseTokenizer
+from funasr.utils.register import register_class
+@register_class("tokenizer_classes", "CharTokenizer")
class CharTokenizer(BaseTokenizer):
def __init__(
self,
diff --git a/funasr/tokenizer/phoneme_tokenizer.py b/funasr/tokenizer/phoneme_tokenizer.py
index 04b423b..f1f7168 100644
--- a/funasr/tokenizer/phoneme_tokenizer.py
+++ b/funasr/tokenizer/phoneme_tokenizer.py
@@ -13,7 +13,7 @@
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
-g2p_choices = [
+g2p_classes = [
None,
"g2p_en",
"g2p_en_no_space",
diff --git a/funasr/train_utils/collect_stats.py b/funasr/train_utils/collect_stats.py
deleted file mode 100644
index 7454ccf..0000000
--- a/funasr/train_utils/collect_stats.py
+++ /dev/null
@@ -1,124 +0,0 @@
-from collections import defaultdict
-import logging
-from pathlib import Path
-from typing import Dict
-from typing import Iterable
-from typing import List
-from typing import Optional
-from typing import Tuple
-
-import numpy as np
-import torch
-from torch.nn.parallel import data_parallel
-from torch.utils.data import DataLoader
-
-from funasr.utils.datadir_writer import DatadirWriter
-from funasr.fileio.npy_scp import NpyScpWriter
-from funasr.train_utils.device_funcs import to_device
-from funasr.train_utils.forward_adaptor import ForwardAdaptor
-from funasr.models.base_model import FunASRModel
-
-
-@torch.no_grad()
-def collect_stats(
- model: FunASRModel,
- train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
- valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
- output_dir: Path,
- ngpu: Optional[int],
- log_interval: Optional[int],
- write_collected_feats: bool,
-) -> None:
- """Perform on collect_stats mode.
-
- Running for deriving the shape information from data
- and gathering statistics.
- This method is used before executing train().
-
- """
-
- npy_scp_writers = {}
- for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]):
- if log_interval is None:
- try:
- log_interval = max(len(itr) // 20, 10)
- except TypeError:
- log_interval = 100
-
- sum_dict = defaultdict(lambda: 0)
- sq_dict = defaultdict(lambda: 0)
- count_dict = defaultdict(lambda: 0)
-
- with DatadirWriter(output_dir / mode) as datadir_writer:
- for iiter, (keys, batch) in enumerate(itr, 1):
- batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
-
- # 1. Write shape file
- for name in batch:
- if name.endswith("_lengths"):
- continue
- for i, (key, data) in enumerate(zip(keys, batch[name])):
- if f"{name}_lengths" in batch:
- lg = int(batch[f"{name}_lengths"][i])
- data = data[:lg]
- datadir_writer[f"{name}_shape"][key] = ",".join(
- map(str, data.shape)
- )
-
- # 2. Extract feats
- if ngpu <= 1:
- data = model.collect_feats(**batch)
- else:
- # Note that data_parallel can parallelize only "forward()"
- data = data_parallel(
- ForwardAdaptor(model, "collect_feats"),
- (),
- range(ngpu),
- module_kwargs=batch,
- )
-
- # 3. Calculate sum and square sum
- for key, v in data.items():
- for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())):
- # Truncate zero-padding region
- if f"{key}_lengths" in data:
- length = data[f"{key}_lengths"][i]
- # seq: (Length, Dim, ...)
- seq = seq[:length]
- else:
- # seq: (Dim, ...) -> (1, Dim, ...)
- seq = seq[None]
- # Accumulate value, its square, and count
- sum_dict[key] += seq.sum(0)
- sq_dict[key] += (seq**2).sum(0)
- count_dict[key] += len(seq)
-
- # 4. [Option] Write derived features as npy format file.
- if write_collected_feats:
- # Instantiate NpyScpWriter for the first iteration
- if (key, mode) not in npy_scp_writers:
- p = output_dir / mode / "collect_feats"
- npy_scp_writers[(key, mode)] = NpyScpWriter(
- p / f"data_{key}", p / f"{key}.scp"
- )
- # Save array as npy file
- npy_scp_writers[(key, mode)][uttid] = seq
-
- if iiter % log_interval == 0:
- logging.info(f"Niter: {iiter}")
-
- for key in sum_dict:
- np.savez(
- output_dir / mode / f"{key}_stats.npz",
- count=count_dict[key],
- sum=sum_dict[key],
- sum_square=sq_dict[key],
- )
-
- # batch_keys and stats_keys are used by aggregate_stats_dirs.py
- with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f:
- f.write(
- "\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n"
- )
- with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f:
- f.write("\n".join(sum_dict) + "\n")
diff --git a/funasr/train_utils/model_summary.py b/funasr/train_utils/model_summary.py
index 8d7f14f..1001160 100644
--- a/funasr/train_utils/model_summary.py
+++ b/funasr/train_utils/model_summary.py
@@ -1,4 +1,3 @@
-import humanfriendly
import numpy as np
import torch
@@ -59,12 +58,7 @@
message += (
f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n"
)
- num_bytes = humanfriendly.format_size(
- sum(
- p.numel() * to_bytes(p.dtype) for p in model.parameters() if p.requires_grad
- )
- )
- message += f" Size: {num_bytes}\n"
+
dtype = next(iter(model.parameters())).dtype
message += f" Type: {dtype}"
return message
diff --git a/funasr/train_utils/pack_funcs.py b/funasr/train_utils/pack_funcs.py
deleted file mode 100644
index fe365d8..0000000
--- a/funasr/train_utils/pack_funcs.py
+++ /dev/null
@@ -1,302 +0,0 @@
-from datetime import datetime
-from io import BytesIO
-from io import TextIOWrapper
-import os
-from pathlib import Path
-import sys
-import tarfile
-from typing import Dict
-from typing import Iterable
-from typing import Optional
-from typing import Union
-import zipfile
-
-import yaml
-
-
-class Archiver:
- def __init__(self, file, mode="r"):
- if Path(file).suffix == ".tar":
- self.type = "tar"
- elif Path(file).suffix == ".tgz" or Path(file).suffixes == [".tar", ".gz"]:
- self.type = "tar"
- if mode == "w":
- mode = "w:gz"
- elif Path(file).suffix == ".tbz2" or Path(file).suffixes == [".tar", ".bz2"]:
- self.type = "tar"
- if mode == "w":
- mode = "w:bz2"
- elif Path(file).suffix == ".txz" or Path(file).suffixes == [".tar", ".xz"]:
- self.type = "tar"
- if mode == "w":
- mode = "w:xz"
- elif Path(file).suffix == ".zip":
- self.type = "zip"
- else:
- raise ValueError(f"Cannot detect archive format: type={file}")
-
- if self.type == "tar":
- self.fopen = tarfile.open(file, mode=mode)
- elif self.type == "zip":
-
- self.fopen = zipfile.ZipFile(file, mode=mode)
- else:
- raise ValueError(f"Not supported: type={type}")
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.fopen.close()
-
- def close(self):
- self.fopen.close()
-
- def __iter__(self):
- if self.type == "tar":
- return iter(self.fopen)
- elif self.type == "zip":
- return iter(self.fopen.infolist())
- else:
- raise ValueError(f"Not supported: type={self.type}")
-
- def add(self, filename, arcname=None, recursive: bool = True):
- if arcname is not None:
- print(f"adding: {arcname}")
- else:
- print(f"adding: {filename}")
-
- if recursive and Path(filename).is_dir():
- for f in Path(filename).glob("**/*"):
- if f.is_dir():
- continue
-
- if arcname is not None:
- _arcname = Path(arcname) / f
- else:
- _arcname = None
-
- self.add(f, _arcname)
- return
-
- if self.type == "tar":
- return self.fopen.add(filename, arcname)
- elif self.type == "zip":
- return self.fopen.write(filename, arcname)
- else:
- raise ValueError(f"Not supported: type={self.type}")
-
- def addfile(self, info, fileobj):
- print(f"adding: {self.get_name_from_info(info)}")
-
- if self.type == "tar":
- return self.fopen.addfile(info, fileobj)
- elif self.type == "zip":
- return self.fopen.writestr(info, fileobj.read())
- else:
- raise ValueError(f"Not supported: type={self.type}")
-
- def generate_info(self, name, size) -> Union[tarfile.TarInfo, zipfile.ZipInfo]:
- """Generate TarInfo using system information"""
- if self.type == "tar":
- tarinfo = tarfile.TarInfo(str(name))
- if os.name == "posix":
- tarinfo.gid = os.getgid()
- tarinfo.uid = os.getuid()
- tarinfo.mtime = datetime.now().timestamp()
- tarinfo.size = size
- # Keep mode as default
- return tarinfo
- elif self.type == "zip":
- zipinfo = zipfile.ZipInfo(str(name), datetime.now().timetuple()[:6])
- zipinfo.file_size = size
- return zipinfo
- else:
- raise ValueError(f"Not supported: type={self.type}")
-
- def get_name_from_info(self, info):
- if self.type == "tar":
- assert isinstance(info, tarfile.TarInfo), type(info)
- return info.name
- elif self.type == "zip":
- assert isinstance(info, zipfile.ZipInfo), type(info)
- return info.filename
- else:
- raise ValueError(f"Not supported: type={self.type}")
-
- def extract(self, info, path=None):
- if self.type == "tar":
- return self.fopen.extract(info, path)
- elif self.type == "zip":
- return self.fopen.extract(info, path)
- else:
- raise ValueError(f"Not supported: type={self.type}")
-
- def extractfile(self, info, mode="r"):
- if self.type == "tar":
- f = self.fopen.extractfile(info)
- if mode == "r":
- return TextIOWrapper(f)
- else:
- return f
- elif self.type == "zip":
- if mode == "rb":
- mode = "r"
- return self.fopen.open(info, mode)
- else:
- raise ValueError(f"Not supported: type={self.type}")
-
-
-def find_path_and_change_it_recursive(value, src: str, tgt: str):
- if isinstance(value, dict):
- return {
- k: find_path_and_change_it_recursive(v, src, tgt) for k, v in value.items()
- }
- elif isinstance(value, (list, tuple)):
- return [find_path_and_change_it_recursive(v, src, tgt) for v in value]
- elif isinstance(value, str) and Path(value) == Path(src):
- return tgt
- else:
- return value
-
-
-def get_dict_from_cache(meta: Union[Path, str]) -> Optional[Dict[str, str]]:
- meta = Path(meta)
- outpath = meta.parent.parent
- if not meta.exists():
- return None
-
- with meta.open("r", encoding="utf-8") as f:
- d = yaml.safe_load(f)
- assert isinstance(d, dict), type(d)
- yaml_files = d["yaml_files"]
- files = d["files"]
- assert isinstance(yaml_files, dict), type(yaml_files)
- assert isinstance(files, dict), type(files)
-
- retval = {}
- for key, value in list(yaml_files.items()) + list(files.items()):
- if not (outpath / value).exists():
- return None
- retval[key] = str(outpath / value)
- return retval
-
-
-def unpack(
- input_archive: Union[Path, str],
- outpath: Union[Path, str],
- use_cache: bool = True,
-) -> Dict[str, str]:
- """Scan all files in the archive file and return as a dict of files.
-
- Examples:
- tarfile:
- model.pb
- some1.file
- some2.file
-
- >>> unpack("tarfile", "out")
- {'asr_model_file': 'out/model.pb'}
- """
- input_archive = Path(input_archive)
- outpath = Path(outpath)
-
- with Archiver(input_archive) as archive:
- for info in archive:
- if Path(archive.get_name_from_info(info)).name == "meta.yaml":
- if (
- use_cache
- and (outpath / Path(archive.get_name_from_info(info))).exists()
- ):
- retval = get_dict_from_cache(
- outpath / Path(archive.get_name_from_info(info))
- )
- if retval is not None:
- return retval
- d = yaml.safe_load(archive.extractfile(info))
- assert isinstance(d, dict), type(d)
- yaml_files = d["yaml_files"]
- files = d["files"]
- assert isinstance(yaml_files, dict), type(yaml_files)
- assert isinstance(files, dict), type(files)
- break
- else:
- raise RuntimeError("Format error: not found meta.yaml")
-
- for info in archive:
- fname = archive.get_name_from_info(info)
- outname = outpath / fname
- outname.parent.mkdir(parents=True, exist_ok=True)
- if fname in set(yaml_files.values()):
- d = yaml.safe_load(archive.extractfile(info))
- # Rewrite yaml
- for info2 in archive:
- name = archive.get_name_from_info(info2)
- d = find_path_and_change_it_recursive(d, name, str(outpath / name))
- with outname.open("w", encoding="utf-8") as f:
- yaml.safe_dump(d, f)
- else:
- archive.extract(info, path=outpath)
-
- retval = {}
- for key, value in list(yaml_files.items()) + list(files.items()):
- retval[key] = str(outpath / value)
- return retval
-
-
-def _to_relative_or_resolve(f):
- # Resolve to avoid symbolic link
- p = Path(f).resolve()
- try:
- # Change to relative if it can
- p = p.relative_to(Path(".").resolve())
- except ValueError:
- pass
- return str(p)
-
-
-def pack(
- files: Dict[str, Union[str, Path]],
- yaml_files: Dict[str, Union[str, Path]],
- outpath: Union[str, Path],
- option: Iterable[Union[str, Path]] = (),
-):
- for v in list(files.values()) + list(yaml_files.values()) + list(option):
- if not Path(v).exists():
- raise FileNotFoundError(f"No such file or directory: {v}")
-
- files = {k: _to_relative_or_resolve(v) for k, v in files.items()}
- yaml_files = {k: _to_relative_or_resolve(v) for k, v in yaml_files.items()}
- option = [_to_relative_or_resolve(v) for v in option]
-
- meta_objs = dict(
- files=files,
- yaml_files=yaml_files,
- timestamp=datetime.now().timestamp(),
- python=sys.version,
- )
-
- try:
- import torch
-
- meta_objs.update(torch=str(torch.__version__))
- except ImportError:
- pass
- try:
- import espnet
-
- meta_objs.update(espnet=espnet.__version__)
- except ImportError:
- pass
-
- Path(outpath).parent.mkdir(parents=True, exist_ok=True)
- with Archiver(outpath, mode="w") as archive:
- # Write packed/meta.yaml
- fileobj = BytesIO(yaml.safe_dump(meta_objs).encode())
- info = archive.generate_info("meta.yaml", fileobj.getbuffer().nbytes)
- archive.addfile(info, fileobj=fileobj)
-
- for f in list(yaml_files.values()) + list(files.values()) + list(option):
- archive.add(f)
-
- print(f"Generate: {outpath}")
diff --git a/funasr/train_utils/pytorch_version.py b/funasr/train_utils/pytorch_version.py
deleted file mode 100644
index 01f17cc..0000000
--- a/funasr/train_utils/pytorch_version.py
+++ /dev/null
@@ -1,16 +0,0 @@
-import torch
-
-
-def pytorch_cudnn_version() -> str:
- message = (
- f"pytorch.version={torch.__version__}, "
- f"cuda.available={torch.cuda.is_available()}, "
- )
-
- if torch.backends.cudnn.enabled:
- message += (
- f"cudnn.version={torch.backends.cudnn.version()}, "
- f"cudnn.benchmark={torch.backends.cudnn.benchmark}, "
- f"cudnn.deterministic={torch.backends.cudnn.deterministic}"
- )
- return message
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 425b79f..ea502f7 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -55,7 +55,7 @@
self.dataloader_val = dataloader_val
self.output_dir = kwargs.get('output_dir', './')
self.resume = kwargs.get('resume', True)
- self.start_epoch = 1
+ self.start_epoch = 0
self.max_epoch = kwargs.get('max_epoch', 100)
self.local_rank = local_rank
self.use_ddp = use_ddp
@@ -123,7 +123,7 @@
for epoch in range(self.start_epoch, self.max_epoch + 1):
self._train_epoch(epoch)
# self._validate_epoch(epoch)
- if dist.get_rank() == 0:
+ if self.rank == 0:
self._save_checkpoint(epoch)
self.scheduler.step()
break
@@ -201,21 +201,22 @@
speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
speed_stats["total_time"] = total_time
-
+ # import pdb;
+ # pdb.set_trace()
pbar.update(1)
if self.local_rank == 0:
description = (
f"Epoch: {epoch + 1}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
- f"(loss: {loss.detach().float():.3f}), "
+ f"(loss: {loss.detach().cpu().item():.3f}), "
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
)
pbar.set_description(description)
- if batch_idx == 2:
- break
+ # if batch_idx == 2:
+ # break
pbar.close()
def _validate_epoch(self, epoch):
diff --git a/funasr/utils/class_choices.py b/funasr/utils/class_choices.py
deleted file mode 100644
index 1ffb97a..0000000
--- a/funasr/utils/class_choices.py
+++ /dev/null
@@ -1,90 +0,0 @@
-from typing import Mapping
-from typing import Optional
-from typing import Tuple
-
-
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.types import str_or_none
-
-
-class ClassChoices:
- """Helper class to manage the options for variable objects and its configuration.
-
- Example:
-
- >>> class A:
- ... def __init__(self, foo=3): pass
- >>> class B:
- ... def __init__(self, bar="aaaa"): pass
- >>> choices = ClassChoices("var", dict(a=A, b=B), default="a")
- >>> import argparse
- >>> parser = argparse.ArgumentParser()
- >>> choices.add_arguments(parser)
- >>> args = parser.parse_args(["--var", "a", "--var_conf", "foo=4")
- >>> args.var
- a
- >>> args.var_conf
- {"foo": 4}
- >>> class_obj = choices.get_class(args.var)
- >>> a_object = class_obj(**args.var_conf)
-
- """
-
- def __init__(
- self,
- name: str,
- classes: Mapping[str, type],
- type_check: type = None,
- default: str = None,
- optional: bool = False,
- ):
- self.name = name
- self.base_type = type_check
- self.classes = {k.lower(): v for k, v in classes.items()}
- if "none" in self.classes or "nil" in self.classes or "null" in self.classes:
- raise ValueError('"none", "nil", and "null" are reserved.')
- if type_check is not None:
- for v in self.classes.values():
- if not issubclass(v, type_check):
- raise ValueError(f"must be {type_check.__name__}, but got {v}")
-
- self.optional = optional
- self.default = default
- if default is None:
- self.optional = True
-
- def choices(self) -> Tuple[Optional[str], ...]:
- retval = tuple(self.classes)
- if self.optional:
- return retval + (None,)
- else:
- return retval
-
- def get_class(self, name: Optional[str]) -> Optional[type]:
- if name is None or (self.optional and name.lower() == ("none", "null", "nil")):
- retval = None
- elif name.lower() in self.classes:
- class_obj = self.classes[name]
- retval = class_obj
- else:
- raise ValueError(
- f"--{self.name} must be one of {self.choices()}: "
- f"--{self.name} {name.lower()}"
- )
-
- return retval
-
- def add_arguments(self, parser):
- parser.add_argument(
- f"--{self.name}",
- type=lambda x: str_or_none(x.lower()),
- default=self.default,
- choices=self.choices(),
- help=f"The {self.name} type",
- )
- parser.add_argument(
- f"--{self.name}_conf",
- action=NestedDictAction,
- default=dict(),
- help=f"The keyword arguments for {self.name}",
- )
diff --git a/funasr/utils/cli_utils.py b/funasr/utils/cli_utils.py
deleted file mode 100644
index c4a4cd1..0000000
--- a/funasr/utils/cli_utils.py
+++ /dev/null
@@ -1,65 +0,0 @@
-from collections.abc import Sequence
-from distutils.util import strtobool as dist_strtobool
-import sys
-
-import numpy
-
-
-def strtobool(x):
- # distutils.util.strtobool returns integer, but it's confusing,
- return bool(dist_strtobool(x))
-
-
-def get_commandline_args():
- extra_chars = [
- " ",
- ";",
- "&",
- "(",
- ")",
- "|",
- "^",
- "<",
- ">",
- "?",
- "*",
- "[",
- "]",
- "$",
- "`",
- '"',
- "\\",
- "!",
- "{",
- "}",
- ]
-
- # Escape the extra characters for shell
- argv = [
- arg.replace("'", "'\\''")
- if all(char not in arg for char in extra_chars)
- else "'" + arg.replace("'", "'\\''") + "'"
- for arg in sys.argv
- ]
-
- return sys.executable + " " + " ".join(argv)
-
-
-def is_scipy_wav_style(value):
- # If Tuple[int, numpy.ndarray] or not
- return (
- isinstance(value, Sequence)
- and len(value) == 2
- and isinstance(value[0], int)
- and isinstance(value[1], numpy.ndarray)
- )
-
-
-def assert_scipy_wav_style(value):
- assert is_scipy_wav_style(
- value
- ), "Must be Tuple[int, numpy.ndarray], but got {}".format(
- type(value)
- if not isinstance(value, Sequence)
- else "{}[{}]".format(type(value), ", ".join(str(type(v)) for v in value))
- )
diff --git a/funasr/utils/dynamic_import.py b/funasr/utils/dynamic_import.py
deleted file mode 100644
index 2830cb2..0000000
--- a/funasr/utils/dynamic_import.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import importlib
-
-
-def dynamic_import(import_path):
- """dynamic import module and class
-
- :param str import_path: syntax 'module_name:class_name'
- :return: imported class
- """
-
- module_name, objname = import_path.split(":")
- m = importlib.import_module(module_name)
- return getattr(m, objname)
diff --git a/funasr/utils/load_fr_tf.py b/funasr/utils/load_fr_tf.py
deleted file mode 100644
index 5c8c275..0000000
--- a/funasr/utils/load_fr_tf.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import numpy as np
-np.set_printoptions(threshold=np.inf)
-import logging
-
-def load_ckpt(checkpoint_path):
- import tensorflow as tf
- if tf.__version__.startswith('2'):
- import tensorflow.compat.v1 as tf
- tf.disable_v2_behavior()
- reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
- else:
- from tensorflow.python import pywrap_tensorflow
- reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
- var_to_shape_map = reader.get_variable_to_shape_map()
-
- var_dict = dict()
- for var_name in sorted(var_to_shape_map):
- if "Adam" in var_name:
- continue
- tensor = reader.get_tensor(var_name)
- # print("in ckpt: {}, {}".format(var_name, tensor.shape))
- # print(tensor)
- var_dict[var_name] = tensor
-
- return var_dict
-
-
-
-def load_tf_pb_dict(pb_model):
- import tensorflow as tf
- if tf.__version__.startswith('2'):
- import tensorflow.compat.v1 as tf
- tf.disable_v2_behavior()
- # import tensorflow_addons as tfa
- # from tensorflow_addons.seq2seq.python.ops import beam_search_ops
- else:
- from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
- from tensorflow.python.ops import lookup_ops as lookup
- from tensorflow.python.framework import tensor_util
- from tensorflow.python.platform import gfile
-
- sess = tf.Session()
- with gfile.FastGFile(pb_model, 'rb') as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- sess.graph.as_default()
- tf.import_graph_def(graph_def, name='')
-
- var_dict = dict()
- for node in sess.graph_def.node:
- if node.op == 'Const':
- value = tensor_util.MakeNdarray(node.attr['value'].tensor)
- if len(value.shape) >= 1:
- var_dict[node.name] = value
- return var_dict
-
-def load_tf_dict(pb_model):
- if "model.ckpt-" in pb_model:
- var_dict = load_ckpt(pb_model)
- else:
- var_dict = load_tf_pb_dict(pb_model)
- return var_dict
diff --git a/funasr/utils/register.py b/funasr/utils/register.py
new file mode 100644
index 0000000..0dfcdab
--- /dev/null
+++ b/funasr/utils/register.py
@@ -0,0 +1,72 @@
+import logging
+import inspect
+from dataclasses import dataclass, fields
+
+
+@dataclass
+class ClassRegistryTables:
+ model_classes = {}
+ frontend_classes = {}
+ specaug_classes = {}
+ normalize_classes = {}
+ encoder_classes = {}
+ decoder_classes = {}
+ joint_network_classes = {}
+ predictor_classes = {}
+ stride_conv_classes = {}
+ tokenizer_classes = {}
+ batch_sampler_classes = {}
+ dataset_classes = {}
+ index_ds_classes = {}
+
+ def print_register_tables(self,):
+ print("\nregister_tables: \n")
+ fields = vars(self)
+ for classes_key, classes_dict in fields.items():
+ print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
+
+ if classes_key.endswith("_meta"):
+ headers = ["class name", "register name", "class location"]
+ metas = []
+ for register_key, meta in classes_dict.items():
+ metas.append(meta)
+ metas.sort(key=lambda x: x[0])
+ data = [headers] + metas
+ col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
+
+ for row in data:
+ print("| " + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths)) + " |")
+ print("\n")
+
+registry_tables = ClassRegistryTables()
+
+def register_class(registry_tables_key:str, key=None):
+ def decorator(target_class):
+
+ if not hasattr(registry_tables, registry_tables_key):
+ setattr(registry_tables, registry_tables_key, {})
+ logging.info("new registry table has been added: {}".format(registry_tables_key))
+
+ registry = getattr(registry_tables, registry_tables_key)
+ registry_key = key if key is not None else target_class.__name__
+ registry_key = registry_key.lower()
+ # import pdb; pdb.set_trace()
+ assert not registry_key in registry, "(key: {} / class: {}) has been registered already锛宨n {}".format(registry_key, target_class, registry_tables_key)
+
+ registry[registry_key] = target_class
+
+ # meta锛� headers = ["class name", "register name", "class location"]
+ registry_tables_key_meta = registry_tables_key + "_meta"
+ if not hasattr(registry_tables, registry_tables_key_meta):
+ setattr(registry_tables, registry_tables_key_meta, {})
+ registry_meta = getattr(registry_tables, registry_tables_key_meta)
+ class_file = inspect.getfile(target_class)
+ class_line = inspect.getsourcelines(target_class)[1]
+ meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
+ registry_meta[registry_key] = meata_data
+ # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
+ return target_class
+ return decorator
+
+import funasr
+
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index c463f0c..8186dff 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -269,7 +269,7 @@
def convert_external_alphas(alphas_file, text_file, output_file):
- from funasr.models.predictor.cif import cif_wo_hidden
+ from funasr.models.paraformer.cif_predictor import cif_wo_hidden
with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
for line1, line2 in zip(f1.readlines(), f2.readlines()):
line1 = line1.rstrip()
diff --git a/funasr/utils/yaml_no_alias_safe_dump.py b/funasr/utils/yaml_no_alias_safe_dump.py
deleted file mode 100644
index 70a7b0e..0000000
--- a/funasr/utils/yaml_no_alias_safe_dump.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import yaml
-
-
-class NoAliasSafeDumper(yaml.SafeDumper):
- # Disable anchor/alias in yaml because looks ugly
- def ignore_aliases(self, data):
- return True
-
-
-def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
- """Safe-dump in yaml with no anchor/alias"""
- return yaml.dump(
- data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
- )
--
Gitblit v1.9.1