From 806a03609df033d61f824f1ab8527eb88fe837ad Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 十二月 2023 19:43:13 +0800
Subject: [PATCH] funasr2 paraformer biciparaformer contextuaparaformer
---
funasr/models/paraformer/model.py | 1760 +++++++++++++++++++++++++++++++++++++++
funasr/models/paraformer/search.py | 453 ++++++++++
.gitignore | 1
examples/industrial_data_pretraining/paraformer-large/infer.sh | 15
funasr/models/frontend/wav_frontend.py | 2
examples/industrial_data_pretraining/paraformer-large/run.sh | 2
funasr/bin/inference.py | 170 +++
funasr/datasets/fun_datasets/load_audio_extract_fbank.py | 75 +
/dev/null | 81 -
funasr/tokenizer/abs_tokenizer.py | 4
funasr/datasets/fun_datasets/__init__.py | 0
funasr/utils/download_from_hub.py | 24
funasr/cli/train_cli.py | 8
funasr/datasets/dataset_jsonl.py | 33
funasr/models/paraformer/__init__.py | 0
funasr/__init__.py | 2
funasr/bin/asr_inference_launch.py | 31
17 files changed, 2,501 insertions(+), 160 deletions(-)
diff --git a/.gitignore b/.gitignore
index 37f39fe..dea4634 100644
--- a/.gitignore
+++ b/.gitignore
@@ -21,3 +21,4 @@
modelscope
samples
.ipynb_checkpoints
+outputs*
diff --git a/examples/industrial_data_pretraining/paraformer-large/infer.sh b/examples/industrial_data_pretraining/paraformer-large/infer.sh
new file mode 100644
index 0000000..b7fbe75
--- /dev/null
+++ b/examples/industrial_data_pretraining/paraformer-large/infer.sh
@@ -0,0 +1,15 @@
+
+cmd="funasr/bin/inference.py"
+
+python $cmd \
++model="/Users/zhifu/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
++input="/Users/zhifu/Downloads/asr_example.wav" \
++output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
++device="cpu" \
++"hotword='杈鹃瓟闄� 榄旀惌'"
+
+#+input="/Users/zhifu/funasr_github/test_local/asr_example.wav" \
+#+input="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
+#+model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+
+#+model="/Users/zhifu/modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer-large/run.sh b/examples/industrial_data_pretraining/paraformer-large/run.sh
index 9b40b81..8571974 100644
--- a/examples/industrial_data_pretraining/paraformer-large/run.sh
+++ b/examples/industrial_data_pretraining/paraformer-large/run.sh
@@ -2,7 +2,7 @@
cmd="funasr/cli/train_cli.py"
python $cmd \
-+model_pretrain="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
++model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+token_list="/Users/zhifu/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt" \
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
diff --git a/funasr/__init__.py b/funasr/__init__.py
index d0b7aa5..ac1591d 100644
--- a/funasr/__init__.py
+++ b/funasr/__init__.py
@@ -7,4 +7,4 @@
with open(version_file, "r") as f:
__version__ = f.read().strip()
-from funasr.bin.inference_cli import infer
\ No newline at end of file
+from funasr.bin.inference import infer
\ No newline at end of file
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index f34bfb2..6151d28 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -1254,37 +1254,6 @@
return cache
- #def _prepare_cache(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
- # if len(cache) > 0:
- # return cache
- # config = _read_yaml(asr_train_config)
- # enc_output_size = config["encoder_conf"]["output_size"]
- # feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
- # cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
- # "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
- # "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
- # cache["encoder"] = cache_en
-
- # cache_de = {"decode_fsmn": None}
- # cache["decoder"] = cache_de
-
- # return cache
-
- #def _cache_reset(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
- # if len(cache) > 0:
- # config = _read_yaml(asr_train_config)
- # enc_output_size = config["encoder_conf"]["output_size"]
- # feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
- # cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
- # "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
- # "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
- # "tail_chunk": False}
- # cache["encoder"] = cache_en
-
- # cache_de = {"decode_fsmn": None}
- # cache["decoder"] = cache_de
-
- # return cache
def _forward(
data_path_and_name_and_type,
diff --git a/funasr/bin/asr_train.py b/funasr/bin/asr_train.py
deleted file mode 100755
index 8161e7b..0000000
--- a/funasr/bin/asr_train.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import os
-
-from funasr.tasks.asr import ASRTask
-
-
-# for ASR Training
-def parse_args():
- parser = ASRTask.get_parser()
- parser.add_argument(
- "--mode",
- type=str,
- default="asr",
- help=" ",
- )
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
-
- # for ASR Training
- if args.mode == "asr":
- from funasr.tasks.asr import ASRTask
- if args.mode == "paraformer":
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- if args.mode == "uniasr":
- from funasr.tasks.asr import ASRTaskUniASR as ASRTask
- if args.mode == "rnnt":
- from funasr.tasks.asr import ASRTransducerTask as ASRTask
-
- ASRTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
- args = parse_args()
-
- # setup local gpu_id
- if args.ngpu > 0:
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- # re-compute batch size: when dataset type is small
- if args.dataset_type == "small":
- if args.batch_size is not None and args.ngpu > 0:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None and args.ngpu > 0:
- args.batch_bins = args.batch_bins * args.ngpu
-
- main(args=args)
-
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
new file mode 100644
index 0000000..4e140e1
--- /dev/null
+++ b/funasr/bin/inference.py
@@ -0,0 +1,170 @@
+import os.path
+
+import torch
+import numpy as np
+import hydra
+import json
+from omegaconf import DictConfig, OmegaConf
+from funasr.utils.dynamic_import import dynamic_import
+import logging
+from funasr.utils.download_from_hub import download_model
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.tokenizer.funtoken import build_tokenizer
+from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_bytes
+from funasr.torch_utils.device_funcs import to_device
+from tqdm import tqdm
+from funasr.torch_utils.load_pretrained_model import load_pretrained_model
+import time
+import random
+import string
+
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(kwargs: DictConfig):
+ assert "model" in kwargs
+
+ pipeline = infer(**kwargs)
+ res = pipeline(input=kwargs["input"])
+ print(res)
+
+def infer(**kwargs):
+
+ if ":" not in kwargs["model"]:
+ logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+ kwargs = download_model(**kwargs)
+
+ set_all_random_seed(kwargs.get("seed", 0))
+
+
+ device = kwargs.get("device", "cuda")
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1):
+ device = "cpu"
+ batch_size = 1
+ kwargs["device"] = device
+
+ # build_tokenizer
+ tokenizer = build_tokenizer(
+ token_type=kwargs.get("token_type", "char"),
+ bpemodel=kwargs.get("bpemodel", None),
+ delimiter=kwargs.get("delimiter", None),
+ space_symbol=kwargs.get("space_symbol", "<space>"),
+ non_linguistic_symbols=kwargs.get("non_linguistic_symbols", None),
+ g2p_type=kwargs.get("g2p_type", None),
+ token_list=kwargs.get("token_list", None),
+ unk_symbol=kwargs.get("unk_symbol", "<unk>"),
+ )
+
+ import pdb;
+ pdb.set_trace()
+ # build model
+ model_class = dynamic_import(kwargs.get("model"))
+ model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+ model.eval()
+ model.to(device)
+ frontend = model.frontend
+ kwargs["token_list"] = tokenizer.token_list
+
+
+ # init_param
+ init_param = kwargs.get("init_param", None)
+ if init_param is not None:
+ logging.info(f"Loading pretrained params from {init_param}")
+ load_pretrained_model(
+ model=model,
+ init_param=init_param,
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
+ oss_bucket=kwargs.get("oss_bucket", None),
+ )
+
+ def _forward(input, input_len=None, **cfg):
+ cfg = OmegaConf.merge(kwargs, cfg)
+ date_type = cfg.get("date_type", "sound")
+
+ key_list, data_list = build_iter_for_infer(input, input_len=input_len, date_type=date_type, frontend=frontend)
+
+ speed_stats = {}
+ asr_result_list = []
+ num_samples = len(data_list)
+ pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
+ for beg_idx in range(0, num_samples, batch_size):
+
+ end_idx = min(num_samples, beg_idx + batch_size)
+ data_batch = data_list[beg_idx:end_idx]
+ key_batch = key_list[beg_idx:end_idx]
+ batch = {"data_in": data_batch, "key": key_batch}
+
+ time1 = time.perf_counter()
+ results, meta_data = model.generate(**batch, tokenizer=tokenizer, **cfg)
+ time2 = time.perf_counter()
+
+ asr_result_list.append(results)
+ pbar.update(1)
+
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
+ batch_data_time = meta_data.get("batch_data_time", -1)
+ speed_stats["load_data"] = meta_data["load_data"]
+ speed_stats["extract_feat"] = meta_data["extract_feat"]
+ speed_stats["forward"] = f"{time2 - time1:0.3f}"
+ speed_stats["rtf"] = f"{(time2 - time1)/batch_data_time:0.3f}"
+ description = (
+ f"{speed_stats}, "
+ )
+ pbar.set_description(description)
+
+ torch.cuda.empty_cache()
+ return asr_result_list
+
+ return _forward
+
+
+def build_iter_for_infer(data_in, input_len=None, date_type="sound", frontend=None):
+ """
+
+ :param input:
+ :param input_len:
+ :param date_type:
+ :param frontend:
+ :return:
+ """
+ data_list = []
+ key_list = []
+ filelist = [".scp", ".txt", ".json", ".jsonl"]
+
+ chars = string.ascii_letters + string.digits
+
+ if isinstance(data_in, str) and os.path.exists(data_in): # wav_pat; filelist: wav.scp, file.jsonl;text.txt;
+ _, file_extension = os.path.splitext(data_in)
+ file_extension = file_extension.lower()
+ if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt;
+ with open(data_in, encoding='utf-8') as fin:
+ for line in fin:
+ key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
+ if data_in.endswith(".jsonl"): #file.jsonl: json.dumps({"source": data})
+ lines = json.loads(line.strip())
+ data = lines["source"]
+ key = data["key"] if "key" in data else key
+ else: # filelist, wav.scp, text.txt: id \t data or data
+ lines = line.strip().split()
+ data = lines[1] if len(lines)>1 else lines[0]
+ key = lines[0] if len(lines)>1 else key
+
+ data_list.append(data)
+ key_list.append(key)
+ else:
+ key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
+ data_list = [data_in]
+ key_list = [key]
+ elif isinstance(data_in, (list, tuple)): # [audio sample point, fbank, wav_path]
+ data_list = data_in
+ key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
+ else: # raw text; audio sample point, fbank
+ if isinstance(data_in, bytes): # audio bytes
+ data_in = load_bytes(data_in)
+ key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
+ data_list = [data_in]
+ key_list = [key]
+
+ return key_list, data_list
+
+
+if __name__ == '__main__':
+ main_hydra()
\ No newline at end of file
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
deleted file mode 100755
index 6aebf8a..0000000
--- a/funasr/bin/train.py
+++ /dev/null
@@ -1,595 +0,0 @@
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import os
-import sys
-from io import BytesIO
-
-import torch
-
-from funasr.build_utils.build_args import build_args
-from funasr.build_utils.build_dataloader import build_dataloader
-from funasr.build_utils.build_distributed import build_distributed
-from funasr.build_utils.build_model import build_model
-from funasr.build_utils.build_optimizer import build_optimizer
-from funasr.build_utils.build_scheduler import build_scheduler
-from funasr.build_utils.build_trainer import build_trainer
-from funasr.tokenizer.phoneme_tokenizer import g2p_choices
-from funasr.torch_utils.load_pretrained_model import load_pretrained_model
-from funasr.torch_utils.model_summary import model_summary
-from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.prepare_data import prepare_data
-from funasr.utils.types import int_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str_or_none
-from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
-from funasr.modules.lora.utils import mark_only_lora_as_trainable
-
-
-def get_parser():
- parser = argparse.ArgumentParser(
- description="FunASR Common Training Parser",
- )
-
- # common configuration
- parser.add_argument("--output_dir", help="model save path")
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
-
- # ddp related
- parser.add_argument(
- "--dist_backend",
- default="nccl",
- type=str,
- help="distributed backend",
- )
- parser.add_argument(
- "--dist_init_method",
- type=str,
- default="env://",
- help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
- '"WORLD_SIZE", and "RANK" are referred.',
- )
- parser.add_argument(
- "--dist_world_size",
- type=int,
- default=1,
- help="number of nodes for distributed training",
- )
- parser.add_argument(
- "--dist_rank",
- type=int,
- default=None,
- help="node rank for distributed training",
- )
- parser.add_argument(
- "--local_rank",
- type=int,
- default=None,
- help="local rank for distributed training",
- )
- parser.add_argument(
- "--dist_master_addr",
- default=None,
- type=str_or_none,
- help="The master address for distributed training. "
- "This value is used when dist_init_method == 'env://'",
- )
- parser.add_argument(
- "--dist_master_port",
- default=None,
- type=int_or_none,
- help="The master port for distributed training"
- "This value is used when dist_init_method == 'env://'",
- )
- parser.add_argument(
- "--dist_launcher",
- default=None,
- type=str_or_none,
- choices=["slurm", "mpi", None],
- help="The launcher type for distributed training",
- )
- parser.add_argument(
- "--multiprocessing_distributed",
- default=True,
- type=str2bool,
- help="Use multi-processing distributed training to launch "
- "N processes per node, which has N GPUs. This is the "
- "fastest way to use PyTorch for either single node or "
- "multi node data parallel training",
- )
- parser.add_argument(
- "--unused_parameters",
- type=str2bool,
- default=False,
- help="Whether to use the find_unused_parameters in "
- "torch.nn.parallel.DistributedDataParallel ",
- )
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
-
- # cudnn related
- parser.add_argument(
- "--cudnn_enabled",
- type=str2bool,
- default=torch.backends.cudnn.enabled,
- help="Enable CUDNN",
- )
- parser.add_argument(
- "--cudnn_benchmark",
- type=str2bool,
- default=torch.backends.cudnn.benchmark,
- help="Enable cudnn-benchmark mode",
- )
- parser.add_argument(
- "--cudnn_deterministic",
- type=str2bool,
- default=True,
- help="Enable cudnn-deterministic mode",
- )
-
- # trainer related
- parser.add_argument(
- "--max_epoch",
- type=int,
- default=40,
- help="The maximum number epoch to train",
- )
- parser.add_argument(
- "--max_update",
- type=int,
- default=sys.maxsize,
- help="The maximum number update step to train",
- )
- parser.add_argument(
- "--batch_interval",
- type=int,
- default=10000,
- help="The batch interval for saving model.",
- )
- parser.add_argument(
- "--patience",
- type=int_or_none,
- default=None,
- help="Number of epochs to wait without improvement "
- "before stopping the training",
- )
- parser.add_argument(
- "--val_scheduler_criterion",
- type=str,
- nargs=2,
- default=("valid", "loss"),
- help="The criterion used for the value given to the lr scheduler. "
- 'Give a pair referring the phase, "train" or "valid",'
- 'and the criterion name. The mode specifying "min" or "max" can '
- "be changed by --scheduler_conf",
- )
- parser.add_argument(
- "--early_stopping_criterion",
- type=str,
- nargs=3,
- default=("valid", "loss", "min"),
- help="The criterion used for judging of early stopping. "
- 'Give a pair referring the phase, "train" or "valid",'
- 'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
- )
- parser.add_argument(
- "--best_model_criterion",
- nargs="+",
- default=[
- ("train", "loss", "min"),
- ("valid", "loss", "min"),
- ("train", "acc", "max"),
- ("valid", "acc", "max"),
- ],
- help="The criterion used for judging of the best model. "
- 'Give a pair referring the phase, "train" or "valid",'
- 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
- )
- parser.add_argument(
- "--keep_nbest_models",
- type=int,
- nargs="+",
- default=[10],
- help="Remove previous snapshots excluding the n-best scored epochs",
- )
- parser.add_argument(
- "--nbest_averaging_interval",
- type=int,
- default=0,
- help="The epoch interval to apply model averaging and save nbest models",
- )
- parser.add_argument(
- "--grad_clip",
- type=float,
- default=5.0,
- help="Gradient norm threshold to clip",
- )
- parser.add_argument(
- "--grad_clip_type",
- type=float,
- default=2.0,
- help="The type of the used p-norm for gradient clip. Can be inf",
- )
- parser.add_argument(
- "--grad_noise",
- type=str2bool,
- default=False,
- help="The flag to switch to use noise injection to "
- "gradients during training",
- )
- parser.add_argument(
- "--accum_grad",
- type=int,
- default=1,
- help="The number of gradient accumulation",
- )
- parser.add_argument(
- "--resume",
- type=str2bool,
- default=False,
- help="Enable resuming if checkpoint is existing",
- )
- parser.add_argument(
- "--train_dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type for training.",
- )
- parser.add_argument(
- "--use_amp",
- type=str2bool,
- default=False,
- help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
- )
- parser.add_argument(
- "--log_interval",
- default=None,
- help="Show the logs every the number iterations in each epochs at the "
- "training phase. If None is given, it is decided according the number "
- "of training samples automatically .",
- )
- parser.add_argument(
- "--use_tensorboard",
- type=str2bool,
- default=True,
- help="Enable tensorboard logging",
- )
-
- # pretrained model related
- parser.add_argument(
- "--init_param",
- type=str,
- action="append",
- default=[],
- help="Specify the file path used for initialization of parameters. "
- "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
- "where file_path is the model file path, "
- "src_key specifies the key of model states to be used in the model file, "
- "dst_key specifies the attribute of the model to be initialized, "
- "and exclude_keys excludes keys of model states for the initialization."
- "e.g.\n"
- " # Load all parameters"
- " --init_param some/where/model.pb\n"
- " # Load only decoder parameters"
- " --init_param some/where/model.pb:decoder:decoder\n"
- " # Load only decoder parameters excluding decoder.embed"
- " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
- " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
- )
- parser.add_argument(
- "--ignore_init_mismatch",
- type=str2bool,
- default=False,
- help="Ignore size mismatch when loading pre-trained model",
- )
- parser.add_argument(
- "--freeze_param",
- type=str,
- default=[],
- action="append",
- help="Freeze parameters",
- )
-
- # dataset related
- parser.add_argument(
- "--dataset_type",
- type=str,
- default="small",
- help="whether to use dataloader for large dataset",
- )
- parser.add_argument(
- "--dataset_conf",
- action=NestedDictAction,
- default=dict(),
- help=f"The keyword arguments for dataset",
- )
- parser.add_argument(
- "--data_dir",
- type=str,
- default=None,
- help="root path of data",
- )
- parser.add_argument(
- "--train_set",
- type=str,
- default="train",
- help="train dataset",
- )
- parser.add_argument(
- "--valid_set",
- type=str,
- default="validation",
- help="dev dataset",
- )
- parser.add_argument(
- "--data_file_names",
- type=str,
- default="wav.scp,text",
- help="input data files",
- )
- parser.add_argument(
- "--speed_perturb",
- type=float,
- nargs="+",
- default=None,
- help="speed perturb",
- )
- parser.add_argument(
- "--use_preprocessor",
- type=str2bool,
- default=True,
- help="Apply preprocessing to data or not",
- )
-
- # optimization related
- parser.add_argument(
- "--optim",
- type=lambda x: x.lower(),
- default="adam",
- help="The optimizer type",
- )
- parser.add_argument(
- "--optim_conf",
- action=NestedDictAction,
- default=dict(),
- help="The keyword arguments for optimizer",
- )
- parser.add_argument(
- "--scheduler",
- type=lambda x: str_or_none(x.lower()),
- default=None,
- help="The lr scheduler type",
- )
- parser.add_argument(
- "--scheduler_conf",
- action=NestedDictAction,
- default=dict(),
- help="The keyword arguments for lr scheduler",
- )
-
- # most task related
- parser.add_argument(
- "--init",
- type=lambda x: str_or_none(x.lower()),
- default=None,
- help="The initialization method",
- choices=[
- "chainer",
- "xavier_uniform",
- "xavier_normal",
- "kaiming_uniform",
- "kaiming_normal",
- None,
- ],
- )
- parser.add_argument(
- "--token_list",
- type=str_or_none,
- default=None,
- help="A text mapping int-id to token",
- )
- parser.add_argument(
- "--token_type",
- type=str,
- default="bpe",
- choices=["bpe", "char", "word"],
- help="",
- )
- parser.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model file fo sentencepiece",
- )
- parser.add_argument(
- "--cleaner",
- type=str_or_none,
- choices=[None, "tacotron", "jaconv", "vietnamese"],
- default=None,
- help="Apply text cleaning",
- )
- parser.add_argument(
- "--g2p",
- type=str_or_none,
- choices=g2p_choices,
- default=None,
- help="Specify g2p method if --token_type=phn",
- )
-
- # pai related
- parser.add_argument(
- "--use_pai",
- type=str2bool,
- default=False,
- help="flag to indicate whether training on PAI",
- )
- parser.add_argument(
- "--simple_ddp",
- type=str2bool,
- default=False,
- )
- parser.add_argument(
- "--num_worker_count",
- type=int,
- default=1,
- help="The number of machines on PAI.",
- )
- parser.add_argument(
- "--access_key_id",
- type=str,
- default=None,
- help="The username for oss.",
- )
- parser.add_argument(
- "--access_key_secret",
- type=str,
- default=None,
- help="The password for oss.",
- )
- parser.add_argument(
- "--endpoint",
- type=str,
- default=None,
- help="The endpoint for oss.",
- )
- parser.add_argument(
- "--bucket_name",
- type=str,
- default=None,
- help="The bucket name for oss.",
- )
- parser.add_argument(
- "--oss_bucket",
- default=None,
- help="oss bucket.",
- )
- parser.add_argument(
- "--enable_lora",
- type=str2bool,
- default=False,
- help="Apply lora for finetuning.",
- )
- parser.add_argument(
- "--lora_bias",
- type=str,
- default="none",
- help="lora bias.",
- )
-
- return parser
-
-
-if __name__ == '__main__':
- parser = get_parser()
- args, extra_task_params = parser.parse_known_args()
- if extra_task_params:
- args = build_args(args, parser, extra_task_params)
-
- # set random seed
- set_all_random_seed(args.seed)
- torch.backends.cudnn.enabled = args.cudnn_enabled
- torch.backends.cudnn.benchmark = args.cudnn_benchmark
- torch.backends.cudnn.deterministic = args.cudnn_deterministic
-
- # ddp init
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
- args.distributed = args.ngpu > 1 or args.dist_world_size > 1
- distributed_option = build_distributed(args)
-
- # for logging
- if not distributed_option.distributed or distributed_option.dist_rank == 0:
- logging.basicConfig(
- level="INFO",
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- else:
- logging.basicConfig(
- level="ERROR",
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- # prepare files for dataloader
- prepare_data(args, distributed_option)
-
- model = build_model(args)
- model = model.to(
- dtype=getattr(torch, args.train_dtype),
- device="cuda" if args.ngpu > 0 else "cpu",
- )
- if args.enable_lora:
- mark_only_lora_as_trainable(model, args.lora_bias)
- for t in args.freeze_param:
- for k, p in model.named_parameters():
- if k.startswith(t + ".") or k == t:
- logging.info(f"Setting {k}.requires_grad = False")
- p.requires_grad = False
-
- optimizers = build_optimizer(args, model=model)
- schedulers = build_scheduler(args, optimizers)
-
- logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
- distributed_option.dist_rank,
- distributed_option.local_rank))
- logging.info(pytorch_cudnn_version())
- logging.info("Args: {}".format(args))
- logging.info(model_summary(model))
- logging.info("Optimizer: {}".format(optimizers))
- logging.info("Scheduler: {}".format(schedulers))
-
- # dump args to config.yaml
- if not distributed_option.distributed or distributed_option.dist_rank == 0:
- os.makedirs(args.output_dir, exist_ok=True)
- with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
- logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml"))
- if args.use_pai:
- buffer = BytesIO()
- torch.save({"config": vars(args)}, buffer)
- args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
- else:
- yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
-
- for p in args.init_param:
- logging.info(f"Loading pretrained params from {p}")
- load_pretrained_model(
- model=model,
- init_param=p,
- ignore_init_mismatch=args.ignore_init_mismatch,
- map_location=f"cuda:{torch.cuda.current_device()}"
- if args.ngpu > 0
- else "cpu",
- oss_bucket=args.oss_bucket,
- )
-
- # dataloader for training/validation
- train_dataloader, valid_dataloader = build_dataloader(args)
-
- # Trainer, including model, optimizers, etc.
- trainer = build_trainer(
- args=args,
- model=model,
- optimizers=optimizers,
- schedulers=schedulers,
- train_dataloader=train_dataloader,
- valid_dataloader=valid_dataloader,
- distributed_option=distributed_option
- )
-
- trainer.run()
diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py
deleted file mode 100644
index 08018a7..0000000
--- a/funasr/build_utils/build_args.py
+++ /dev/null
@@ -1,122 +0,0 @@
-from funasr.models.ctc import CTC
-from funasr.utils import config_argparse
-from funasr.utils.get_default_kwargs import get_default_kwargs
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.types import int_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str_or_none
-
-
-def build_args(args, parser, extra_task_params):
- task_parser = config_argparse.ArgumentParser("Task related config")
- if args.task_name == "asr":
- from funasr.build_utils.build_asr_model import class_choices_list
- for class_choices in class_choices_list:
- class_choices.add_arguments(task_parser)
- task_parser.add_argument(
- "--split_with_space",
- type=str2bool,
- default=True,
- help="whether to split text using <space>",
- )
- task_parser.add_argument(
- "--seg_dict_file",
- type=str,
- default=None,
- help="seg_dict_file for text processing",
- )
- task_parser.add_argument(
- "--input_size",
- type=int_or_none,
- default=None,
- help="The number of input dimension of the feature",
- )
- task_parser.add_argument(
- "--ctc_conf",
- action=NestedDictAction,
- default=get_default_kwargs(CTC),
- help="The keyword arguments for CTC class.",
- )
- task_parser.add_argument(
- "--cmvn_file",
- type=str_or_none,
- default=None,
- help="The path of cmvn file.",
- )
-
- elif args.task_name == "pretrain":
- from funasr.build_utils.build_pretrain_model import class_choices_list
- for class_choices in class_choices_list:
- class_choices.add_arguments(task_parser)
- task_parser.add_argument(
- "--input_size",
- type=int_or_none,
- default=None,
- help="The number of input dimension of the feature",
- )
- task_parser.add_argument(
- "--cmvn_file",
- type=str_or_none,
- default=None,
- help="The path of cmvn file.",
- )
-
- elif args.task_name == "lm":
- from funasr.build_utils.build_lm_model import class_choices_list
- for class_choices in class_choices_list:
- class_choices.add_arguments(task_parser)
-
- elif args.task_name == "punc":
- from funasr.build_utils.build_punc_model import class_choices_list
- for class_choices in class_choices_list:
- class_choices.add_arguments(task_parser)
-
- elif args.task_name == "vad":
- from funasr.build_utils.build_vad_model import class_choices_list
- for class_choices in class_choices_list:
- class_choices.add_arguments(task_parser)
- task_parser.add_argument(
- "--input_size",
- type=int_or_none,
- default=None,
- help="The number of input dimension of the feature",
- )
- task_parser.add_argument(
- "--cmvn_file",
- type=str_or_none,
- default=None,
- help="The path of cmvn file.",
- )
-
- elif args.task_name == "diar":
- from funasr.build_utils.build_diar_model import class_choices_list
- for class_choices in class_choices_list:
- class_choices.add_arguments(task_parser)
- task_parser.add_argument(
- "--input_size",
- type=int_or_none,
- default=None,
- help="The number of input dimension of the feature",
- )
-
- elif args.task_name == "sv":
- from funasr.build_utils.build_sv_model import class_choices_list
- for class_choices in class_choices_list:
- class_choices.add_arguments(task_parser)
- task_parser.add_argument(
- "--input_size",
- type=int_or_none,
- default=None,
- help="The number of input dimension of the feature",
- )
-
- else:
- raise NotImplementedError("Not supported task: {}".format(args.task_name))
-
- for action in parser._actions:
- if not any(action.dest == a.dest for a in task_parser._actions):
- task_parser._add_action(action)
-
- task_parser.set_defaults(**vars(args))
- task_args = task_parser.parse_args(extra_task_params)
- return task_args
diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
deleted file mode 100644
index fd47bd3..0000000
--- a/funasr/build_utils/build_asr_model.py
+++ /dev/null
@@ -1,559 +0,0 @@
-import logging
-
-from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.models.ctc import CTC
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
-from funasr.models.decoder.rnn_decoder import RNNDecoder
-from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
-from funasr.models.decoder.transformer_decoder import (
- DynamicConvolution2DTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
-from funasr.models.decoder.transformer_decoder import (
- LightweightConvolution2DTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import (
- LightweightConvolutionTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
-from funasr.models.decoder.transformer_decoder import TransformerDecoder
-from funasr.models.decoder.rnnt_decoder import RNNTDecoder
-from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
-from funasr.models.e2e_asr import ASRModel
-from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
-from funasr.models.e2e_asr_mfcca import MFCCA
-
-from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
-from funasr.models.e2e_asr_bat import BATModel
-
-from funasr.models.e2e_sa_asr import SAASRModel
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
-
-from funasr.models.e2e_tp import TimestampPredictor
-from funasr.models.e2e_uni_asr import UniASR
-from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
-from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
-from funasr.models.encoder.resnet34_encoder import ResNet34Diar
-from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
-from funasr.models.encoder.branchformer_encoder import BranchformerEncoder
-from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder
-from funasr.models.encoder.transformer_encoder import TransformerEncoder
-from funasr.models.encoder.rwkv_encoder import RWKVEncoder
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.default import MultiChannelFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.joint_net.joint_network import JointNetwork
-from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
-from funasr.models.specaug.specaug import SpecAug
-from funasr.models.specaug.specaug import SpecAugLFR
-from funasr.modules.subsampling import Conv1dSubsampling
-from funasr.torch_utils.initialize import initialize
-from funasr.train.class_choices import ClassChoices
-
-frontend_choices = ClassChoices(
- name="frontend",
- classes=dict(
- default=DefaultFrontend,
- sliding_window=SlidingWindow,
- s3prl=S3prlFrontend,
- fused=FusedFrontends,
- wav_frontend=WavFrontend,
- multichannelfrontend=MultiChannelFrontend,
- ),
- default="default",
-)
-specaug_choices = ClassChoices(
- name="specaug",
- classes=dict(
- specaug=SpecAug,
- specaug_lfr=SpecAugLFR,
- ),
- default=None,
- optional=True,
-)
-normalize_choices = ClassChoices(
- "normalize",
- classes=dict(
- global_mvn=GlobalMVN,
- utterance_mvn=UtteranceMVN,
- ),
- default=None,
- optional=True,
-)
-model_choices = ClassChoices(
- "model",
- classes=dict(
- asr=ASRModel,
- uniasr=UniASR,
- paraformer=Paraformer,
- paraformer_online=ParaformerOnline,
- paraformer_bert=ParaformerBert,
- bicif_paraformer=BiCifParaformer,
- contextual_paraformer=ContextualParaformer,
- neatcontextual_paraformer=NeatContextualParaformer,
- mfcca=MFCCA,
- timestamp_prediction=TimestampPredictor,
- rnnt=TransducerModel,
- rnnt_unified=UnifiedTransducerModel,
- sa_asr=SAASRModel,
- bat=BATModel,
- ),
- default="asr",
-)
-encoder_choices = ClassChoices(
- "encoder",
- classes=dict(
- conformer=ConformerEncoder,
- transformer=TransformerEncoder,
- rnn=RNNEncoder,
- sanm=SANMEncoder,
- sanm_chunk_opt=SANMEncoderChunkOpt,
- data2vec_encoder=Data2VecEncoder,
- branchformer=BranchformerEncoder,
- e_branchformer=EBranchformerEncoder,
- mfcca_enc=MFCCAEncoder,
- chunk_conformer=ConformerChunkEncoder,
- rwkv=RWKVEncoder,
- ),
- default="rnn",
-)
-asr_encoder_choices = ClassChoices(
- "asr_encoder",
- classes=dict(
- conformer=ConformerEncoder,
- transformer=TransformerEncoder,
- rnn=RNNEncoder,
- sanm=SANMEncoder,
- sanm_chunk_opt=SANMEncoderChunkOpt,
- data2vec_encoder=Data2VecEncoder,
- mfcca_enc=MFCCAEncoder,
- ),
- default="rnn",
-)
-
-spk_encoder_choices = ClassChoices(
- "spk_encoder",
- classes=dict(
- resnet34_diar=ResNet34Diar,
- ),
- default="resnet34_diar",
-)
-encoder_choices2 = ClassChoices(
- "encoder2",
- classes=dict(
- conformer=ConformerEncoder,
- transformer=TransformerEncoder,
- rnn=RNNEncoder,
- sanm=SANMEncoder,
- sanm_chunk_opt=SANMEncoderChunkOpt,
- ),
- default="rnn",
-)
-decoder_choices = ClassChoices(
- "decoder",
- classes=dict(
- transformer=TransformerDecoder,
- lightweight_conv=LightweightConvolutionTransformerDecoder,
- lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
- dynamic_conv=DynamicConvolutionTransformerDecoder,
- dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
- rnn=RNNDecoder,
- fsmn_scama_opt=FsmnDecoderSCAMAOpt,
- paraformer_decoder_sanm=ParaformerSANMDecoder,
- paraformer_decoder_san=ParaformerDecoderSAN,
- contextual_paraformer_decoder=ContextualParaformerDecoder,
- sa_decoder=SAAsrTransformerDecoder,
- ),
- default="rnn",
-)
-decoder_choices2 = ClassChoices(
- "decoder2",
- classes=dict(
- transformer=TransformerDecoder,
- lightweight_conv=LightweightConvolutionTransformerDecoder,
- lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
- dynamic_conv=DynamicConvolutionTransformerDecoder,
- dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
- rnn=RNNDecoder,
- fsmn_scama_opt=FsmnDecoderSCAMAOpt,
- paraformer_decoder_sanm=ParaformerSANMDecoder,
- ),
- type_check=AbsDecoder,
- default="rnn",
-)
-predictor_choices = ClassChoices(
- name="predictor",
- classes=dict(
- cif_predictor=CifPredictor,
- ctc_predictor=None,
- cif_predictor_v2=CifPredictorV2,
- cif_predictor_v3=CifPredictorV3,
- bat_predictor=BATPredictor,
- ),
- default="cif_predictor",
- optional=True,
-)
-predictor_choices2 = ClassChoices(
- name="predictor2",
- classes=dict(
- cif_predictor=CifPredictor,
- ctc_predictor=None,
- cif_predictor_v2=CifPredictorV2,
- ),
- default="cif_predictor",
- optional=True,
-)
-stride_conv_choices = ClassChoices(
- name="stride_conv",
- classes=dict(
- stride_conv1d=Conv1dSubsampling
- ),
- default="stride_conv1d",
- optional=True,
-)
-rnnt_decoder_choices = ClassChoices(
- name="rnnt_decoder",
- classes=dict(
- rnnt=RNNTDecoder,
- ),
- default="rnnt",
- optional=True,
-)
-joint_network_choices = ClassChoices(
- name="joint_network",
- classes=dict(
- joint_network=JointNetwork,
- ),
- default="joint_network",
- optional=True,
-)
-
-class_choices_list = [
- # --frontend and --frontend_conf
- frontend_choices,
- # --specaug and --specaug_conf
- specaug_choices,
- # --normalize and --normalize_conf
- normalize_choices,
- # --model and --model_conf
- model_choices,
- # --encoder and --encoder_conf
- encoder_choices,
- # --decoder and --decoder_conf
- decoder_choices,
- # --predictor and --predictor_conf
- predictor_choices,
- # --encoder2 and --encoder2_conf
- encoder_choices2,
- # --decoder2 and --decoder2_conf
- decoder_choices2,
- # --predictor2 and --predictor2_conf
- predictor_choices2,
- # --stride_conv and --stride_conv_conf
- stride_conv_choices,
- # --rnnt_decoder and --rnnt_decoder_conf
- rnnt_decoder_choices,
- # --joint_network and --joint_network_conf
- joint_network_choices,
- # --asr_encoder and --asr_encoder_conf
- asr_encoder_choices,
- # --spk_encoder and --spk_encoder_conf
- spk_encoder_choices,
-]
-
-
-def build_asr_model(args):
- # token_list
- if isinstance(args.token_list, str):
- with open(args.token_list, encoding="utf-8") as f:
- token_list = [line.rstrip() for line in f]
- args.token_list = list(token_list)
- vocab_size = len(token_list)
- logging.info(f"Vocabulary size: {vocab_size}")
- elif isinstance(args.token_list, (tuple, list)):
- token_list = list(args.token_list)
- vocab_size = len(token_list)
- logging.info(f"Vocabulary size: {vocab_size}")
- else:
- token_list = None
- vocab_size = None
-
- # frontend
- if hasattr(args, "input_size") and args.input_size is None:
- frontend_class = frontend_choices.get_class(args.frontend)
- if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
- frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
- else:
- frontend = frontend_class(**args.frontend_conf)
- input_size = frontend.output_size()
- else:
- args.frontend = None
- args.frontend_conf = {}
- frontend = None
- input_size = args.input_size if hasattr(args, "input_size") else None
-
- # data augmentation for spectrogram
- if args.specaug is not None:
- specaug_class = specaug_choices.get_class(args.specaug)
- specaug = specaug_class(**args.specaug_conf)
- else:
- specaug = None
-
- # normalization layer
- if args.normalize is not None:
- normalize_class = normalize_choices.get_class(args.normalize)
- if args.model == "mfcca":
- normalize = normalize_class(stats_file=args.cmvn_file, **args.normalize_conf)
- else:
- normalize = normalize_class(**args.normalize_conf)
- else:
- normalize = None
-
- # encoder
- encoder_class = encoder_choices.get_class(args.encoder)
- encoder = encoder_class(input_size=input_size, **args.encoder_conf)
-
- # decoder
- if hasattr(args, "decoder") and args.decoder is not None:
- decoder_class = decoder_choices.get_class(args.decoder)
- decoder = decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder.output_size(),
- **args.decoder_conf,
- )
- else:
- decoder = None
-
- # ctc
- ctc = CTC(
- odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
- )
-
- if args.model in ["asr", "mfcca"]:
- model_class = model_choices.get_class(args.model)
- model = model_class(
- vocab_size=vocab_size,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- ctc=ctc,
- token_list=token_list,
- **args.model_conf,
- )
- elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer",
- "contextual_paraformer", "neatcontextual_paraformer"]:
- # predictor
- predictor_class = predictor_choices.get_class(args.predictor)
- predictor = predictor_class(**args.predictor_conf)
-
- model_class = model_choices.get_class(args.model)
- model = model_class(
- vocab_size=vocab_size,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- ctc=ctc,
- token_list=token_list,
- predictor=predictor,
- **args.model_conf,
- )
- elif args.model == "uniasr":
- # stride_conv
- stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
- stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder.output_size(),
- odim=input_size + encoder.output_size())
- stride_conv_output_size = stride_conv.output_size()
-
- # encoder2
- encoder_class2 = encoder_choices2.get_class(args.encoder2)
- encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
-
- # decoder2
- decoder_class2 = decoder_choices2.get_class(args.decoder2)
- decoder2 = decoder_class2(
- vocab_size=vocab_size,
- encoder_output_size=encoder2.output_size(),
- **args.decoder2_conf,
- )
-
- # ctc2
- ctc2 = CTC(
- odim=vocab_size, encoder_output_size=encoder2.output_size(), **args.ctc_conf
- )
-
- # predictor
- predictor_class = predictor_choices.get_class(args.predictor)
- predictor = predictor_class(**args.predictor_conf)
-
- # predictor2
- predictor_class = predictor_choices2.get_class(args.predictor2)
- predictor2 = predictor_class(**args.predictor2_conf)
-
- model_class = model_choices.get_class(args.model)
- model = model_class(
- vocab_size=vocab_size,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- ctc=ctc,
- token_list=token_list,
- predictor=predictor,
- ctc2=ctc2,
- encoder2=encoder2,
- decoder2=decoder2,
- predictor2=predictor2,
- stride_conv=stride_conv,
- **args.model_conf,
- )
- elif args.model == "timestamp_prediction":
- # predictor
- predictor_class = predictor_choices.get_class(args.predictor)
- predictor = predictor_class(**args.predictor_conf)
-
- model_class = model_choices.get_class(args.model)
- model = model_class(
- frontend=frontend,
- encoder=encoder,
- predictor=predictor,
- token_list=token_list,
- **args.model_conf,
- )
- elif args.model == "rnnt" or args.model == "rnnt_unified":
- # 5. Decoder
- encoder_output_size = encoder.output_size()
-
- rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
- decoder = rnnt_decoder_class(
- vocab_size,
- **args.rnnt_decoder_conf,
- )
- decoder_output_size = decoder.output_size
-
- if getattr(args, "decoder", None) is not None:
- att_decoder_class = decoder_choices.get_class(args.decoder)
-
- att_decoder = att_decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- **args.decoder_conf,
- )
- else:
- att_decoder = None
- # 6. Joint Network
- joint_network = JointNetwork(
- vocab_size,
- encoder_output_size,
- decoder_output_size,
- **args.joint_network_conf,
- )
-
- model_class = model_choices.get_class(args.model)
- # 7. Build model
- model = model_class(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- att_decoder=att_decoder,
- joint_network=joint_network,
- **args.model_conf,
- )
- elif args.model == "bat":
- # 5. Decoder
- encoder_output_size = encoder.output_size()
-
- rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
- decoder = rnnt_decoder_class(
- vocab_size,
- **args.rnnt_decoder_conf,
- )
- decoder_output_size = decoder.output_size
-
- if getattr(args, "decoder", None) is not None:
- att_decoder_class = decoder_choices.get_class(args.decoder)
-
- att_decoder = att_decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- **args.decoder_conf,
- )
- else:
- att_decoder = None
- # 6. Joint Network
- joint_network = JointNetwork(
- vocab_size,
- encoder_output_size,
- decoder_output_size,
- **args.joint_network_conf,
- )
-
- predictor_class = predictor_choices.get_class(args.predictor)
- predictor = predictor_class(**args.predictor_conf)
-
- model_class = model_choices.get_class(args.model)
- # 7. Build model
- model = model_class(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- att_decoder=att_decoder,
- joint_network=joint_network,
- predictor=predictor,
- **args.model_conf,
- )
- elif args.model == "sa_asr":
- asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
- asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
- spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
- spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
- decoder = decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=asr_encoder.output_size(),
- **args.decoder_conf,
- )
- ctc = CTC(
- odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
- )
-
- model_class = model_choices.get_class(args.model)
- model = model_class(
- vocab_size=vocab_size,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- asr_encoder=asr_encoder,
- spk_encoder=spk_encoder,
- decoder=decoder,
- ctc=ctc,
- token_list=token_list,
- **args.model_conf,
- )
-
- else:
- raise NotImplementedError("Not supported model: {}".format(args.model))
-
- # initialize
- if args.init is not None:
- initialize(model, args.init)
-
- return model
diff --git a/funasr/build_utils/build_dataloader.py b/funasr/build_utils/build_dataloader.py
deleted file mode 100644
index 473097e..0000000
--- a/funasr/build_utils/build_dataloader.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
-from funasr.datasets.small_datasets.sequence_iter_factory import SequenceIterFactory
-
-
-def build_dataloader(args):
- if args.dataset_type == "small":
- if args.task_name == "diar" and args.model == "eend_ola":
- from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader
- train_iter_factory = EENDOLADataLoader(
- data_file=args.train_data_path_and_name_and_type[0][0],
- batch_size=args.dataset_conf["batch_conf"]["batch_size"],
- num_workers=args.dataset_conf["num_workers"],
- shuffle=True)
- valid_iter_factory = EENDOLADataLoader(
- data_file=args.valid_data_path_and_name_and_type[0][0],
- batch_size=args.dataset_conf["batch_conf"]["batch_size"],
- num_workers=0,
- shuffle=False)
- else:
- train_iter_factory = SequenceIterFactory(args, mode="train")
- valid_iter_factory = SequenceIterFactory(args, mode="valid")
- elif args.dataset_type == "large":
- train_iter_factory = LargeDataLoader(args, mode="train")
- valid_iter_factory = LargeDataLoader(args, mode="valid")
- else:
- raise ValueError(f"Not supported dataset_type={args.dataset_type}")
-
- return train_iter_factory, valid_iter_factory
diff --git a/funasr/build_utils/build_diar_model.py b/funasr/build_utils/build_diar_model.py
deleted file mode 100644
index 1be04c7..0000000
--- a/funasr/build_utils/build_diar_model.py
+++ /dev/null
@@ -1,326 +0,0 @@
-import logging
-
-import torch
-
-from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
-from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
-from funasr.models.e2e_diar_sond import DiarSondModel
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
-from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
-from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
-from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
-from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
-from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
-from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
-from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
-from funasr.models.encoder.transformer_encoder import TransformerEncoder
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.frontend.wav_frontend import WavFrontendMel23
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.specaug.specaug import SpecAug
-from funasr.models.specaug.specaug import SpecAugLFR
-from funasr.models.specaug.abs_profileaug import AbsProfileAug
-from funasr.models.specaug.profileaug import ProfileAug
-from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
-from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
-from funasr.torch_utils.initialize import initialize
-from funasr.train.class_choices import ClassChoices
-
-frontend_choices = ClassChoices(
- name="frontend",
- classes=dict(
- default=DefaultFrontend,
- sliding_window=SlidingWindow,
- s3prl=S3prlFrontend,
- fused=FusedFrontends,
- wav_frontend=WavFrontend,
- wav_frontend_mel23=WavFrontendMel23,
- ),
- default="default",
-)
-specaug_choices = ClassChoices(
- name="specaug",
- classes=dict(
- specaug=SpecAug,
- specaug_lfr=SpecAugLFR,
- ),
- default=None,
- optional=True,
-)
-profileaug_choices = ClassChoices(
- name="profileaug",
- classes=dict(
- profileaug=ProfileAug,
- ),
- type_check=AbsProfileAug,
- default=None,
- optional=True,
-)
-normalize_choices = ClassChoices(
- "normalize",
- classes=dict(
- global_mvn=GlobalMVN,
- utterance_mvn=UtteranceMVN,
- ),
- default=None,
- optional=True,
-)
-label_aggregator_choices = ClassChoices(
- "label_aggregator",
- classes=dict(
- label_aggregator=LabelAggregate,
- label_aggregator_max_pool=LabelAggregateMaxPooling,
- ),
- default=None,
- optional=True,
-)
-model_choices = ClassChoices(
- "model",
- classes=dict(
- sond=DiarSondModel,
- eend_ola=DiarEENDOLAModel,
- ),
- default="sond",
-)
-encoder_choices = ClassChoices(
- "encoder",
- classes=dict(
- conformer=ConformerEncoder,
- transformer=TransformerEncoder,
- rnn=RNNEncoder,
- sanm=SANMEncoder,
- san=SelfAttentionEncoder,
- fsmn=FsmnEncoder,
- conv=ConvEncoder,
- resnet34=ResNet34Diar,
- resnet34_sp_l2reg=ResNet34SpL2RegDiar,
- sanm_chunk_opt=SANMEncoderChunkOpt,
- data2vec_encoder=Data2VecEncoder,
- ecapa_tdnn=ECAPA_TDNN,
- eend_ola_transformer=EENDOLATransformerEncoder,
- ),
- default="resnet34",
-)
-speaker_encoder_choices = ClassChoices(
- "speaker_encoder",
- classes=dict(
- conformer=ConformerEncoder,
- transformer=TransformerEncoder,
- rnn=RNNEncoder,
- sanm=SANMEncoder,
- san=SelfAttentionEncoder,
- fsmn=FsmnEncoder,
- conv=ConvEncoder,
- sanm_chunk_opt=SANMEncoderChunkOpt,
- data2vec_encoder=Data2VecEncoder,
- ),
- default=None,
- optional=True
-)
-cd_scorer_choices = ClassChoices(
- "cd_scorer",
- classes=dict(
- san=SelfAttentionEncoder,
- ),
- default=None,
- optional=True,
-)
-ci_scorer_choices = ClassChoices(
- "ci_scorer",
- classes=dict(
- dot=DotScorer,
- cosine=CosScorer,
- conv=ConvEncoder,
- ),
- type_check=torch.nn.Module,
- default=None,
- optional=True,
-)
-# decoder is used for output (e.g. post_net in SOND)
-decoder_choices = ClassChoices(
- "decoder",
- classes=dict(
- rnn=RNNEncoder,
- fsmn=FsmnEncoder,
- ),
- type_check=torch.nn.Module,
- default="fsmn",
-)
-# encoder_decoder_attractor is used for EEND-OLA
-encoder_decoder_attractor_choices = ClassChoices(
- "encoder_decoder_attractor",
- classes=dict(
- eda=EncoderDecoderAttractor,
- ),
- type_check=torch.nn.Module,
- default="eda",
-)
-class_choices_list = [
- # --frontend and --frontend_conf
- frontend_choices,
- # --specaug and --specaug_conf
- specaug_choices,
- # --profileaug and --profileaug_conf
- profileaug_choices,
- # --normalize and --normalize_conf
- normalize_choices,
- # --label_aggregator and --label_aggregator_conf
- label_aggregator_choices,
- # --model and --model_conf
- model_choices,
- # --encoder and --encoder_conf
- encoder_choices,
- # --speaker_encoder and --speaker_encoder_conf
- speaker_encoder_choices,
- # --cd_scorer and cd_scorer_conf
- cd_scorer_choices,
- # --ci_scorer and ci_scorer_conf
- ci_scorer_choices,
- # --decoder and --decoder_conf
- decoder_choices,
- # --eda and --eda_conf
- encoder_decoder_attractor_choices,
-]
-
-
-def build_diar_model(args):
- # token_list
- if args.token_list is not None:
- if isinstance(args.token_list, str):
- with open(args.token_list, encoding="utf-8") as f:
- token_list = [line.rstrip() for line in f]
-
- # Overwriting token_list to keep it as "portable".
- args.token_list = list(token_list)
- elif isinstance(args.token_list, (tuple, list)):
- token_list = list(args.token_list)
- else:
- raise RuntimeError("token_list must be str or list")
- vocab_size = len(token_list)
- logging.info(f"Vocabulary size: {vocab_size}")
- else:
- token_list = None
- vocab_size = None
-
- # frontend
- if args.input_size is None:
- frontend_class = frontend_choices.get_class(args.frontend)
- if args.frontend == 'wav_frontend':
- frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
- else:
- frontend = frontend_class(**args.frontend_conf)
- input_size = frontend.output_size()
- else:
- args.frontend = None
- args.frontend_conf = {}
- frontend = None
- input_size = args.input_size
-
- if args.model == "sond":
- # encoder
- encoder_class = encoder_choices.get_class(args.encoder)
- encoder = encoder_class(input_size=input_size ,**args.encoder_conf)
-
- # data augmentation for spectrogram
- if args.specaug is not None:
- specaug_class = specaug_choices.get_class(args.specaug)
- specaug = specaug_class(**args.specaug_conf)
- else:
- specaug = None
-
- # Data augmentation for Profiles
- if hasattr(args, "profileaug") and args.profileaug is not None:
- profileaug_class = profileaug_choices.get_class(args.profileaug)
- profileaug = profileaug_class(**args.profileaug_conf)
- else:
- profileaug = None
-
- # normalization layer
- if args.normalize is not None:
- normalize_class = normalize_choices.get_class(args.normalize)
- normalize = normalize_class(**args.normalize_conf)
- else:
- normalize = None
-
- # speaker encoder
- if getattr(args, "speaker_encoder", None) is not None:
- speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
- speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
- else:
- speaker_encoder = None
-
- # ci scorer
- if getattr(args, "ci_scorer", None) is not None:
- ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
- ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
- else:
- ci_scorer = None
-
- # cd scorer
- if getattr(args, "cd_scorer", None) is not None:
- cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
- cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
- else:
- cd_scorer = None
-
- # decoder
- decoder_class = decoder_choices.get_class(args.decoder)
- decoder = decoder_class(**args.decoder_conf)
-
- # logger aggregator
- if getattr(args, "label_aggregator", None) is not None:
- label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
- label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
- else:
- label_aggregator = None
-
- model_class = model_choices.get_class(args.model)
- model = model_class(
- vocab_size=vocab_size,
- frontend=frontend,
- specaug=specaug,
- profileaug=profileaug,
- normalize=normalize,
- label_aggregator=label_aggregator,
- encoder=encoder,
- speaker_encoder=speaker_encoder,
- ci_scorer=ci_scorer,
- cd_scorer=cd_scorer,
- decoder=decoder,
- token_list=token_list,
- **args.model_conf,
- )
-
- elif args.model == "eend_ola":
- # encoder
- encoder_class = encoder_choices.get_class(args.encoder)
- encoder = encoder_class(**args.encoder_conf)
-
- # encoder-decoder attractor
- encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
- encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
-
- # 9. Build model
- model_class = model_choices.get_class(args.model)
- model = model_class(
- frontend=frontend,
- encoder=encoder,
- encoder_decoder_attractor=encoder_decoder_attractor,
- **args.model_conf,
- )
-
- else:
- raise NotImplementedError("Not supported model: {}".format(args.model))
-
- # 10. Initialize
- if args.init is not None:
- initialize(model, args.init)
-
- return model
diff --git a/funasr/build_utils/build_distributed.py b/funasr/build_utils/build_distributed.py
deleted file mode 100644
index b64b4c0..0000000
--- a/funasr/build_utils/build_distributed.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import logging
-import os
-
-import torch
-
-from funasr.train.distributed_utils import DistributedOption
-from funasr.utils.build_dataclass import build_dataclass
-
-
-def build_distributed(args):
- distributed_option = build_dataclass(DistributedOption, args)
- if args.use_pai:
- distributed_option.init_options_pai()
- distributed_option.init_torch_distributed_pai(args)
- elif not args.simple_ddp:
- distributed_option.init_torch_distributed(args)
- elif args.distributed and args.simple_ddp:
- distributed_option.init_torch_distributed_pai(args)
- args.ngpu = torch.distributed.get_world_size()
-
- for handler in logging.root.handlers[:]:
- logging.root.removeHandler(handler)
- if not distributed_option.distributed or distributed_option.dist_rank == 0:
- logging.basicConfig(
- level="INFO",
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- else:
- logging.basicConfig(
- level="ERROR",
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
- distributed_option.dist_rank,
- distributed_option.local_rank))
- return distributed_option
diff --git a/funasr/build_utils/build_lm_model.py b/funasr/build_utils/build_lm_model.py
deleted file mode 100644
index f78a20e..0000000
--- a/funasr/build_utils/build_lm_model.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import logging
-
-from funasr.train.abs_model import AbsLM
-from funasr.train.abs_model import LanguageModel
-from funasr.models.seq_rnn_lm import SequentialRNNLM
-from funasr.models.transformer_lm import TransformerLM
-from funasr.torch_utils.initialize import initialize
-from funasr.train.class_choices import ClassChoices
-
-lm_choices = ClassChoices(
- "lm",
- classes=dict(
- seq_rnn=SequentialRNNLM,
- transformer=TransformerLM,
- ),
- type_check=AbsLM,
- default="seq_rnn",
-)
-model_choices = ClassChoices(
- "model",
- classes=dict(
- lm=LanguageModel,
- ),
- default="lm",
-)
-
-class_choices_list = [
- # --lm and --lm_conf
- lm_choices,
- # --model and --model_conf
- model_choices
-]
-
-
-def build_lm_model(args):
- # token_list
- if isinstance(args.token_list, str):
- with open(args.token_list, encoding="utf-8") as f:
- token_list = [line.rstrip() for line in f]
- args.token_list = list(token_list)
- vocab_size = len(token_list)
- logging.info(f"Vocabulary size: {vocab_size}")
- elif isinstance(args.token_list, (tuple, list)):
- token_list = list(args.token_list)
- vocab_size = len(token_list)
- logging.info(f"Vocabulary size: {vocab_size}")
- else:
- vocab_size = None
-
- # lm
- lm_class = lm_choices.get_class(args.lm)
- lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
-
- args.model = args.model if hasattr(args, "model") else "lm"
- model_class = model_choices.get_class(args.model)
- model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
-
- # initialize
- if args.init is not None:
- initialize(model, args.init)
-
- return model
diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py
deleted file mode 100644
index 66fdfd0..0000000
--- a/funasr/build_utils/build_model.py
+++ /dev/null
@@ -1,31 +0,0 @@
-from funasr.build_utils.build_asr_model import build_asr_model
-from funasr.build_utils.build_diar_model import build_diar_model
-from funasr.build_utils.build_lm_model import build_lm_model
-from funasr.build_utils.build_pretrain_model import build_pretrain_model
-from funasr.build_utils.build_punc_model import build_punc_model
-from funasr.build_utils.build_sv_model import build_sv_model
-from funasr.build_utils.build_vad_model import build_vad_model
-from funasr.build_utils.build_ss_model import build_ss_model
-
-
-def build_model(args):
- if args.task_name == "asr":
- model = build_asr_model(args)
- elif args.task_name == "pretrain":
- model = build_pretrain_model(args)
- elif args.task_name == "lm":
- model = build_lm_model(args)
- elif args.task_name == "punc":
- model = build_punc_model(args)
- elif args.task_name == "vad":
- model = build_vad_model(args)
- elif args.task_name == "diar":
- model = build_diar_model(args)
- elif args.task_name == "sv":
- model = build_sv_model(args)
- elif args.task_name == "ss":
- model = build_ss_model(args)
- else:
- raise NotImplementedError("Not supported task: {}".format(args.task_name))
-
- return model
diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py
deleted file mode 100644
index 65e0d5f..0000000
--- a/funasr/build_utils/build_model_from_file.py
+++ /dev/null
@@ -1,193 +0,0 @@
-import argparse
-import logging
-import os
-from pathlib import Path
-from typing import Union
-
-import torch
-import yaml
-
-from funasr.build_utils.build_model import build_model
-from funasr.models.base_model import FunASRModel
-
-
-def build_model_from_file(
- config_file: Union[Path, str] = None,
- model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- device: str = "cpu",
- task_name: str = "asr",
- mode: str = "paraformer",
-):
- """Build model from the files.
-
- This method is used for inference or fine-tuning.
-
- Args:
- config_file: The yaml file saved when training.
- model_file: The model file saved when training.
- device: Device type, "cpu", "cuda", or "cuda:N".
-
- """
- if config_file is None:
- assert model_file is not None, (
- "The argument 'model_file' must be provided "
- "if the argument 'config_file' is not specified."
- )
- config_file = Path(model_file).parent / "config.yaml"
- else:
- config_file = Path(config_file)
-
- with config_file.open("r", encoding="utf-8") as f:
- args = yaml.safe_load(f)
- if cmvn_file is not None:
- args["cmvn_file"] = cmvn_file
- args = argparse.Namespace(**args)
- args.task_name = task_name
- model = build_model(args)
- if not isinstance(model, FunASRModel):
- raise RuntimeError(
- f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
- )
- model.to(device)
- model_dict = dict()
- model_name_pth = None
- if model_file is not None:
- logging.info("model_file is {}".format(model_file))
- if device == "cuda":
- device = f"cuda:{torch.cuda.current_device()}"
- model_dir = os.path.dirname(model_file)
- model_name = os.path.basename(model_file)
- if "model.ckpt-" in model_name or ".bin" in model_name:
- model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
- '.pb')) if ".bin" in model_name else os.path.join(
- model_dir, "{}.pb".format(model_name))
- if os.path.exists(model_name_pth):
- logging.info("model_file is load from pth: {}".format(model_name_pth))
- model_dict = torch.load(model_name_pth, map_location=device)
- else:
- model_dict = convert_tf2torch(model, model_file, mode)
- model.load_state_dict(model_dict)
- else:
- model_dict = torch.load(model_file, map_location=device)
- if task_name == "ss":
- model_dict = model_dict['model']
- if task_name == "diar" and mode == "sond":
- model_dict = fileter_model_dict(model_dict, model.state_dict())
- if task_name == "vad":
- model.encoder.load_state_dict(model_dict)
- else:
- model.load_state_dict(model_dict)
- if model_name_pth is not None and not os.path.exists(model_name_pth):
- torch.save(model_dict, model_name_pth)
- logging.info("model_file is saved to pth: {}".format(model_name_pth))
-
- return model, args
-
-
-def convert_tf2torch(
- model,
- ckpt,
- mode,
-):
- assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp"
- logging.info("start convert tf model to torch model")
- from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
- var_dict_tf = load_tf_dict(ckpt)
- var_dict_torch = model.state_dict()
- var_dict_torch_update = dict()
- if mode == "uniasr":
- # encoder
- var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # predictor
- var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # decoder
- var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # encoder2
- var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # predictor2
- var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # decoder2
- var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # stride_conv
- var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- elif mode == "paraformer":
- # encoder
- var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # predictor
- var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # decoder
- var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # bias_encoder
- var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- elif "mode" == "sond":
- if model.encoder is not None:
- var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # speaker encoder
- if model.speaker_encoder is not None:
- var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # cd scorer
- if model.cd_scorer is not None:
- var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # ci scorer
- if model.ci_scorer is not None:
- var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # decoder
- if model.decoder is not None:
- var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- elif "mode" == "sv":
- # speech encoder
- var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # pooling layer
- var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # decoder
- var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- else:
- # encoder
- var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # predictor
- var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # decoder
- var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- # bias_encoder
- var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
- return var_dict_torch_update
-
- return var_dict_torch_update
-
-
-def fileter_model_dict(src_dict: dict, dest_dict: dict):
- from collections import OrderedDict
- new_dict = OrderedDict()
- for key, value in src_dict.items():
- if key in dest_dict:
- new_dict[key] = value
- else:
- logging.info("{} is no longer needed in this model.".format(key))
- for key, value in dest_dict.items():
- if key not in new_dict:
- logging.warning("{} is missed in checkpoint.".format(key))
- return new_dict
diff --git a/funasr/build_utils/build_optimizer.py b/funasr/build_utils/build_optimizer.py
deleted file mode 100644
index bd0b73d..0000000
--- a/funasr/build_utils/build_optimizer.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import torch
-
-from funasr.optimizers.fairseq_adam import FairseqAdam
-from funasr.optimizers.sgd import SGD
-
-
-def build_optimizer(args, model):
- optim_classes = dict(
- adam=torch.optim.Adam,
- fairseq_adam=FairseqAdam,
- adamw=torch.optim.AdamW,
- sgd=SGD,
- adadelta=torch.optim.Adadelta,
- adagrad=torch.optim.Adagrad,
- adamax=torch.optim.Adamax,
- asgd=torch.optim.ASGD,
- lbfgs=torch.optim.LBFGS,
- rmsprop=torch.optim.RMSprop,
- rprop=torch.optim.Rprop,
- )
-
- optim_class = optim_classes.get(args.optim)
- if optim_class is None:
- raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
- optimizer = optim_class(model.parameters(), **args.optim_conf)
-
- optimizers = [optimizer]
- return optimizers
\ No newline at end of file
diff --git a/funasr/build_utils/build_pretrain_model.py b/funasr/build_utils/build_pretrain_model.py
deleted file mode 100644
index 0784fb2..0000000
--- a/funasr/build_utils/build_pretrain_model.py
+++ /dev/null
@@ -1,112 +0,0 @@
-from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.models.data2vec import Data2VecPretrainModel
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.specaug.specaug import SpecAug
-from funasr.torch_utils.initialize import initialize
-from funasr.train.class_choices import ClassChoices
-
-frontend_choices = ClassChoices(
- name="frontend",
- classes=dict(
- default=DefaultFrontend,
- sliding_window=SlidingWindow,
- wav_frontend=WavFrontend,
- ),
- default="default",
-)
-specaug_choices = ClassChoices(
- name="specaug",
- classes=dict(specaug=SpecAug),
- default=None,
- optional=True,
-)
-normalize_choices = ClassChoices(
- "normalize",
- classes=dict(
- global_mvn=GlobalMVN,
- utterance_mvn=UtteranceMVN,
- ),
- default=None,
- optional=True,
-)
-encoder_choices = ClassChoices(
- "encoder",
- classes=dict(
- data2vec_encoder=Data2VecEncoder,
- ),
- default="data2vec_encoder",
-)
-model_choices = ClassChoices(
- "model",
- classes=dict(
- data2vec=Data2VecPretrainModel,
- ),
- default="data2vec",
-)
-class_choices_list = [
- # --frontend and --frontend_conf
- frontend_choices,
- # --specaug and --specaug_conf
- specaug_choices,
- # --normalize and --normalize_conf
- normalize_choices,
- # --encoder and --encoder_conf
- encoder_choices,
- # --model and --model_conf
- model_choices,
-]
-
-
-def build_pretrain_model(args):
- # frontend
- if args.input_size is None:
- frontend_class = frontend_choices.get_class(args.frontend)
- frontend = frontend_class(**args.frontend_conf)
- input_size = frontend.output_size()
- else:
- args.frontend = None
- args.frontend_conf = {}
- frontend = None
- input_size = args.input_size
-
- # data augmentation for spectrogram
- if args.specaug is not None:
- specaug_class = specaug_choices.get_class(args.specaug)
- specaug = specaug_class(**args.specaug_conf)
- else:
- specaug = None
-
- # normalization layer
- if args.normalize is not None:
- normalize_class = normalize_choices.get_class(args.normalize)
- normalize = normalize_class(**args.normalize_conf)
- else:
- normalize = None
-
- # encoder
- encoder_class = encoder_choices.get_class(args.encoder)
- encoder = encoder_class(
- input_size=input_size,
- **args.encoder_conf,
- )
-
- if args.model == "data2vec":
- model_class = model_choices.get_class("data2vec")
- model = model_class(
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- )
- else:
- raise NotImplementedError("Not supported model: {}".format(args.model))
-
- # initialize
- if args.init is not None:
- initialize(model, args.init)
-
- return model
diff --git a/funasr/build_utils/build_punc_model.py b/funasr/build_utils/build_punc_model.py
deleted file mode 100644
index 62ccaf2..0000000
--- a/funasr/build_utils/build_punc_model.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import logging
-
-from funasr.models.target_delay_transformer import TargetDelayTransformer
-from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
-from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_model import PunctuationModel
-from funasr.train.class_choices import ClassChoices
-
-punc_choices = ClassChoices(
- "punctuation",
- classes=dict(
- target_delay=TargetDelayTransformer,
- vad_realtime=VadRealtimeTransformer
- ),
- default="target_delay",
-)
-model_choices = ClassChoices(
- "model",
- classes=dict(
- punc=PunctuationModel,
- ),
- default="punc",
-)
-class_choices_list = [
- # --punc and --punc_conf
- punc_choices,
- # --model and --model_conf
- model_choices
-]
-
-
-def build_punc_model(args):
- # token_list and punc list
- if isinstance(args.token_list, str):
- with open(args.token_list, encoding="utf-8") as f:
- token_list = [line.rstrip() for line in f]
- args.token_list = token_list.copy()
- if isinstance(args.punc_list, str):
- with open(args.punc_list, encoding="utf-8") as f2:
- pairs = [line.rstrip().split(":") for line in f2]
- punc_list = [pair[0] for pair in pairs]
- punc_weight_list = [float(pair[1]) for pair in pairs]
- args.punc_list = punc_list.copy()
- elif isinstance(args.punc_list, list):
- punc_list = args.punc_list.copy()
- punc_weight_list = [1] * len(punc_list)
- if isinstance(args.token_list, (tuple, list)):
- token_list = args.token_list.copy()
- else:
- raise RuntimeError("token_list must be str or dict")
-
- vocab_size = len(token_list)
- punc_size = len(punc_list)
- logging.info(f"Vocabulary size: {vocab_size}")
-
- # punc
- punc_class = punc_choices.get_class(args.punctuation)
- punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf)
-
- if "punc_weight" in args.model_conf:
- args.model_conf.pop("punc_weight")
- model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
-
- # initialize
- if args.init is not None:
- initialize(model, args.init)
-
- return model
diff --git a/funasr/build_utils/build_scheduler.py b/funasr/build_utils/build_scheduler.py
deleted file mode 100644
index 4b9990e..0000000
--- a/funasr/build_utils/build_scheduler.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import torch
-import torch.multiprocessing
-import torch.nn
-import torch.optim
-
-from funasr.schedulers.noam_lr import NoamLR
-from funasr.schedulers.tri_stage_scheduler import TriStageLR
-from funasr.schedulers.warmup_lr import WarmupLR
-
-
-def build_scheduler(args, optimizers):
- scheduler_classes = dict(
- ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
- lambdalr=torch.optim.lr_scheduler.LambdaLR,
- steplr=torch.optim.lr_scheduler.StepLR,
- multisteplr=torch.optim.lr_scheduler.MultiStepLR,
- exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
- CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
- noamlr=NoamLR,
- warmuplr=WarmupLR,
- tri_stage=TriStageLR,
- cycliclr=torch.optim.lr_scheduler.CyclicLR,
- onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
- CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
- )
-
- schedulers = []
- for i, optim in enumerate(optimizers, 1):
- suf = "" if i == 1 else str(i)
- name = getattr(args, f"scheduler{suf}")
- conf = getattr(args, f"scheduler{suf}_conf")
- if name is not None:
- cls_ = scheduler_classes.get(name)
- if cls_ is None:
- raise ValueError(
- f"must be one of {list(scheduler_classes)}: {name}"
- )
- scheduler = cls_(optim, **conf)
- else:
- scheduler = None
-
- schedulers.append(scheduler)
-
- return schedulers
\ No newline at end of file
diff --git a/funasr/build_utils/build_ss_model.py b/funasr/build_utils/build_ss_model.py
deleted file mode 100644
index a6b5209..0000000
--- a/funasr/build_utils/build_ss_model.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from funasr.models.e2e_ss import MossFormer
-
-def build_ss_model(args):
- model = MossFormer(
- in_channels=args.encoder_embedding_dim,
- out_channels=args.mossformer_sequence_dim,
- num_blocks=args.num_mossformer_layer,
- kernel_size=args.encoder_kernel_size,
- norm=args.norm,
- num_spks=args.num_spks,
- skip_around_intra=args.skip_around_intra,
- use_global_pos_enc=args.use_global_pos_enc,
- max_length=args.max_length)
-
- return model
diff --git a/funasr/build_utils/build_streaming_iterator.py b/funasr/build_utils/build_streaming_iterator.py
deleted file mode 100644
index 02fc263..0000000
--- a/funasr/build_utils/build_streaming_iterator.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import numpy as np
-from torch.utils.data import DataLoader
-
-from funasr.datasets.iterable_dataset import IterableESPnetDataset
-from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
-from funasr.datasets.small_datasets.preprocessor import build_preprocess
-
-
-def build_streaming_iterator(
- task_name,
- preprocess_args,
- data_path_and_name_and_type,
- key_file: str = None,
- batch_size: int = 1,
- fs: dict = None,
- mc: bool = False,
- dtype: str = np.float32,
- num_workers: int = 1,
- use_collate_fn: bool = True,
- preprocess_fn=None,
- ngpu: int = 0,
- train: bool = False,
-) -> DataLoader:
- """Build DataLoader using iterable dataset"""
-
- # preprocess
- if preprocess_fn is not None:
- preprocess_fn = preprocess_fn
- elif preprocess_args is not None:
- preprocess_args.task_name = task_name
- preprocess_fn = build_preprocess(preprocess_args, train)
- else:
- preprocess_fn = None
-
- # collate
- if not use_collate_fn:
- collate_fn = None
- elif task_name in ["punc", "lm"]:
- collate_fn = CommonCollateFn(int_pad_value=0)
- else:
- collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
- if collate_fn is not None:
- kwargs = dict(collate_fn=collate_fn)
- else:
- kwargs = {}
-
- dataset = IterableESPnetDataset(
- data_path_and_name_and_type,
- float_dtype=dtype,
- fs=fs,
- mc=mc,
- preprocess=preprocess_fn,
- key_file=key_file,
- )
- if dataset.apply_utt2category:
- kwargs.update(batch_size=1)
- else:
- kwargs.update(batch_size=batch_size)
-
- return DataLoader(
- dataset=dataset,
- pin_memory=ngpu > 0,
- num_workers=num_workers,
- **kwargs,
- )
diff --git a/funasr/build_utils/build_sv_model.py b/funasr/build_utils/build_sv_model.py
deleted file mode 100644
index 55df75a..0000000
--- a/funasr/build_utils/build_sv_model.py
+++ /dev/null
@@ -1,256 +0,0 @@
-import logging
-
-import torch
-
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.models.base_model import FunASRModel
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.decoder.sv_decoder import DenseDecoder
-from funasr.models.e2e_sv import ESPnetSVModel
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
-from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.pooling.statistic_pooling import StatisticPooling
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.postencoder.hugging_face_transformers_postencoder import (
- HuggingFaceTransformersPostEncoder, # noqa: H301
-)
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.preencoder.linear import LinearProjection
-from funasr.models.preencoder.sinc import LightweightSincConvs
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.specaug.specaug import SpecAug
-from funasr.torch_utils.initialize import initialize
-from funasr.train.class_choices import ClassChoices
-
-frontend_choices = ClassChoices(
- name="frontend",
- classes=dict(
- default=DefaultFrontend,
- sliding_window=SlidingWindow,
- s3prl=S3prlFrontend,
- fused=FusedFrontends,
- wav_frontend=WavFrontend,
- ),
- type_check=AbsFrontend,
- default="default",
-)
-specaug_choices = ClassChoices(
- name="specaug",
- classes=dict(
- specaug=SpecAug,
- ),
- type_check=AbsSpecAug,
- default=None,
- optional=True,
-)
-normalize_choices = ClassChoices(
- "normalize",
- classes=dict(
- global_mvn=GlobalMVN,
- utterance_mvn=UtteranceMVN,
- ),
- type_check=AbsNormalize,
- default=None,
- optional=True,
-)
-model_choices = ClassChoices(
- "model",
- classes=dict(
- espnet=ESPnetSVModel,
- ),
- type_check=FunASRModel,
- default="espnet",
-)
-preencoder_choices = ClassChoices(
- name="preencoder",
- classes=dict(
- sinc=LightweightSincConvs,
- linear=LinearProjection,
- ),
- type_check=AbsPreEncoder,
- default=None,
- optional=True,
-)
-encoder_choices = ClassChoices(
- "encoder",
- classes=dict(
- resnet34=ResNet34,
- resnet34_sp_l2reg=ResNet34_SP_L2Reg,
- rnn=RNNEncoder,
- ),
- type_check=AbsEncoder,
- default="resnet34",
-)
-postencoder_choices = ClassChoices(
- name="postencoder",
- classes=dict(
- hugging_face_transformers=HuggingFaceTransformersPostEncoder,
- ),
- type_check=AbsPostEncoder,
- default=None,
- optional=True,
-)
-pooling_choices = ClassChoices(
- name="pooling_type",
- classes=dict(
- statistic=StatisticPooling,
- ),
- type_check=torch.nn.Module,
- default="statistic",
-)
-decoder_choices = ClassChoices(
- "decoder",
- classes=dict(
- dense=DenseDecoder,
- ),
- type_check=AbsDecoder,
- default="dense",
-)
-
-class_choices_list = [
- # --frontend and --frontend_conf
- frontend_choices,
- # --specaug and --specaug_conf
- specaug_choices,
- # --normalize and --normalize_conf
- normalize_choices,
- # --model and --model_conf
- model_choices,
- # --preencoder and --preencoder_conf
- preencoder_choices,
- # --encoder and --encoder_conf
- encoder_choices,
- # --postencoder and --postencoder_conf
- postencoder_choices,
- # --pooling and --pooling_conf
- pooling_choices,
- # --decoder and --decoder_conf
- decoder_choices,
-]
-
-
-def build_sv_model(args):
- # token_list
- if isinstance(args.token_list, str):
- with open(args.token_list, encoding="utf-8") as f:
- token_list = [line.rstrip() for line in f]
-
- # Overwriting token_list to keep it as "portable".
- args.token_list = list(token_list)
- elif isinstance(args.token_list, (tuple, list)):
- token_list = list(args.token_list)
- else:
- raise RuntimeError("token_list must be str or list")
- vocab_size = len(token_list)
- logging.info(f"Speaker number: {vocab_size}")
-
- # 1. frontend
- if args.input_size is None:
- # Extract features in the model
- frontend_class = frontend_choices.get_class(args.frontend)
- frontend = frontend_class(**args.frontend_conf)
- input_size = frontend.output_size()
- else:
- # Give features from data-loader
- args.frontend = None
- args.frontend_conf = {}
- frontend = None
- input_size = args.input_size
-
- # 2. Data augmentation for spectrogram
- if args.specaug is not None:
- specaug_class = specaug_choices.get_class(args.specaug)
- specaug = specaug_class(**args.specaug_conf)
- else:
- specaug = None
-
- # 3. Normalization layer
- if args.normalize is not None:
- normalize_class = normalize_choices.get_class(args.normalize)
- normalize = normalize_class(**args.normalize_conf)
- else:
- normalize = None
-
- # 4. Pre-encoder input block
- # NOTE(kan-bayashi): Use getattr to keep the compatibility
- if getattr(args, "preencoder", None) is not None:
- preencoder_class = preencoder_choices.get_class(args.preencoder)
- preencoder = preencoder_class(**args.preencoder_conf)
- input_size = preencoder.output_size()
- else:
- preencoder = None
-
- # 5. Encoder
- encoder_class = encoder_choices.get_class(args.encoder)
- encoder = encoder_class(input_size=input_size, **args.encoder_conf)
-
- # 6. Post-encoder block
- # NOTE(kan-bayashi): Use getattr to keep the compatibility
- encoder_output_size = encoder.output_size()
- if getattr(args, "postencoder", None) is not None:
- postencoder_class = postencoder_choices.get_class(args.postencoder)
- postencoder = postencoder_class(
- input_size=encoder_output_size, **args.postencoder_conf
- )
- encoder_output_size = postencoder.output_size()
- else:
- postencoder = None
-
- # 7. Pooling layer
- pooling_class = pooling_choices.get_class(args.pooling_type)
- pooling_dim = (2, 3)
- eps = 1e-12
- if hasattr(args, "pooling_type_conf"):
- if "pooling_dim" in args.pooling_type_conf:
- pooling_dim = args.pooling_type_conf["pooling_dim"]
- if "eps" in args.pooling_type_conf:
- eps = args.pooling_type_conf["eps"]
- pooling_layer = pooling_class(
- pooling_dim=pooling_dim,
- eps=eps,
- )
- if args.pooling_type == "statistic":
- encoder_output_size *= 2
-
- # 8. Decoder
- decoder_class = decoder_choices.get_class(args.decoder)
- decoder = decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- **args.decoder_conf,
- )
-
- # 7. Build model
- try:
- model_class = model_choices.get_class(args.model)
- except AttributeError:
- model_class = model_choices.get_class("espnet")
- model = model_class(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- preencoder=preencoder,
- encoder=encoder,
- postencoder=postencoder,
- pooling_layer=pooling_layer,
- decoder=decoder,
- **args.model_conf,
- )
-
- # FIXME(kamo): Should be done in model?
- # 8. Initialize
- if args.init is not None:
- initialize(model, args.init)
-
- return model
diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
deleted file mode 100644
index 498d05d..0000000
--- a/funasr/build_utils/build_trainer.py
+++ /dev/null
@@ -1,812 +0,0 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-
-"""Trainer module."""
-import argparse
-import dataclasses
-import logging
-import os
-import time
-from contextlib import contextmanager
-from dataclasses import is_dataclass
-from distutils.version import LooseVersion
-from io import BytesIO
-from pathlib import Path
-from typing import Dict
-from typing import Iterable
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import humanfriendly
-import oss2
-import torch
-import torch.nn
-import torch.optim
-
-from funasr.iterators.abs_iter_factory import AbsIterFactory
-from funasr.main_funcs.average_nbest_models import average_nbest_models
-from funasr.models.base_model import FunASRModel
-from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
-from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler
-from funasr.schedulers.abs_scheduler import AbsScheduler
-from funasr.schedulers.abs_scheduler import AbsValEpochStepScheduler
-from funasr.torch_utils.add_gradient_noise import add_gradient_noise
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.recursive_op import recursive_average
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.distributed_utils import DistributedOption
-from funasr.train.reporter import Reporter
-from funasr.train.reporter import SubReporter
-from funasr.utils.build_dataclass import build_dataclass
-
-if torch.distributed.is_available():
- from torch.distributed import ReduceOp
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
- from torch.cuda.amp import GradScaler
-else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
-
-
- GradScaler = None
-
-try:
- import fairscale
-except ImportError:
- fairscale = None
-
-
-@dataclasses.dataclass
-class TrainerOptions:
- ngpu: int
- resume: bool
- use_amp: bool
- train_dtype: str
- grad_noise: bool
- accum_grad: int
- grad_clip: float
- grad_clip_type: float
- log_interval: Optional[int]
- # no_forward_run: bool
- use_tensorboard: bool
- # use_wandb: bool
- output_dir: Union[Path, str]
- max_epoch: int
- max_update: int
- seed: int
- # sharded_ddp: bool
- patience: Optional[int]
- keep_nbest_models: Union[int, List[int]]
- nbest_averaging_interval: int
- early_stopping_criterion: Sequence[str]
- best_model_criterion: Sequence[Sequence[str]]
- val_scheduler_criterion: Sequence[str]
- unused_parameters: bool
- # wandb_model_log_interval: int
- use_pai: bool
- oss_bucket: Union[oss2.Bucket, None]
-
-
-class Trainer:
- """Trainer
-
- """
-
- def __init__(self,
- args,
- model: FunASRModel,
- optimizers: Sequence[torch.optim.Optimizer],
- schedulers: Sequence[Optional[AbsScheduler]],
- train_dataloader: AbsIterFactory,
- valid_dataloader: AbsIterFactory,
- distributed_option: DistributedOption):
- self.trainer_options = self.build_options(args)
- self.model = model
- self.optimizers = optimizers
- self.schedulers = schedulers
- self.train_dataloader = train_dataloader
- self.valid_dataloader = valid_dataloader
- self.distributed_option = distributed_option
-
- def build_options(self, args: argparse.Namespace) -> TrainerOptions:
- """Build options consumed by train(), eval()"""
- return build_dataclass(TrainerOptions, args)
-
- @classmethod
- def add_arguments(cls, parser: argparse.ArgumentParser):
- """Reserved for future development of another Trainer"""
- pass
-
- def resume(self,
- checkpoint: Union[str, Path],
- model: torch.nn.Module,
- reporter: Reporter,
- optimizers: Sequence[torch.optim.Optimizer],
- schedulers: Sequence[Optional[AbsScheduler]],
- scaler: Optional[GradScaler],
- ngpu: int = 0,
- ):
- states = torch.load(
- checkpoint,
- map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
- )
- model.load_state_dict(states["model"])
- reporter.load_state_dict(states["reporter"])
- for optimizer, state in zip(optimizers, states["optimizers"]):
- optimizer.load_state_dict(state)
- for scheduler, state in zip(schedulers, states["schedulers"]):
- if scheduler is not None:
- scheduler.load_state_dict(state)
- if scaler is not None:
- if states["scaler"] is None:
- logging.warning("scaler state is not found")
- else:
- scaler.load_state_dict(states["scaler"])
-
- logging.info(f"The training was resumed using {checkpoint}")
-
- def run(self) -> None:
- """Perform training. This method performs the main process of training."""
- # NOTE(kamo): Don't check the type more strictly as far trainer_options
- model = self.model
- optimizers = self.optimizers
- schedulers = self.schedulers
- train_dataloader = self.train_dataloader
- valid_dataloader = self.valid_dataloader
- trainer_options = self.trainer_options
- distributed_option = self.distributed_option
- assert is_dataclass(trainer_options), type(trainer_options)
- assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
-
- if isinstance(trainer_options.keep_nbest_models, int):
- keep_nbest_models = [trainer_options.keep_nbest_models]
- else:
- if len(trainer_options.keep_nbest_models) == 0:
- logging.warning("No keep_nbest_models is given. Change to [1]")
- trainer_options.keep_nbest_models = [1]
- keep_nbest_models = trainer_options.keep_nbest_models
-
- output_dir = Path(trainer_options.output_dir)
- reporter = Reporter()
- if trainer_options.use_amp:
- if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
- raise RuntimeError(
- "Require torch>=1.6.0 for Automatic Mixed Precision"
- )
- # if trainer_options.sharded_ddp:
- # if fairscale is None:
- # raise RuntimeError(
- # "Requiring fairscale. Do 'pip install fairscale'"
- # )
- # scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
- # else:
- scaler = GradScaler()
- else:
- scaler = None
-
- if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
- self.resume(
- checkpoint=output_dir / "checkpoint.pb",
- model=model,
- optimizers=optimizers,
- schedulers=schedulers,
- reporter=reporter,
- scaler=scaler,
- ngpu=trainer_options.ngpu,
- )
-
- start_epoch = reporter.get_epoch() + 1
- if start_epoch == trainer_options.max_epoch + 1:
- logging.warning(
- f"The training has already reached at max_epoch: {start_epoch}"
- )
-
- if distributed_option.distributed:
- dp_model = torch.nn.parallel.DistributedDataParallel(
- model, find_unused_parameters=trainer_options.unused_parameters)
- elif distributed_option.ngpu > 1:
- dp_model = torch.nn.parallel.DataParallel(
- model,
- device_ids=list(range(distributed_option.ngpu)),
- )
- else:
- # NOTE(kamo): DataParallel also should work with ngpu=1,
- # but for debuggability it's better to keep this block.
- dp_model = model
-
- if trainer_options.use_tensorboard and (
- not distributed_option.distributed or distributed_option.dist_rank == 0
- ):
- from torch.utils.tensorboard import SummaryWriter
- if trainer_options.use_pai:
- train_summary_writer = SummaryWriter(
- os.path.join(trainer_options.output_dir, "tensorboard/train")
- )
- valid_summary_writer = SummaryWriter(
- os.path.join(trainer_options.output_dir, "tensorboard/valid")
- )
- else:
- train_summary_writer = SummaryWriter(
- str(output_dir / "tensorboard" / "train")
- )
- valid_summary_writer = SummaryWriter(
- str(output_dir / "tensorboard" / "valid")
- )
- else:
- train_summary_writer = None
-
- start_time = time.perf_counter()
- for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
- if iepoch != start_epoch:
- logging.info(
- "{}/{}epoch started. Estimated time to finish: {} hours".format(
- iepoch,
- trainer_options.max_epoch,
- (time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * (
- trainer_options.max_epoch - iepoch + 1),
- )
- )
- else:
- logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
- set_all_random_seed(trainer_options.seed + iepoch)
-
- reporter.set_epoch(iepoch)
- # 1. Train and validation for one-epoch
- with reporter.observe("train") as sub_reporter:
- all_steps_are_invalid, max_update_stop = self.train_one_epoch(
- model=dp_model,
- optimizers=optimizers,
- schedulers=schedulers,
- iterator=train_dataloader.build_iter(iepoch),
- reporter=sub_reporter,
- scaler=scaler,
- summary_writer=train_summary_writer,
- options=trainer_options,
- distributed_option=distributed_option,
- )
-
- with reporter.observe("valid") as sub_reporter:
- self.validate_one_epoch(
- model=dp_model,
- iterator=valid_dataloader.build_iter(iepoch),
- reporter=sub_reporter,
- options=trainer_options,
- distributed_option=distributed_option,
- )
-
- # 2. LR Scheduler step
- for scheduler in schedulers:
- if isinstance(scheduler, AbsValEpochStepScheduler):
- scheduler.step(
- reporter.get_value(*trainer_options.val_scheduler_criterion)
- )
- elif isinstance(scheduler, AbsEpochStepScheduler):
- scheduler.step()
- # if trainer_options.sharded_ddp:
- # for optimizer in optimizers:
- # if isinstance(optimizer, fairscale.optim.oss.OSS):
- # optimizer.consolidate_state_dict()
-
- if not distributed_option.distributed or distributed_option.dist_rank == 0:
- # 3. Report the results
- logging.info(reporter.log_message())
- if train_summary_writer is not None:
- reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
- reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
- # if trainer_options.use_wandb:
- # reporter.wandb_log()
-
- # save tensorboard on oss
- if trainer_options.use_pai and train_summary_writer is not None:
- def write_tensorboard_summary(summary_writer_path, oss_bucket):
- file_list = []
- for root, dirs, files in os.walk(summary_writer_path, topdown=False):
- for name in files:
- file_full_path = os.path.join(root, name)
- file_list.append(file_full_path)
-
- for file_full_path in file_list:
- with open(file_full_path, "rb") as f:
- oss_bucket.put_object(file_full_path, f)
-
- write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"),
- trainer_options.oss_bucket)
- write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"),
- trainer_options.oss_bucket)
-
- # 4. Save/Update the checkpoint
- if trainer_options.use_pai:
- buffer = BytesIO()
- torch.save(
- {
- "model": model.state_dict(),
- "reporter": reporter.state_dict(),
- "optimizers": [o.state_dict() for o in optimizers],
- "schedulers": [
- s.state_dict() if s is not None else None
- for s in schedulers
- ],
- "scaler": scaler.state_dict() if scaler is not None else None,
- "ema_model": model.encoder.ema.model.state_dict()
- if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
- },
- buffer,
- )
- trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"),
- buffer.getvalue())
- else:
- torch.save(
- {
- "model": model.state_dict(),
- "reporter": reporter.state_dict(),
- "optimizers": [o.state_dict() for o in optimizers],
- "schedulers": [
- s.state_dict() if s is not None else None
- for s in schedulers
- ],
- "scaler": scaler.state_dict() if scaler is not None else None,
- },
- output_dir / "checkpoint.pb",
- )
-
- # 5. Save and log the model and update the link to the best model
- if trainer_options.use_pai:
- buffer = BytesIO()
- torch.save(model.state_dict(), buffer)
- trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
- f"{iepoch}epoch.pb"), buffer.getvalue())
- else:
- torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
-
- # Creates a sym link latest.pb -> {iepoch}epoch.pb
- if trainer_options.use_pai:
- p = os.path.join(trainer_options.output_dir, "latest.pb")
- if trainer_options.oss_bucket.object_exists(p):
- trainer_options.oss_bucket.delete_object(p)
- trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
- os.path.join(trainer_options.output_dir,
- f"{iepoch}epoch.pb"), p)
- else:
- p = output_dir / "latest.pb"
- if p.is_symlink() or p.exists():
- p.unlink()
- p.symlink_to(f"{iepoch}epoch.pb")
-
- _improved = []
- for _phase, k, _mode in trainer_options.best_model_criterion:
- # e.g. _phase, k, _mode = "train", "loss", "min"
- if reporter.has(_phase, k):
- best_epoch = reporter.get_best_epoch(_phase, k, _mode)
- # Creates sym links if it's the best result
- if best_epoch == iepoch:
- if trainer_options.use_pai:
- p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
- if trainer_options.oss_bucket.object_exists(p):
- trainer_options.oss_bucket.delete_object(p)
- trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
- os.path.join(trainer_options.output_dir,
- f"{iepoch}epoch.pb"), p)
- else:
- p = output_dir / f"{_phase}.{k}.best.pb"
- if p.is_symlink() or p.exists():
- p.unlink()
- p.symlink_to(f"{iepoch}epoch.pb")
- _improved.append(f"{_phase}.{k}")
- if len(_improved) == 0:
- logging.info("There are no improvements in this epoch")
- else:
- logging.info(
- "The best model has been updated: " + ", ".join(_improved)
- )
-
- # log_model = (
- # trainer_options.wandb_model_log_interval > 0
- # and iepoch % trainer_options.wandb_model_log_interval == 0
- # )
- # if log_model and trainer_options.use_wandb:
- # import wandb
- #
- # logging.info("Logging Model on this epoch :::::")
- # artifact = wandb.Artifact(
- # name=f"model_{wandb.run.id}",
- # type="model",
- # metadata={"improved": _improved},
- # )
- # artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
- # aliases = [
- # f"epoch-{iepoch}",
- # "best" if best_epoch == iepoch else "",
- # ]
- # wandb.log_artifact(artifact, aliases=aliases)
-
- # 6. Remove the model files excluding n-best epoch and latest epoch
- _removed = []
- # Get the union set of the n-best among multiple criterion
- nbests = set().union(
- *[
- set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
- for ph, k, m in trainer_options.best_model_criterion
- if reporter.has(ph, k)
- ]
- )
-
- # Generated n-best averaged model
- if (
- trainer_options.nbest_averaging_interval > 0
- and iepoch % trainer_options.nbest_averaging_interval == 0
- ):
- average_nbest_models(
- reporter=reporter,
- output_dir=output_dir,
- best_model_criterion=trainer_options.best_model_criterion,
- nbest=keep_nbest_models,
- suffix=f"till{iepoch}epoch",
- oss_bucket=trainer_options.oss_bucket,
- pai_output_dir=trainer_options.output_dir,
- )
-
- for e in range(1, iepoch):
- if trainer_options.use_pai:
- p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
- if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
- trainer_options.oss_bucket.delete_object(p)
- _removed.append(str(p))
- else:
- p = output_dir / f"{e}epoch.pb"
- if p.exists() and e not in nbests:
- p.unlink()
- _removed.append(str(p))
- if len(_removed) != 0:
- logging.info("The model files were removed: " + ", ".join(_removed))
-
- # 7. If any updating haven't happened, stops the training
- if all_steps_are_invalid:
- logging.warning(
- f"The gradients at all steps are invalid in this epoch. "
- f"Something seems wrong. This training was stopped at {iepoch}epoch"
- )
- break
-
- if max_update_stop:
- logging.info(
- f"Stopping training due to "
- f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
- )
- break
-
- # 8. Check early stopping
- if trainer_options.patience is not None:
- if reporter.check_early_stopping(
- trainer_options.patience, *trainer_options.early_stopping_criterion
- ):
- break
-
- else:
- logging.info(
- f"The training was finished at {trainer_options.max_epoch} epochs "
- )
-
- # Generated n-best averaged model
- if not distributed_option.distributed or distributed_option.dist_rank == 0:
- average_nbest_models(
- reporter=reporter,
- output_dir=output_dir,
- best_model_criterion=trainer_options.best_model_criterion,
- nbest=keep_nbest_models,
- oss_bucket=trainer_options.oss_bucket,
- pai_output_dir=trainer_options.output_dir,
- )
-
- def train_one_epoch(
- self,
- model: torch.nn.Module,
- iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
- optimizers: Sequence[torch.optim.Optimizer],
- schedulers: Sequence[Optional[AbsScheduler]],
- scaler: Optional[GradScaler],
- reporter: SubReporter,
- summary_writer,
- options: TrainerOptions,
- distributed_option: DistributedOption,
- ) -> Tuple[bool, bool]:
-
- grad_noise = options.grad_noise
- accum_grad = options.accum_grad
- grad_clip = options.grad_clip
- grad_clip_type = options.grad_clip_type
- log_interval = options.log_interval
- # no_forward_run = options.no_forward_run
- ngpu = options.ngpu
- # use_wandb = options.use_wandb
- distributed = distributed_option.distributed
-
- if log_interval is None:
- try:
- log_interval = max(len(iterator) // 20, 10)
- except TypeError:
- log_interval = 100
-
- model.train()
- all_steps_are_invalid = True
- max_update_stop = False
- # [For distributed] Because iteration counts are not always equals between
- # processes, send stop-flag to the other processes if iterator is finished
- iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
-
- start_time = time.perf_counter()
- for iiter, (_, batch) in enumerate(
- reporter.measure_iter_time(iterator, "iter_time"), 1
- ):
- assert isinstance(batch, dict), type(batch)
-
- if distributed:
- torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
- if iterator_stop > 0:
- break
-
- batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
- # if no_forward_run:
- # all_steps_are_invalid = False
- # continue
-
- with autocast(scaler is not None):
- with reporter.measure_time("forward_time"):
- retval = model(**batch)
-
- # Note(kamo):
- # Supporting two patterns for the returned value from the model
- # a. dict type
- if isinstance(retval, dict):
- loss = retval["loss"]
- stats = retval["stats"]
- weight = retval["weight"]
- optim_idx = retval.get("optim_idx")
- if optim_idx is not None and not isinstance(optim_idx, int):
- if not isinstance(optim_idx, torch.Tensor):
- raise RuntimeError(
- "optim_idx must be int or 1dim torch.Tensor, "
- f"but got {type(optim_idx)}"
- )
- if optim_idx.dim() >= 2:
- raise RuntimeError(
- "optim_idx must be int or 1dim torch.Tensor, "
- f"but got {optim_idx.dim()}dim tensor"
- )
- if optim_idx.dim() == 1:
- for v in optim_idx:
- if v != optim_idx[0]:
- raise RuntimeError(
- "optim_idx must be 1dim tensor "
- "having same values for all entries"
- )
- optim_idx = optim_idx[0].item()
- else:
- optim_idx = optim_idx.item()
-
- # b. tuple or list type
- else:
- loss, stats, weight = retval
- optim_idx = None
-
- stats = {k: v for k, v in stats.items() if v is not None}
- if ngpu > 1 or distributed:
- # Apply weighted averaging for loss and stats
- loss = (loss * weight.type(loss.dtype)).sum()
-
- # if distributed, this method can also apply all_reduce()
- stats, weight = recursive_average(stats, weight, distributed)
-
- # Now weight is summation over all workers
- loss /= weight
- if distributed:
- # NOTE(kamo): Multiply world_size because DistributedDataParallel
- # automatically normalizes the gradient by world_size.
- loss *= torch.distributed.get_world_size()
-
- loss /= accum_grad
-
- reporter.register(stats, weight)
-
- with reporter.measure_time("backward_time"):
- if scaler is not None:
- # Scales loss. Calls backward() on scaled loss
- # to create scaled gradients.
- # Backward passes under autocast are not recommended.
- # Backward ops run in the same dtype autocast chose
- # for corresponding forward ops.
- scaler.scale(loss).backward()
- else:
- loss.backward()
-
- if iiter % accum_grad == 0:
- if scaler is not None:
- # Unscales the gradients of optimizer's assigned params in-place
- for iopt, optimizer in enumerate(optimizers):
- if optim_idx is not None and iopt != optim_idx:
- continue
- scaler.unscale_(optimizer)
-
- # gradient noise injection
- if grad_noise:
- add_gradient_noise(
- model,
- reporter.get_total_count(),
- duration=100,
- eta=1.0,
- scale_factor=0.55,
- )
-
- # compute the gradient norm to check if it is normal or not
- grad_norm = torch.nn.utils.clip_grad_norm_(
- model.parameters(),
- max_norm=grad_clip,
- norm_type=grad_clip_type,
- )
- # PyTorch<=1.4, clip_grad_norm_ returns float value
- if not isinstance(grad_norm, torch.Tensor):
- grad_norm = torch.tensor(grad_norm)
-
- if not torch.isfinite(grad_norm):
- logging.warning(
- f"The grad norm is {grad_norm}. Skipping updating the model."
- )
-
- # Must invoke scaler.update() if unscale_() is used in the iteration
- # to avoid the following error:
- # RuntimeError: unscale_() has already been called
- # on this optimizer since the last update().
- # Note that if the gradient has inf/nan values,
- # scaler.step skips optimizer.step().
- if scaler is not None:
- for iopt, optimizer in enumerate(optimizers):
- if optim_idx is not None and iopt != optim_idx:
- continue
- scaler.step(optimizer)
- scaler.update()
-
- else:
- all_steps_are_invalid = False
- with reporter.measure_time("optim_step_time"):
- for iopt, (optimizer, scheduler) in enumerate(
- zip(optimizers, schedulers)
- ):
- if optim_idx is not None and iopt != optim_idx:
- continue
- if scaler is not None:
- # scaler.step() first unscales the gradients of
- # the optimizer's assigned params.
- scaler.step(optimizer)
- # Updates the scale for next iteration.
- scaler.update()
- else:
- optimizer.step()
- if isinstance(scheduler, AbsBatchStepScheduler):
- scheduler.step()
- for iopt, optimizer in enumerate(optimizers):
- if optim_idx is not None and iopt != optim_idx:
- continue
- optimizer.zero_grad()
-
- # Register lr and train/load time[sec/step],
- # where step refers to accum_grad * mini-batch
- reporter.register(
- dict(
- {
- f"optim{i}_lr{j}": pg["lr"]
- for i, optimizer in enumerate(optimizers)
- for j, pg in enumerate(optimizer.param_groups)
- if "lr" in pg
- },
- train_time=time.perf_counter() - start_time,
- ),
- )
- start_time = time.perf_counter()
-
- # update num_updates
- if distributed:
- if hasattr(model.module, "num_updates"):
- model.module.set_num_updates(model.module.get_num_updates() + 1)
- options.num_updates = model.module.get_num_updates()
- if model.module.get_num_updates() >= options.max_update:
- max_update_stop = True
- else:
- if hasattr(model, "num_updates"):
- model.set_num_updates(model.get_num_updates() + 1)
- options.num_updates = model.get_num_updates()
- if model.get_num_updates() >= options.max_update:
- max_update_stop = True
-
- # NOTE(kamo): Call log_message() after next()
- reporter.next()
- if iiter % log_interval == 0:
- num_updates = options.num_updates if hasattr(options, "num_updates") else None
- logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
- if summary_writer is not None:
- reporter.tensorboard_add_scalar(summary_writer, -log_interval)
- # if use_wandb:
- # reporter.wandb_log()
-
- if max_update_stop:
- break
-
- else:
- if distributed:
- iterator_stop.fill_(1)
- torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
- return all_steps_are_invalid, max_update_stop
-
- @torch.no_grad()
- def validate_one_epoch(
- self,
- model: torch.nn.Module,
- iterator: Iterable[Dict[str, torch.Tensor]],
- reporter: SubReporter,
- options: TrainerOptions,
- distributed_option: DistributedOption,
- ) -> None:
- ngpu = options.ngpu
- # no_forward_run = options.no_forward_run
- distributed = distributed_option.distributed
-
- model.eval()
-
- # [For distributed] Because iteration counts are not always equals between
- # processes, send stop-flag to the other processes if iterator is finished
- iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
- for (_, batch) in iterator:
- assert isinstance(batch, dict), type(batch)
- if distributed:
- torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
- if iterator_stop > 0:
- break
-
- batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
- # if no_forward_run:
- # continue
-
- retval = model(**batch)
- if isinstance(retval, dict):
- stats = retval["stats"]
- weight = retval["weight"]
- else:
- _, stats, weight = retval
- if ngpu > 1 or distributed:
- # Apply weighted averaging for stats.
- # if distributed, this method can also apply all_reduce()
- stats, weight = recursive_average(stats, weight, distributed)
-
- reporter.register(stats, weight)
- reporter.next()
-
- else:
- if distributed:
- iterator_stop.fill_(1)
- torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
-
-
-def build_trainer(
- args,
- model: FunASRModel,
- optimizers: Sequence[torch.optim.Optimizer],
- schedulers: Sequence[Optional[AbsScheduler]],
- train_dataloader: AbsIterFactory,
- valid_dataloader: AbsIterFactory,
- distributed_option: DistributedOption
-):
- trainer = Trainer(
- args=args,
- model=model,
- optimizers=optimizers,
- schedulers=schedulers,
- train_dataloader=train_dataloader,
- valid_dataloader=valid_dataloader,
- distributed_option=distributed_option
- )
- return trainer
diff --git a/funasr/build_utils/build_vad_model.py b/funasr/build_utils/build_vad_model.py
deleted file mode 100644
index 6a840cf..0000000
--- a/funasr/build_utils/build_vad_model.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import torch
-
-from funasr.models.e2e_vad import E2EVadModel
-from funasr.models.encoder.fsmn_encoder import FSMN
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.torch_utils.initialize import initialize
-from funasr.train.class_choices import ClassChoices
-
-frontend_choices = ClassChoices(
- name="frontend",
- classes=dict(
- default=DefaultFrontend,
- sliding_window=SlidingWindow,
- s3prl=S3prlFrontend,
- fused=FusedFrontends,
- wav_frontend=WavFrontend,
- wav_frontend_online=WavFrontendOnline,
- ),
- default="default",
-)
-encoder_choices = ClassChoices(
- "encoder",
- classes=dict(
- fsmn=FSMN,
- ),
- type_check=torch.nn.Module,
- default="fsmn",
-)
-model_choices = ClassChoices(
- "model",
- classes=dict(
- e2evad=E2EVadModel,
- ),
- default="e2evad",
-)
-
-class_choices_list = [
- # --frontend and --frontend_conf
- frontend_choices,
- # --encoder and --encoder_conf
- encoder_choices,
- # --model and --model_conf
- model_choices,
-]
-
-
-def build_vad_model(args):
- # frontend
- if not hasattr(args, "cmvn_file"):
- args.cmvn_file = None
- if not hasattr(args, "init"):
- args.init = None
- if args.input_size is None:
- frontend_class = frontend_choices.get_class(args.frontend)
- if args.frontend == 'wav_frontend':
- frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
- else:
- frontend = frontend_class(**args.frontend_conf)
- input_size = frontend.output_size()
- else:
- args.frontend = None
- args.frontend_conf = {}
- frontend = None
- input_size = args.input_size
-
- # encoder
- encoder_class = encoder_choices.get_class(args.encoder)
- encoder = encoder_class(**args.encoder_conf)
-
- model_class = model_choices.get_class(args.model)
- model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
-
- # initialize
- if args.init is not None:
- initialize(model, args.init)
-
- return model
diff --git a/funasr/cli/train_cli.py b/funasr/cli/train_cli.py
index c62153e..a22d5d4 100644
--- a/funasr/cli/train_cli.py
+++ b/funasr/cli/train_cli.py
@@ -35,8 +35,9 @@
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
import pdb; pdb.set_trace()
- if kwargs.get("model_pretrain"):
- kwargs = download_model(**kwargs)
+ if ":" in kwargs["model"]:
+ logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+ kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
import pdb;
pdb.set_trace()
@@ -84,8 +85,7 @@
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:
- init_param = init_param
- if isinstance(init_param, Sequence):
+ if not isinstance(init_param, Sequence):
init_param = (init_param,)
logging.info("init_param is not None: %s", init_param)
for p in init_param:
diff --git a/funasr/datasets/dataset_jsonl.py b/funasr/datasets/dataset_jsonl.py
index 7f2cd83..21df89e 100644
--- a/funasr/datasets/dataset_jsonl.py
+++ b/funasr/datasets/dataset_jsonl.py
@@ -8,33 +8,7 @@
import time
import logging
-def load_audio(audio_path: str, fs: int=16000):
- audio = None
- if audio_path.startswith("oss:"):
- pass
- elif audio_path.startswith("odps:"):
- pass
- else:
- if ".ark:" in audio_path:
- audio = kaldiio.load_mat(audio_path)
- else:
- # audio, fs = librosa.load(audio_path, sr=fs)
- audio, fs = torchaudio.load(audio_path)
- audio = audio[0, :]
- return audio
-
-def extract_features(data, date_type: str="sound", frontend=None):
- if date_type == "sound":
-
- if isinstance(data, np.ndarray):
- data = torch.from_numpy(data).to(torch.float32)
- data_len = torch.tensor([data.shape[0]]).to(torch.int32)
- feat, feats_lens = frontend(data[None, :], data_len)
-
- feat = feat[0, :, :]
- else:
- feat, feats_lens = torch.from_numpy(data).to(torch.float32), torch.tensor([data.shape[0]]).to(torch.int32)
- return feat, feats_lens
+from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
@@ -115,17 +89,16 @@
def __getitem__(self, index):
item = self.indexed_dataset[index]
- # return item
source = item["source"]
data_src = load_audio(source, fs=self.fs)
- speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
+ speech, speech_lengths = extract_fbank(data_src, self.data_type, self.frontend) # speech: [b, T, d]
target = item["target"]
ids = self.tokenizer.encode(target)
ids_lengths = len(ids)
text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
- return {"speech": speech,
+ return {"speech": speech[0, :, :],
"speech_lengths": speech_lengths,
"text": text,
"text_lengths": text_lengths,
diff --git a/funasr/build_utils/__init__.py b/funasr/datasets/fun_datasets/__init__.py
similarity index 100%
copy from funasr/build_utils/__init__.py
copy to funasr/datasets/fun_datasets/__init__.py
diff --git a/funasr/datasets/fun_datasets/load_audio_extract_fbank.py b/funasr/datasets/fun_datasets/load_audio_extract_fbank.py
new file mode 100644
index 0000000..c76f346
--- /dev/null
+++ b/funasr/datasets/fun_datasets/load_audio_extract_fbank.py
@@ -0,0 +1,75 @@
+import os
+import torch
+import json
+import torch.distributed as dist
+import numpy as np
+import kaldiio
+import librosa
+import torchaudio
+import time
+import logging
+from torch.nn.utils.rnn import pad_sequence
+
+def load_audio(audio_or_path_or_list, fs: int=16000, audio_fs: int=16000):
+
+ if isinstance(audio_or_path_or_list, (list, tuple)):
+ return [load_audio(audio, fs=fs, audio_fs=audio_fs) for audio in audio_or_path_or_list]
+
+ if isinstance(audio_or_path_or_list, str) and os.path.exists(audio_or_path_or_list):
+ audio_or_path_or_list, audio_fs = torchaudio.load(audio_or_path_or_list)
+ audio_or_path_or_list = audio_or_path_or_list[0, :]
+ elif isinstance(audio_or_path_or_list, np.ndarray): # audio sample point
+ audio_or_path_or_list = np.squeeze(audio_or_path_or_list) #[n_samples,]
+
+ if audio_fs != fs:
+ resampler = torchaudio.transforms.Resample(audio_fs, fs)
+ resampled_waveform = resampler(audio_or_path_or_list[None, :])[0, :]
+ return audio_or_path_or_list
+#
+# def load_audio_from_list(audio_list, fs: int=16000, audio_fs: int=16000):
+# if isinstance(audio_list, (list, tuple)):
+# return [load_audio(audio_or_path, fs=fs, audio_fs=audio_fs) for audio_or_path in audio_list]
+
+
+def load_bytes(input):
+ middle_data = np.frombuffer(input, dtype=np.int16)
+ middle_data = np.asarray(middle_data)
+ if middle_data.dtype.kind not in 'iu':
+ raise TypeError("'middle_data' must be an array of integers")
+ dtype = np.dtype('float32')
+ if dtype.kind != 'f':
+ raise TypeError("'dtype' must be a floating point type")
+
+ i = np.iinfo(middle_data.dtype)
+ abs_max = 2 ** (i.bits - 1)
+ offset = i.min + abs_max
+ array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
+ return array
+
+def extract_fbank(data, data_len = None, date_type: str="sound", frontend=None):
+
+ if isinstance(data, np.ndarray):
+ data = torch.from_numpy(data)
+ if len(data) < 2:
+ data = data[None, :] # data: [batch, N]
+ data_len = [data.shape[1]] if data_len is None else data_len
+ elif isinstance(data, torch.Tensor):
+ if len(data) < 2:
+ data = data[None, :] # data: [batch, N]
+ data_len = [data.shape[1]] if data_len is None else data_len
+ elif isinstance(data, (list, tuple)):
+ data_list, data_len = [], []
+ for data_i in data:
+ if isinstance(data, np.ndarray):
+ data_i = torch.from_numpy(data_i)
+ data_list.append(data_i)
+ data_len.append(data_i.shape[0])
+ data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
+ # import pdb;
+ # pdb.set_trace()
+ if date_type == "sound":
+ data, data_len = frontend(data, data_len)
+
+ if isinstance(data_len, (list, tuple)):
+ data_len = torch.tensor([data_len])
+ return data.to(torch.float32), data_len.to(torch.int32)
\ No newline at end of file
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index f92f322..ac16065 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -116,7 +116,7 @@
def forward(
self,
input: torch.Tensor,
- input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ input_lengths) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
diff --git a/funasr/build_utils/__init__.py b/funasr/models/paraformer/__init__.py
similarity index 100%
rename from funasr/build_utils/__init__.py
rename to funasr/models/paraformer/__init__.py
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
new file mode 100644
index 0000000..75b36a9
--- /dev/null
+++ b/funasr/models/paraformer/model.py
@@ -0,0 +1,1760 @@
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+import tempfile
+import codecs
+import requests
+import re
+import copy
+import torch
+import torch.nn as nn
+import random
+import numpy as np
+import time
+# from funasr.layers.abs_normalize import AbsNormalize
+from funasr.losses.label_smoothing_loss import (
+ LabelSmoothingLoss, # noqa: H301
+)
+# from funasr.models.ctc import CTC
+# from funasr.models.decoder.abs_decoder import AbsDecoder
+# from funasr.models.e2e_asr_common import ErrorCalculator
+# from funasr.models.encoder.abs_encoder import AbsEncoder
+# from funasr.models.frontend.abs_frontend import AbsFrontend
+# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.predictor.cif import mae_loss
+# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+# from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.modules.nets_utils import make_pad_mask, pad_list
+from funasr.modules.nets_utils import th_accuracy
+from funasr.torch_utils.device_funcs import force_gatherable
+# from funasr.models.base_model import FunASRModel
+# from funasr.models.predictor.cif import CifPredictorV3
+from funasr.models.paraformer.search import Hypothesis
+
+from funasr.cli.model_class_factory import *
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+
+class Paraformer(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ # token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[str] = None,
+ frontend_conf: Optional[Dict] = None,
+ specaug: Optional[str] = None,
+ specaug_conf: Optional[Dict] = None,
+ normalize: str = None,
+ normalize_conf: Optional[Dict] = None,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ decoder: str = None,
+ decoder_conf: Optional[Dict] = None,
+ ctc: str = None,
+ ctc_conf: Optional[Dict] = None,
+ predictor: str = None,
+ predictor_conf: Optional[Dict] = None,
+ ctc_weight: float = 0.5,
+ input_size: int = 80,
+ vocab_size: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ # report_cer: bool = True,
+ # report_wer: bool = True,
+ # sym_space: str = "<space>",
+ # sym_blank: str = "<blank>",
+ # extract_feats_in_collect_stats: bool = True,
+ # predictor=None,
+ predictor_weight: float = 0.0,
+ predictor_bias: int = 0,
+ sampling_ratio: float = 0.2,
+ share_embedding: bool = False,
+ # preencoder: Optional[AbsPreEncoder] = None,
+ # postencoder: Optional[AbsPostEncoder] = None,
+ use_1st_decoder_loss: bool = False,
+ **kwargs,
+ ):
+
+ super().__init__()
+
+ # import pdb;
+ # pdb.set_trace()
+
+ if frontend is not None:
+ frontend_class = frontend_choices.get_class(frontend)
+ frontend = frontend_class(**frontend_conf)
+ if specaug is not None:
+ specaug_class = specaug_choices.get_class(specaug)
+ specaug = specaug_class(**specaug_conf)
+ if normalize is not None:
+ normalize_class = normalize_choices.get_class(normalize)
+ normalize = normalize_class(**normalize_conf)
+ encoder_class = encoder_choices.get_class(encoder)
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
+ if decoder is not None:
+ decoder_class = decoder_choices.get_class(decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **decoder_conf,
+ )
+ if ctc_weight > 0.0:
+
+ if ctc_conf is None:
+ ctc_conf = {}
+
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+ )
+ if predictor is not None:
+ predictor_class = predictor_choices.get_class(predictor)
+ predictor = predictor_class(**predictor_conf)
+
+ # note that eos is the same as sos (equivalent ID)
+ self.blank_id = blank_id
+ self.sos = sos if sos is not None else vocab_size - 1
+ self.eos = eos if eos is not None else vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.ctc_weight = ctc_weight
+ # self.token_list = token_list.copy()
+ #
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ # self.preencoder = preencoder
+ # self.postencoder = postencoder
+ self.encoder = encoder
+ #
+ # if not hasattr(self.encoder, "interctc_use_conditioning"):
+ # self.encoder.interctc_use_conditioning = False
+ # if self.encoder.interctc_use_conditioning:
+ # self.encoder.conditioning_layer = torch.nn.Linear(
+ # vocab_size, self.encoder.output_size()
+ # )
+ #
+ # self.error_calculator = None
+ #
+ if ctc_weight == 1.0:
+ self.decoder = None
+ else:
+ self.decoder = decoder
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+ #
+ # if report_cer or report_wer:
+ # self.error_calculator = ErrorCalculator(
+ # token_list, sym_space, sym_blank, report_cer, report_wer
+ # )
+ #
+ if ctc_weight == 0.0:
+ self.ctc = None
+ else:
+ self.ctc = ctc
+ #
+ # self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+ self.predictor = predictor
+ self.predictor_weight = predictor_weight
+ self.predictor_bias = predictor_bias
+ self.sampling_ratio = sampling_ratio
+ self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
+ # self.step_cur = 0
+ #
+ self.share_embedding = share_embedding
+ if self.share_embedding:
+ self.decoder.embed = None
+
+ self.use_1st_decoder_loss = use_1st_decoder_loss
+ self.length_normalized_loss = length_normalized_loss
+ self.beam_search = None
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ # import pdb;
+ # pdb.set_trace()
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+
+ loss_ctc, cer_ctc = None, None
+ loss_pre = None
+ stats = dict()
+
+ # decoder: CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+
+ # decoder: Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = (text_lengths + self.predictor_bias).sum()
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+
+ # Forward encoder
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ return encoder_out, encoder_out_lens
+
+ def calc_predictor(self, encoder_out, encoder_out_lens):
+
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
+ return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # 0. sampler
+ decoder_out_1st = None
+ pre_loss_att = None
+ if self.sampling_ratio > 0.0:
+
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds)
+ else:
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
+
+ def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
+
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad_masked)
+ with torch.no_grad():
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0,
+ index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device),
+ value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
+
+
+ def init_beam_search(self,
+ **kwargs,
+ ):
+ from funasr.models.paraformer.search import BeamSearchPara
+ from funasr.modules.scorers.ctc import CTCPrefixScorer
+ from funasr.modules.scorers.length_bonus import LengthBonus
+
+ # 1. Build ASR model
+ scorers = {}
+
+ if self.ctc != None:
+ ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+ scorers.update(
+ ctc=ctc
+ )
+ token_list = kwargs.get("token_list")
+ scorers.update(
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ weights = dict(
+ decoder=1.0 - kwargs.get("decoding_ctc_weight"),
+ ctc=kwargs.get("decoding_ctc_weight", 0.0),
+ lm=kwargs.get("lm_weight", 0.0),
+ ngram=kwargs.get("ngram_weight", 0.0),
+ length_bonus=kwargs.get("penalty", 0.0),
+ )
+ beam_search = BeamSearchPara(
+ beam_size=kwargs.get("beam_size", 2),
+ weights=weights,
+ scorers=scorers,
+ sos=self.sos,
+ eos=self.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+ )
+ # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+ # for scorer in scorers.values():
+ # if isinstance(scorer, torch.nn.Module):
+ # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+ self.beam_search = beam_search
+
+ def generate(self,
+ data_in: list,
+ data_lengths: list=None,
+ key: list=None,
+ tokenizer=None,
+ **kwargs,
+ ):
+
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"), frontend=self.frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
+
+ speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # predictor
+ predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return []
+ decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+ pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{nbest_idx+1}best_recog"]
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+
+ return results, meta_data
+
+
+
+class BiCifParaformer(Paraformer):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
+
+
+ def _calc_pre2_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
+
+ # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+ loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
+
+ return loss_pre2
+
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # 0. sampler
+ decoder_out_1st = None
+ if self.sampling_ratio > 0.0:
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds)
+ else:
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre
+
+
+ def calc_predictor(self, encoder_out, encoder_out_lens):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
+ None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
+ return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+
+ def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
+ encoder_out_mask,
+ token_num)
+ return ds_alphas, ds_cif_peak, us_alphas, us_peaks
+
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+
+ loss_ctc, cer_ctc = None, None
+ loss_pre = None
+ stats = dict()
+
+ # decoder: CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+
+ # decoder: Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ loss_pre2 = self._calc_pre2_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+ else:
+ loss = self.ctc_weight * loss_ctc + (
+ 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+ stats["loss_pre2"] = loss_pre2.detach().cpu()
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = int((text_lengths + self.predictor_bias).sum())
+
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def generate(self,
+ data_in: list,
+ data_lengths: list = None,
+ key: list = None,
+ tokenizer=None,
+ **kwargs,
+ ):
+
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"),
+ frontend=self.frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data[
+ "batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
+
+ speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # predictor
+ predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return []
+ decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+ pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+ # BiCifParaformer, test no bias cif2
+
+ _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
+ pre_token_length)
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+ minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
+ us_peaks[i][:encoder_out_lens[i] * 3],
+ copy.copy(token),
+ vad_offset=kwargs.get("begin_time", 0))
+
+ text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token, timestamp)
+
+ result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
+ "time_stamp_postprocessed": time_stamp_postprocessed,
+ "word_lists": word_lists
+ }
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+
+
+ return results, meta_data
+
+
+class NeatContextualParaformer(Paraformer):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.target_buffer_length = kwargs.get("target_buffer_length", -1)
+ inner_dim = kwargs.get("inner_dim", 256)
+ bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
+ use_decoder_embedding = kwargs.get("use_decoder_embedding", False)
+ crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
+ crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
+ bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
+
+
+ if bias_encoder_type == 'lstm':
+ logging.warning("enable bias encoder sampling and contextual training")
+ self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
+ self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
+ elif bias_encoder_type == 'mean':
+ logging.warning("enable bias encoder sampling and contextual training")
+ self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
+ else:
+ logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type))
+
+ if self.target_buffer_length > 0:
+ self.hotword_buffer = None
+ self.length_record = []
+ self.current_buffer_length = 0
+ self.use_decoder_embedding = use_decoder_embedding
+ self.crit_attn_weight = crit_attn_weight
+ if self.crit_attn_weight > 0:
+ self.attn_loss = torch.nn.L1Loss()
+ self.crit_attn_smooth = crit_attn_smooth
+
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ hotword_pad: torch.Tensor,
+ hotword_lengths: torch.Tensor,
+ dha_pad: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+
+ loss_ctc, cer_ctc = None, None
+
+ stats = dict()
+
+ # 1. CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+
+ # 2b. Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
+ encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+ if loss_ideal is not None:
+ loss = loss + loss_ideal * self.crit_attn_weight
+ stats["loss_ideal"] = loss_ideal.detach().cpu()
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+ stats["loss"] = torch.clone(loss.detach())
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = int((text_lengths + self.predictor_bias).sum())
+
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+
+ def _calc_att_clas_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ hotword_pad: torch.Tensor,
+ hotword_lengths: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # -1. bias encoder
+ if self.use_decoder_embedding:
+ hw_embed = self.decoder.embed(hotword_pad)
+ else:
+ hw_embed = self.bias_embed(hotword_pad)
+ hw_embed, (_, _) = self.bias_encoder(hw_embed)
+ _ind = np.arange(0, hotword_pad.shape[0]).tolist()
+ selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
+ contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
+
+ # 0. sampler
+ decoder_out_1st = None
+ if self.sampling_ratio > 0.0:
+ if self.step_cur < 2:
+ logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds, contextual_info)
+ else:
+ if self.step_cur < 2:
+ logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ '''
+ if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
+ ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
+ attn_non_blank = attn[:,:,:,:-1]
+ ideal_attn_non_blank = ideal_attn[:,:,:-1]
+ loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device))
+ else:
+ loss_ideal = None
+ '''
+ loss_ideal = None
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
+
+
+ def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad)
+ with torch.no_grad():
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0,
+ index=torch.randperm(seq_lens[li])[:target_num].to(pre_acoustic_embeds.device),
+ value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
+ clas_scale=1.0):
+ if hw_list is None:
+ hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list
+ hw_list_pad = pad_list(hw_list, 0)
+ if self.use_decoder_embedding:
+ hw_embed = self.decoder.embed(hw_list_pad)
+ else:
+ hw_embed = self.bias_embed(hw_list_pad)
+ hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
+ hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
+ else:
+ hw_lengths = [len(i) for i in hw_list]
+ hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
+ if self.use_decoder_embedding:
+ hw_embed = self.decoder.embed(hw_list_pad)
+ else:
+ hw_embed = self.bias_embed(hw_list_pad)
+ hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
+ enforce_sorted=False)
+ _, (h_n, _) = self.bias_encoder(hw_embed)
+ hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
+
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
+ def generate(self,
+ data_in: list,
+ data_lengths: list = None,
+ key: list = None,
+ tokenizer=None,
+ **kwargs,
+ ):
+
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"),
+ frontend=self.frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data[
+ "batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
+
+ speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+ # hotword
+ self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer)
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # predictor
+ predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return []
+
+
+ decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
+ pre_acoustic_embeds,
+ pre_token_length,
+ hw_list=self.hotword_list,
+ clas_scale=kwargs.get("clas_scale", 1.0))
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+ minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(
+ filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+
+ return results, meta_data
+
+
+ def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None):
+ def load_seg_dict(seg_dict_file):
+ seg_dict = {}
+ assert isinstance(seg_dict_file, str)
+ with open(seg_dict_file, "r", encoding="utf8") as f:
+ lines = f.readlines()
+ for line in lines:
+ s = line.strip().split()
+ key = s[0]
+ value = s[1:]
+ seg_dict[key] = " ".join(value)
+ return seg_dict
+
+ def seg_tokenize(txt, seg_dict):
+ pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
+ out_txt = ""
+ for word in txt:
+ word = word.lower()
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ if pattern.match(word):
+ for char in word:
+ if char in seg_dict:
+ out_txt += seg_dict[char] + " "
+ else:
+ out_txt += "<unk>" + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
+ seg_dict = None
+ if self.frontend.cmvn_file is not None:
+ model_dir = os.path.dirname(self.frontend.cmvn_file)
+ seg_dict_file = os.path.join(model_dir, 'seg_dict')
+ if os.path.exists(seg_dict_file):
+ seg_dict = load_seg_dict(seg_dict_file)
+ else:
+ seg_dict = None
+ # for None
+ if hotword_list_or_file is None:
+ hotword_list = None
+ # for local txt inputs
+ elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
+ logging.info("Attempting to parse hotwords from local txt...")
+ hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hw_list = hw.split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
+ hotword_str_list.append(hw)
+ hotword_list.append(tokenizer.tokens2ids(hw_list))
+ hotword_list.append([self.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ # for url, download and generate txt
+ elif hotword_list_or_file.startswith('http'):
+ logging.info("Attempting to parse hotwords from url...")
+ work_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(work_dir):
+ os.makedirs(work_dir)
+ text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+ local_file = requests.get(hotword_list_or_file)
+ open(text_file_path, "wb").write(local_file.content)
+ hotword_list_or_file = text_file_path
+ hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hw_list = hw.split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
+ hotword_str_list.append(hw)
+ hotword_list.append(tokenizer.tokens2ids(hw_list))
+ hotword_list.append([self.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ # for text str input
+ elif not hotword_list_or_file.endswith('.txt'):
+ logging.info("Attempting to parse hotwords as str...")
+ hotword_list = []
+ hotword_str_list = []
+ for hw in hotword_list_or_file.strip().split():
+ hotword_str_list.append(hw)
+ hw_list = hw.strip().split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
+ hotword_list.append(tokenizer.tokens2ids(hw_list))
+ hotword_list.append([self.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Hotword list: {}.".format(hotword_str_list))
+ else:
+ hotword_list = None
+ return hotword_list
+
+
+class ParaformerOnline(Paraformer):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+
+ super().__init__(*args, **kwargs)
+
+ # import pdb;
+ # pdb.set_trace()
+ self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
+
+
+ self.scama_mask = None
+ if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
+ from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
+ self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
+ self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
+
+
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ # import pdb;
+ # pdb.set_trace()
+ decoding_ind = kwargs.get("decoding_ind")
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # Encoder
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+ else:
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ loss_ctc, cer_ctc = None, None
+ loss_pre = None
+ stats = dict()
+
+ # decoder: CTC branch
+
+ if self.ctc_weight > 0.0:
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+ encoder_out_lens,
+ chunk_outs=None)
+ else:
+ encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
+
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
+ )
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+ # decoder: Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight
+ else:
+ loss = self.ctc_weight * loss_ctc + (
+ 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = (text_lengths + self.predictor_bias).sum()
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode_chunk(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+ # Forward encoder
+ encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ return encoder_out, torch.tensor([encoder_out.size(1)])
+
+ def _calc_att_predictor_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ mask_chunk_predictor = None
+ if self.encoder.overlap_chunk_cls is not None:
+ mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(
+ 0))
+ mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+ batch_size=encoder_out.size(0))
+ encoder_out = encoder_out * mask_shfit_chunk
+ pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
+ ys_pad,
+ encoder_out_mask,
+ ignore_id=self.ignore_id,
+ mask_chunk_predictor=mask_chunk_predictor,
+ target_label_length=ys_pad_lens,
+ )
+ predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+ encoder_out_lens)
+
+ scama_mask = None
+ if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+ encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+ attention_chunk_center_bias = 0
+ attention_chunk_size = encoder_chunk_size
+ decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+ mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
+ get_mask_shift_att_chunk_decoder(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(0)
+ )
+ scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+ predictor_alignments=predictor_alignments,
+ encoder_sequence_length=encoder_out_lens,
+ chunk_size=1,
+ encoder_chunk_size=encoder_chunk_size,
+ attention_chunk_center_bias=attention_chunk_center_bias,
+ attention_chunk_size=attention_chunk_size,
+ attention_chunk_type=self.decoder_attention_chunk_type,
+ step=None,
+ predictor_mask_chunk_hopping=mask_chunk_predictor,
+ decoder_att_look_back_factor=decoder_att_look_back_factor,
+ mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+ target_length=ys_pad_lens,
+ is_training=self.training,
+ )
+ elif self.encoder.overlap_chunk_cls is not None:
+ encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+ encoder_out_lens,
+ chunk_outs=None)
+ # 0. sampler
+ decoder_out_1st = None
+ pre_loss_att = None
+ if self.sampling_ratio > 0.0:
+ if self.step_cur < 2:
+ logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ if self.use_1st_decoder_loss:
+ sematic_embeds, decoder_out_1st, pre_loss_att = \
+ self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
+ ys_pad_lens, pre_acoustic_embeds, scama_mask)
+ else:
+ sematic_embeds, decoder_out_1st = \
+ self.sampler(encoder_out, encoder_out_lens, ys_pad,
+ ys_pad_lens, pre_acoustic_embeds, scama_mask)
+ else:
+ if self.step_cur < 2:
+ logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
+
+ def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
+
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad_masked)
+ with torch.no_grad():
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+
+ def calc_predictor(self, encoder_out, encoder_out_lens):
+
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ mask_chunk_predictor = None
+ if self.encoder.overlap_chunk_cls is not None:
+ mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(
+ 0))
+ mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+ batch_size=encoder_out.size(0))
+ encoder_out = encoder_out * mask_shfit_chunk
+ pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
+ None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id,
+ mask_chunk_predictor=mask_chunk_predictor,
+ target_label_length=None,
+ )
+ predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+ encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
+
+ scama_mask = None
+ if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+ encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+ attention_chunk_center_bias = 0
+ attention_chunk_size = encoder_chunk_size
+ decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+ mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
+ get_mask_shift_att_chunk_decoder(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(0)
+ )
+ scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+ predictor_alignments=predictor_alignments,
+ encoder_sequence_length=encoder_out_lens,
+ chunk_size=1,
+ encoder_chunk_size=encoder_chunk_size,
+ attention_chunk_center_bias=attention_chunk_center_bias,
+ attention_chunk_size=attention_chunk_size,
+ attention_chunk_type=self.decoder_attention_chunk_type,
+ step=None,
+ predictor_mask_chunk_hopping=mask_chunk_predictor,
+ decoder_att_look_back_factor=decoder_att_look_back_factor,
+ mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+ target_length=None,
+ is_training=self.training,
+ )
+ self.scama_mask = scama_mask
+
+ return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
+
+ def calc_predictor_chunk(self, encoder_out, cache=None):
+
+ pre_acoustic_embeds, pre_token_length = \
+ self.predictor.forward_chunk(encoder_out, cache["encoder"])
+ return pre_acoustic_embeds, pre_token_length
+
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
+ def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+ decoder_outs = self.decoder.forward_chunk(
+ encoder_out, sematic_embeds, cache["decoder"]
+ )
+ decoder_out = decoder_outs
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out
+
+ def generate(self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ tokenizer=None,
+ **kwargs,
+ ):
+
+ is_use_ctc = kwargs.get("ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ print(is_use_ctc)
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(speech, speech_lengths, **kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ # Forward Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ # predictor
+ predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return []
+ decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+ pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+ minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ timestamp = []
+
+ results.append((text, token, timestamp))
+
+ return results
+
diff --git a/funasr/models/paraformer/search.py b/funasr/models/paraformer/search.py
new file mode 100644
index 0000000..440b48e
--- /dev/null
+++ b/funasr/models/paraformer/search.py
@@ -0,0 +1,453 @@
+from itertools import chain
+import logging
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+from typing import Tuple
+from typing import Union
+
+import torch
+
+from funasr.modules.e2e_asr_common import end_detect
+from funasr.modules.scorers.scorer_interface import PartialScorerInterface
+from funasr.modules.scorers.scorer_interface import ScorerInterface
+
+class Hypothesis(NamedTuple):
+ """Hypothesis data type."""
+
+ yseq: torch.Tensor
+ score: Union[float, torch.Tensor] = 0
+ scores: Dict[str, Union[float, torch.Tensor]] = dict()
+ states: Dict[str, Any] = dict()
+
+ def asdict(self) -> dict:
+ """Convert data to JSON-friendly dict."""
+ return self._replace(
+ yseq=self.yseq.tolist(),
+ score=float(self.score),
+ scores={k: float(v) for k, v in self.scores.items()},
+ )._asdict()
+
+
+class BeamSearchPara(torch.nn.Module):
+ """Beam search implementation."""
+
+ def __init__(
+ self,
+ scorers: Dict[str, ScorerInterface],
+ weights: Dict[str, float],
+ beam_size: int,
+ vocab_size: int,
+ sos: int,
+ eos: int,
+ token_list: List[str] = None,
+ pre_beam_ratio: float = 1.5,
+ pre_beam_score_key: str = None,
+ ):
+ """Initialize beam search.
+
+ Args:
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
+ e.g., Decoder, CTCPrefixScorer, LM
+ The scorer will be ignored if it is `None`
+ weights (dict[str, float]): Dict of weights for each scorers
+ The scorer will be ignored if its weight is 0
+ beam_size (int): The number of hypotheses kept during search
+ vocab_size (int): The number of vocabulary
+ sos (int): Start of sequence id
+ eos (int): End of sequence id
+ token_list (list[str]): List of tokens for debug log
+ pre_beam_score_key (str): key of scores to perform pre-beam search
+ pre_beam_ratio (float): beam size in the pre-beam search
+ will be `int(pre_beam_ratio * beam_size)`
+
+ """
+ super().__init__()
+ # set scorers
+ self.weights = weights
+ self.scorers = dict()
+ self.full_scorers = dict()
+ self.part_scorers = dict()
+ # this module dict is required for recursive cast
+ # `self.to(device, dtype)` in `recog.py`
+ self.nn_dict = torch.nn.ModuleDict()
+ for k, v in scorers.items():
+ w = weights.get(k, 0)
+ if w == 0 or v is None:
+ continue
+ assert isinstance(
+ v, ScorerInterface
+ ), f"{k} ({type(v)}) does not implement ScorerInterface"
+ self.scorers[k] = v
+ if isinstance(v, PartialScorerInterface):
+ self.part_scorers[k] = v
+ else:
+ self.full_scorers[k] = v
+ if isinstance(v, torch.nn.Module):
+ self.nn_dict[k] = v
+
+ # set configurations
+ self.sos = sos
+ self.eos = eos
+ self.token_list = token_list
+ self.pre_beam_size = int(pre_beam_ratio * beam_size)
+ self.beam_size = beam_size
+ self.n_vocab = vocab_size
+ if (
+ pre_beam_score_key is not None
+ and pre_beam_score_key != "full"
+ and pre_beam_score_key not in self.full_scorers
+ ):
+ raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
+ self.pre_beam_score_key = pre_beam_score_key
+ self.do_pre_beam = (
+ self.pre_beam_score_key is not None
+ and self.pre_beam_size < self.n_vocab
+ and len(self.part_scorers) > 0
+ )
+
+ def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
+ """Get an initial hypothesis data.
+
+ Args:
+ x (torch.Tensor): The encoder output feature
+
+ Returns:
+ Hypothesis: The initial hypothesis.
+
+ """
+ init_states = dict()
+ init_scores = dict()
+ for k, d in self.scorers.items():
+ init_states[k] = d.init_state(x)
+ init_scores[k] = 0.0
+ return [
+ Hypothesis(
+ score=0.0,
+ scores=init_scores,
+ states=init_states,
+ yseq=torch.tensor([self.sos], device=x.device),
+ )
+ ]
+
+ @staticmethod
+ def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
+ """Append new token to prefix tokens.
+
+ Args:
+ xs (torch.Tensor): The prefix token
+ x (int): The new token to append
+
+ Returns:
+ torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
+
+ """
+ x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
+ return torch.cat((xs, x))
+
+ def score_full(
+ self, hyp: Hypothesis, x: torch.Tensor
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+ """Score new hypothesis by `self.full_scorers`.
+
+ Args:
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
+ x (torch.Tensor): Corresponding input feature
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+ score dict of `hyp` that has string keys of `self.full_scorers`
+ and tensor score values of shape: `(self.n_vocab,)`,
+ and state dict that has string keys
+ and state values of `self.full_scorers`
+
+ """
+ scores = dict()
+ states = dict()
+ for k, d in self.full_scorers.items():
+ scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
+ return scores, states
+
+ def score_partial(
+ self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+ """Score new hypothesis by `self.part_scorers`.
+
+ Args:
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
+ ids (torch.Tensor): 1D tensor of new partial tokens to score
+ x (torch.Tensor): Corresponding input feature
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+ score dict of `hyp` that has string keys of `self.part_scorers`
+ and tensor score values of shape: `(len(ids),)`,
+ and state dict that has string keys
+ and state values of `self.part_scorers`
+
+ """
+ scores = dict()
+ states = dict()
+ for k, d in self.part_scorers.items():
+ scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
+ return scores, states
+
+ def beam(
+ self, weighted_scores: torch.Tensor, ids: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute topk full token ids and partial token ids.
+
+ Args:
+ weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
+ Its shape is `(self.n_vocab,)`.
+ ids (torch.Tensor): The partial token ids to compute topk
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]:
+ The topk full token ids and partial token ids.
+ Their shapes are `(self.beam_size,)`
+
+ """
+ # no pre beam performed
+ if weighted_scores.size(0) == ids.size(0):
+ top_ids = weighted_scores.topk(self.beam_size)[1]
+ return top_ids, top_ids
+
+ # mask pruned in pre-beam not to select in topk
+ tmp = weighted_scores[ids]
+ weighted_scores[:] = -float("inf")
+ weighted_scores[ids] = tmp
+ top_ids = weighted_scores.topk(self.beam_size)[1]
+ local_ids = weighted_scores[ids].topk(self.beam_size)[1]
+ return top_ids, local_ids
+
+ @staticmethod
+ def merge_scores(
+ prev_scores: Dict[str, float],
+ next_full_scores: Dict[str, torch.Tensor],
+ full_idx: int,
+ next_part_scores: Dict[str, torch.Tensor],
+ part_idx: int,
+ ) -> Dict[str, torch.Tensor]:
+ """Merge scores for new hypothesis.
+
+ Args:
+ prev_scores (Dict[str, float]):
+ The previous hypothesis scores by `self.scorers`
+ next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
+ full_idx (int): The next token id for `next_full_scores`
+ next_part_scores (Dict[str, torch.Tensor]):
+ scores of partial tokens by `self.part_scorers`
+ part_idx (int): The new token id for `next_part_scores`
+
+ Returns:
+ Dict[str, torch.Tensor]: The new score dict.
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
+ Its values are scalar tensors by the scorers.
+
+ """
+ new_scores = dict()
+ for k, v in next_full_scores.items():
+ new_scores[k] = prev_scores[k] + v[full_idx]
+ for k, v in next_part_scores.items():
+ new_scores[k] = prev_scores[k] + v[part_idx]
+ return new_scores
+
+ def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
+ """Merge states for new hypothesis.
+
+ Args:
+ states: states of `self.full_scorers`
+ part_states: states of `self.part_scorers`
+ part_idx (int): The new token id for `part_scores`
+
+ Returns:
+ Dict[str, torch.Tensor]: The new score dict.
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
+ Its values are states of the scorers.
+
+ """
+ new_states = dict()
+ for k, v in states.items():
+ new_states[k] = v
+ for k, d in self.part_scorers.items():
+ new_states[k] = d.select_state(part_states[k], part_idx)
+ return new_states
+
+ def search(
+ self, running_hyps: List[Hypothesis], x: torch.Tensor, am_score: torch.Tensor
+ ) -> List[Hypothesis]:
+ """Search new tokens for running hypotheses and encoded speech x.
+
+ Args:
+ running_hyps (List[Hypothesis]): Running hypotheses on beam
+ x (torch.Tensor): Encoded speech feature (T, D)
+
+ Returns:
+ List[Hypotheses]: Best sorted hypotheses
+
+ """
+ best_hyps = []
+ part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
+ for hyp in running_hyps:
+ # scoring
+ weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
+ weighted_scores += am_score
+ scores, states = self.score_full(hyp, x)
+ for k in self.full_scorers:
+ weighted_scores += self.weights[k] * scores[k]
+ # partial scoring
+ if self.do_pre_beam:
+ pre_beam_scores = (
+ weighted_scores
+ if self.pre_beam_score_key == "full"
+ else scores[self.pre_beam_score_key]
+ )
+ part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
+ part_scores, part_states = self.score_partial(hyp, part_ids, x)
+ for k in self.part_scorers:
+ weighted_scores[part_ids] += self.weights[k] * part_scores[k]
+ # add previous hyp score
+ weighted_scores += hyp.score
+
+ # update hyps
+ for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
+ # will be (2 x beam at most)
+ best_hyps.append(
+ Hypothesis(
+ score=weighted_scores[j],
+ yseq=self.append_token(hyp.yseq, j),
+ scores=self.merge_scores(
+ hyp.scores, scores, j, part_scores, part_j
+ ),
+ states=self.merge_states(states, part_states, part_j),
+ )
+ )
+
+ # sort and prune 2 x beam -> beam
+ best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
+ : min(len(best_hyps), self.beam_size)
+ ]
+ return best_hyps
+
+ def forward(
+ self, x: torch.Tensor, am_scores: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
+ ) -> List[Hypothesis]:
+ """Perform beam search.
+
+ Args:
+ x (torch.Tensor): Encoded speech feature (T, D)
+ maxlenratio (float): Input length ratio to obtain max output length.
+ If maxlenratio=0.0 (default), it uses a end-detect function
+ to automatically find maximum hypothesis lengths
+ If maxlenratio<0.0, its absolute value is interpreted
+ as a constant max output length.
+ minlenratio (float): Input length ratio to obtain min output length.
+
+ Returns:
+ list[Hypothesis]: N-best decoding results
+
+ """
+ # set length bounds
+ maxlen = am_scores.shape[0]
+ logging.info("decoder input length: " + str(x.shape[0]))
+ logging.info("max output length: " + str(maxlen))
+
+ # main loop of prefix search
+ running_hyps = self.init_hyp(x)
+ ended_hyps = []
+ for i in range(maxlen):
+ logging.debug("position " + str(i))
+ best = self.search(running_hyps, x, am_scores[i])
+ # post process of one iteration
+ running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
+ # end detection
+ if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
+ logging.info(f"end detected at {i}")
+ break
+ if len(running_hyps) == 0:
+ logging.info("no hypothesis. Finish decoding.")
+ break
+ else:
+ logging.debug(f"remained hypotheses: {len(running_hyps)}")
+
+ nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
+ # check the number of hypotheses reaching to eos
+ if len(nbest_hyps) == 0:
+ logging.warning(
+ "there is no N-best results, perform recognition "
+ "again with smaller minlenratio."
+ )
+ return (
+ []
+ if minlenratio < 0.1
+ else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
+ )
+
+ # report the best result
+ best = nbest_hyps[0]
+ for k, v in best.scores.items():
+ logging.info(
+ f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
+ )
+ logging.info(f"total log probability: {best.score:.2f}")
+ logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
+ logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
+ if self.token_list is not None:
+ logging.info(
+ "best hypo: "
+ + "".join([self.token_list[x.item()] for x in best.yseq[1:-1]])
+ + "\n"
+ )
+ return nbest_hyps
+
+ def post_process(
+ self,
+ i: int,
+ maxlen: int,
+ maxlenratio: float,
+ running_hyps: List[Hypothesis],
+ ended_hyps: List[Hypothesis],
+ ) -> List[Hypothesis]:
+ """Perform post-processing of beam search iterations.
+
+ Args:
+ i (int): The length of hypothesis tokens.
+ maxlen (int): The maximum length of tokens in beam search.
+ maxlenratio (int): The maximum length ratio in beam search.
+ running_hyps (List[Hypothesis]): The running hypotheses in beam search.
+ ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
+
+ Returns:
+ List[Hypothesis]: The new running hypotheses.
+
+ """
+ logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
+ if self.token_list is not None:
+ logging.debug(
+ "best hypo: "
+ + "".join([self.token_list[x.item()] for x in running_hyps[0].yseq[1:]])
+ )
+ # add eos in the final loop to avoid that there are no ended hyps
+ if i == maxlen - 1:
+ logging.info("adding <eos> in the last position in the loop")
+ running_hyps = [
+ h._replace(yseq=self.append_token(h.yseq, self.eos))
+ for h in running_hyps
+ ]
+
+ # add ended hypotheses to a final list, and removed them from current hypotheses
+ # (this will be a problem, number of hyps < beam)
+ remained_hyps = []
+ for hyp in running_hyps:
+ if hyp.yseq[-1] == self.eos:
+ # e.g., Word LM needs to add final <eos> score
+ for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
+ s = d.final_score(hyp.states[k])
+ hyp.scores[k] += s
+ hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
+ ended_hyps.append(hyp)
+ else:
+ remained_hyps.append(hyp)
+ return remained_hyps
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index d2fc3f0..bbfd173 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -66,7 +66,9 @@
return text_ints
def decode(self, text_ints):
- return self.ids2tokens(text_ints)
+ token = self.ids2tokens(text_ints)
+ text = self.tokens2text(token)
+ return text
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
diff --git a/funasr/utils/download_from_hub.py b/funasr/utils/download_from_hub.py
index d6e4ab4..d6b79d3 100644
--- a/funasr/utils/download_from_hub.py
+++ b/funasr/utils/download_from_hub.py
@@ -11,10 +11,10 @@
return kwargs
def download_fr_ms(**kwargs):
- model_or_path = kwargs.get("model_pretrain")
- model_revision = kwargs.get("model_pretrain_revision")
+ model_or_path = kwargs.get("model")
+ model_revision = kwargs.get("model_revision")
if not os.path.exists(model_or_path):
- model_or_path = get_or_download_model_dir(model_or_path, model_revision, third_party="funasr")
+ model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"))
config = os.path.join(model_or_path, "config.yaml")
assert os.path.exists(config), "{} is not exist!".format(config)
@@ -23,25 +23,29 @@
init_param = os.path.join(model_or_path, "model.pb")
kwargs["init_param"] = init_param
kwargs["token_list"] = os.path.join(model_or_path, "tokens.txt")
+ kwargs["model"] = cfg["model"]
+ kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
return kwargs
def get_or_download_model_dir(
model,
model_revision=None,
- third_party=None):
+ is_training=False,
+ ):
""" Get local model directory or download model if necessary.
Args:
model (str): model id or path to local model directory.
model_revision (str, optional): model version number.
- third_party (str, optional): in which third party library
- this function is called.
+ :param is_training:
"""
from modelscope.hub.check_model import check_local_model_is_latest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import Invoke, ThirdParty
+
+ key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
if os.path.exists(model):
model_cache_dir = model if os.path.isdir(
@@ -49,15 +53,15 @@
check_local_model_is_latest(
model_cache_dir,
user_agent={
- Invoke.KEY: Invoke.LOCAL_TRAINER,
- ThirdParty.KEY: third_party
+ Invoke.KEY: key,
+ ThirdParty.KEY: "funasr"
})
else:
model_cache_dir = snapshot_download(
model,
revision=model_revision,
user_agent={
- Invoke.KEY: Invoke.TRAINER,
- ThirdParty.KEY: third_party
+ Invoke.KEY: key,
+ ThirdParty.KEY: "funasr"
})
return model_cache_dir
\ No newline at end of file
--
Gitblit v1.9.1