| | |
| | | |
| | | import torch |
| | | |
| | | from funasr.torch_utils.model_summary import model_summary |
| | | from funasr.torch_utils.pytorch_version import pytorch_cudnn_version |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.build_utils.build_args import build_args |
| | | from funasr.build_utils.build_dataloader import build_dataloader |
| | | from funasr.build_utils.build_distributed import build_distributed |
| | | from funasr.build_utils.build_model import build_model |
| | | from funasr.build_utils.build_optimizer import build_optimizer |
| | | from funasr.build_utils.build_scheduler import build_scheduler |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.torch_utils.model_summary import model_summary |
| | | from funasr.torch_utils.pytorch_version import pytorch_cudnn_version |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.prepare_data import prepare_data |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump |
| | | |
| | | |
| | |
| | | help="Apply preprocessing to data or not", |
| | | ) |
| | | |
| | | # most task related |
| | | parser.add_argument( |
| | | "--init", |
| | | type=lambda x: str_or_none(x.lower()), |
| | | default=None, |
| | | help="The initialization method", |
| | | choices=[ |
| | | "chainer", |
| | | "xavier_uniform", |
| | | "xavier_normal", |
| | | "kaiming_uniform", |
| | | "kaiming_normal", |
| | | None, |
| | | ], |
| | | ) |
| | | parser.add_argument( |
| | | "--token_list", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="A text mapping int-id to token", |
| | | ) |
| | | parser.add_argument( |
| | | "--token_type", |
| | | type=str, |
| | | default="bpe", |
| | | choices=["bpe", "char", "word"], |
| | | help="", |
| | | ) |
| | | parser.add_argument( |
| | | "--bpemodel", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The model file fo sentencepiece", |
| | | ) |
| | | parser.add_argument( |
| | | "--cleaner", |
| | | type=str_or_none, |
| | | choices=[None, "tacotron", "jaconv", "vietnamese"], |
| | | default=None, |
| | | help="Apply text cleaning", |
| | | ) |
| | | parser.add_argument( |
| | | "--g2p", |
| | | type=str_or_none, |
| | | choices=g2p_choices, |
| | | default=None, |
| | | help="Specify g2p method if --token_type=phn", |
| | | ) |
| | | |
| | | # pai related |
| | | parser.add_argument( |
| | | "--use_pai", |
| | |
| | | # e.g. --encoder and --encoder_conf |
| | | class_choices.add_arguments(parser) |
| | | parser.add_argument( |
| | | "--token_list", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="A text mapping int-id to token", |
| | | ) |
| | | parser.add_argument( |
| | | "--split_with_space", |
| | | type=str2bool, |
| | | default=True, |
| | |
| | | type=str, |
| | | default=None, |
| | | help="seg_dict_file for text processing", |
| | | ) |
| | | parser.add_argument( |
| | | "--init", |
| | | type=lambda x: str_or_none(x.lower()), |
| | | default=None, |
| | | help="The initialization method", |
| | | choices=[ |
| | | "chainer", |
| | | "xavier_uniform", |
| | | "xavier_normal", |
| | | "kaiming_uniform", |
| | | "kaiming_normal", |
| | | None, |
| | | ], |
| | | ) |
| | | parser.add_argument( |
| | | "--input_size", |
| | |
| | | help="The keyword arguments for CTC class.", |
| | | ) |
| | | parser.add_argument( |
| | | "--token_type", |
| | | type=str, |
| | | default="bpe", |
| | | choices=["bpe", "char", "word", "phn"], |
| | | help="The text will be tokenized " "in the specified level token", |
| | | ) |
| | | parser.add_argument( |
| | | "--bpemodel", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The model file of sentencepiece", |
| | | ) |
| | | parser.add_argument( |
| | | "--cleaner", |
| | | type=str_or_none, |
| | | choices=[None, "tacotron", "jaconv", "vietnamese"], |
| | | default=None, |
| | | help="Apply text cleaning", |
| | | ) |
| | | parser.add_argument( |
| | | "--cmvn_file", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The file path of noise scp file.", |
| | | ) |
| | | |
| | | elif args.task_name == "pretrain": |
| | | from funasr.build_utils.build_pretrain_model import class_choices_list |
| | | for class_choices in class_choices_list: |
| | |
| | | # e.g. --encoder and --encoder_conf |
| | | class_choices.add_arguments(parser) |
| | | parser.add_argument( |
| | | "--init", |
| | | type=lambda x: str_or_none(x.lower()), |
| | | default=None, |
| | | help="The initialization method", |
| | | choices=[ |
| | | "chainer", |
| | | "xavier_uniform", |
| | | "xavier_normal", |
| | | "kaiming_uniform", |
| | | "kaiming_normal", |
| | | None, |
| | | ], |
| | | ) |
| | | parser.add_argument( |
| | | "--input_size", |
| | | type=int_or_none, |
| | | default=None, |
| | | help="The number of input dimension of the feature", |
| | | ) |
| | | parser.add_argument( |
| | | "--feats_type", |
| | | type=str, |
| | | default='fbank', |
| | | help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_db_range", |
| | | type=str, |
| | | default="13_15", |
| | | help="The range of noise decibel level.", |
| | | ) |
| | | parser.add_argument( |
| | | "--pred_masked_weight", |
| | | type=float, |
| | | default=1.0, |
| | | help="weight for predictive loss for masked frames", |
| | | ) |
| | | parser.add_argument( |
| | | "--pred_nomask_weight", |
| | | type=float, |
| | | default=0.0, |
| | | help="weight for predictive loss for unmasked frames", |
| | | ) |
| | | parser.add_argument( |
| | | "--loss_weights", |
| | | type=float, |
| | | default=0.0, |
| | | help="weights for additional loss terms (not first one)", |
| | | ) |
| | | |
| | | elif args.task_name == "lm": |
| | | from funasr.build_utils.build_lm_model import class_choices_list |
| | | for class_choices in class_choices_list: |
| | | # Append --<name> and --<name>_conf. |
| | | # e.g. --encoder and --encoder_conf |
| | | class_choices.add_arguments(parser) |
| | | parser.add_argument( |
| | | "--token_list", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="A text mapping int-id to token", |
| | | ) |
| | | parser.add_argument( |
| | | "--init", |
| | | type=lambda x: str_or_none(x.lower()), |
| | | default=None, |
| | | help="The initialization method", |
| | | choices=[ |
| | | "chainer", |
| | | "xavier_uniform", |
| | | "xavier_normal", |
| | | "kaiming_uniform", |
| | | "kaiming_normal", |
| | | None, |
| | | ], |
| | | ) |
| | | parser.add_argument( |
| | | "--token_type", |
| | | type=str, |
| | | default="bpe", |
| | | choices=["bpe", "char", "word"], |
| | | help="", |
| | | ) |
| | | parser.add_argument( |
| | | "--bpemodel", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The model file fo sentencepiece", |
| | | ) |
| | | parser.add_argument( |
| | | "--cleaner", |
| | | type=str_or_none, |
| | | choices=[None, "tacotron", "jaconv", "vietnamese"], |
| | | default=None, |
| | | help="Apply text cleaning", |
| | | ) |
| | | |
| | | elif args.task_name == "punc": |
| | | from funasr.build_utils.build_punc_model import class_choices_list |
| | | for class_choices in class_choices_list: |
| | | # Append --<name> and --<name>_conf. |
| | | # e.g. --encoder and --encoder_conf |
| | | class_choices.add_arguments(parser) |
| | | |
| | | else: |
| | | raise NotImplementedError("Not supported task: {}".format(args.task_name)) |
| | | |
| | |
| | | from funasr.build_utils.build_asr_model import build_asr_model |
| | | from funasr.build_utils.build_pretrain_model import build_pretrain_model |
| | | from funasr.build_utils.build_lm_model import build_lm_model |
| | | from funasr.build_utils.build_punc_model import build_punc_model |
| | | |
| | | |
| | | def build_model(args): |
| | |
| | | model = build_pretrain_model(args) |
| | | elif args.task_name == "lm": |
| | | model = build_lm_model(args) |
| | | elif args.task_name == "punc": |
| | | model = build_punc_model(args) |
| | | else: |
| | | raise NotImplementedError("Not supported task: {}".format(args.task_name)) |
| | | |
| New file |
| | |
| | | import logging |
| | | |
| | | from funasr.models.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.models.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.torch_utils.initialize import initialize |
| | | from funasr.train.abs_model import AbsPunctuation |
| | | from funasr.train.abs_model import PunctuationModel |
| | | from funasr.train.class_choices import ClassChoices |
| | | |
| | | punc_choices = ClassChoices( |
| | | "punctuation", |
| | | classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer), |
| | | type_check=AbsPunctuation, |
| | | default="target_delay", |
| | | ) |
| | | model_choices = ClassChoices( |
| | | "model", |
| | | classes=dict( |
| | | punc=PunctuationModel, |
| | | ), |
| | | default="punc", |
| | | ) |
| | | class_choices_list = [ |
| | | # --punc and --punc_conf |
| | | punc_choices, |
| | | # --model and --model_conf |
| | | model_choices |
| | | ] |
| | | |
| | | |
| | | def build_punc_model(args): |
| | | # token_list and punc list |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | args.token_list = token_list.copy() |
| | | if isinstance(args.punc_list, str): |
| | | with open(args.punc_list, encoding="utf-8") as f2: |
| | | pairs = [line.rstrip().split(":") for line in f2] |
| | | punc_list = [pair[0] for pair in pairs] |
| | | punc_weight_list = [float(pair[1]) for pair in pairs] |
| | | args.punc_list = punc_list.copy() |
| | | elif isinstance(args.punc_list, list): |
| | | punc_list = args.punc_list.copy() |
| | | punc_weight_list = [1] * len(punc_list) |
| | | if isinstance(args.token_list, (tuple, list)): |
| | | token_list = args.token_list.copy() |
| | | else: |
| | | raise RuntimeError("token_list must be str or dict") |
| | | |
| | | vocab_size = len(token_list) |
| | | punc_size = len(punc_list) |
| | | logging.info(f"Vocabulary size: {vocab_size}") |
| | | |
| | | # punc |
| | | punc_class = punc_choices.get_class(args.punctuation) |
| | | punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf) |
| | | |
| | | if "punc_weight" in args.model_conf: |
| | | args.model_conf.pop("punc_weight") |
| | | model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) |
| | | |
| | | # initialize |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | |
| | | return model |