speech_asr
2023-04-20 fac8dbd8210406b95d8b7d43e5ca540ac5cb1995
update
3个文件已修改
1个文件已添加
263 ■■■■ 已修改文件
funasr/bin/train.py 59 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_args.py 134 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_model.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_punc_model.py 67 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py
@@ -6,18 +6,20 @@
import torch
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.build_utils.build_args import build_args
from funasr.build_utils.build_dataloader import build_dataloader
from funasr.build_utils.build_distributed import build_distributed
from funasr.build_utils.build_model import build_model
from funasr.build_utils.build_optimizer import build_optimizer
from funasr.build_utils.build_scheduler import build_scheduler
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.prepare_data import prepare_data
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
@@ -281,6 +283,55 @@
        help="Apply preprocessing to data or not",
    )
    # most task related
    parser.add_argument(
        "--init",
        type=lambda x: str_or_none(x.lower()),
        default=None,
        help="The initialization method",
        choices=[
            "chainer",
            "xavier_uniform",
            "xavier_normal",
            "kaiming_uniform",
            "kaiming_normal",
            None,
        ],
    )
    parser.add_argument(
        "--token_list",
        type=str_or_none,
        default=None,
        help="A text mapping int-id to token",
    )
    parser.add_argument(
        "--token_type",
        type=str,
        default="bpe",
        choices=["bpe", "char", "word"],
        help="",
    )
    parser.add_argument(
        "--bpemodel",
        type=str_or_none,
        default=None,
        help="The model file fo sentencepiece",
    )
    parser.add_argument(
        "--cleaner",
        type=str_or_none,
        choices=[None, "tacotron", "jaconv", "vietnamese"],
        default=None,
        help="Apply text cleaning",
    )
    parser.add_argument(
        "--g2p",
        type=str_or_none,
        choices=g2p_choices,
        default=None,
        help="Specify g2p method if --token_type=phn",
    )
    # pai related
    parser.add_argument(
        "--use_pai",
funasr/build_utils/build_args.py
@@ -17,12 +17,6 @@
            # e.g. --encoder and --encoder_conf
            class_choices.add_arguments(parser)
        parser.add_argument(
            "--token_list",
            type=str_or_none,
            default=None,
            help="A text mapping int-id to token",
        )
        parser.add_argument(
            "--split_with_space",
            type=str2bool,
            default=True,
@@ -33,20 +27,6 @@
            type=str,
            default=None,
            help="seg_dict_file for text processing",
        )
        parser.add_argument(
            "--init",
            type=lambda x: str_or_none(x.lower()),
            default=None,
            help="The initialization method",
            choices=[
                "chainer",
                "xavier_uniform",
                "xavier_normal",
                "kaiming_uniform",
                "kaiming_normal",
                None,
            ],
        )
        parser.add_argument(
            "--input_size",
@@ -61,31 +41,12 @@
            help="The keyword arguments for CTC class.",
        )
        parser.add_argument(
            "--token_type",
            type=str,
            default="bpe",
            choices=["bpe", "char", "word", "phn"],
            help="The text will be tokenized " "in the specified level token",
        )
        parser.add_argument(
            "--bpemodel",
            type=str_or_none,
            default=None,
            help="The model file of sentencepiece",
        )
        parser.add_argument(
            "--cleaner",
            type=str_or_none,
            choices=[None, "tacotron", "jaconv", "vietnamese"],
            default=None,
            help="Apply text cleaning",
        )
        parser.add_argument(
            "--cmvn_file",
            type=str_or_none,
            default=None,
            help="The file path of noise scp file.",
        )
    elif args.task_name == "pretrain":
        from funasr.build_utils.build_pretrain_model import class_choices_list
        for class_choices in class_choices_list:
@@ -93,101 +54,26 @@
            # e.g. --encoder and --encoder_conf
            class_choices.add_arguments(parser)
        parser.add_argument(
            "--init",
            type=lambda x: str_or_none(x.lower()),
            default=None,
            help="The initialization method",
            choices=[
                "chainer",
                "xavier_uniform",
                "xavier_normal",
                "kaiming_uniform",
                "kaiming_normal",
                None,
            ],
        )
        parser.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
        parser.add_argument(
            "--feats_type",
            type=str,
            default='fbank',
            help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
        )
        parser.add_argument(
            "--noise_db_range",
            type=str,
            default="13_15",
            help="The range of noise decibel level.",
        )
        parser.add_argument(
            "--pred_masked_weight",
            type=float,
            default=1.0,
            help="weight for predictive loss for masked frames",
        )
        parser.add_argument(
            "--pred_nomask_weight",
            type=float,
            default=0.0,
            help="weight for predictive loss for unmasked frames",
        )
        parser.add_argument(
            "--loss_weights",
            type=float,
            default=0.0,
            help="weights for additional loss terms (not first one)",
        )
    elif args.task_name == "lm":
        from funasr.build_utils.build_lm_model import class_choices_list
        for class_choices in class_choices_list:
            # Append --<name> and --<name>_conf.
            # e.g. --encoder and --encoder_conf
            class_choices.add_arguments(parser)
        parser.add_argument(
            "--token_list",
            type=str_or_none,
            default=None,
            help="A text mapping int-id to token",
        )
        parser.add_argument(
            "--init",
            type=lambda x: str_or_none(x.lower()),
            default=None,
            help="The initialization method",
            choices=[
                "chainer",
                "xavier_uniform",
                "xavier_normal",
                "kaiming_uniform",
                "kaiming_normal",
                None,
            ],
        )
        parser.add_argument(
            "--token_type",
            type=str,
            default="bpe",
            choices=["bpe", "char", "word"],
            help="",
        )
        parser.add_argument(
            "--bpemodel",
            type=str_or_none,
            default=None,
            help="The model file fo sentencepiece",
        )
        parser.add_argument(
            "--cleaner",
            type=str_or_none,
            choices=[None, "tacotron", "jaconv", "vietnamese"],
            default=None,
            help="Apply text cleaning",
        )
    elif args.task_name == "punc":
        from funasr.build_utils.build_punc_model import class_choices_list
        for class_choices in class_choices_list:
            # Append --<name> and --<name>_conf.
            # e.g. --encoder and --encoder_conf
            class_choices.add_arguments(parser)
    else:
        raise NotImplementedError("Not supported task: {}".format(args.task_name))
funasr/build_utils/build_model.py
@@ -1,6 +1,7 @@
from funasr.build_utils.build_asr_model import build_asr_model
from funasr.build_utils.build_pretrain_model import build_pretrain_model
from funasr.build_utils.build_lm_model import build_lm_model
from funasr.build_utils.build_punc_model import build_punc_model
def build_model(args):
@@ -10,6 +11,8 @@
        model = build_pretrain_model(args)
    elif args.task_name == "lm":
        model = build_lm_model(args)
    elif args.task_name == "punc":
        model = build_punc_model(args)
    else:
        raise NotImplementedError("Not supported task: {}".format(args.task_name))
funasr/build_utils/build_punc_model.py
New file
@@ -0,0 +1,67 @@
import logging
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.torch_utils.initialize import initialize
from funasr.train.abs_model import AbsPunctuation
from funasr.train.abs_model import PunctuationModel
from funasr.train.class_choices import ClassChoices
punc_choices = ClassChoices(
    "punctuation",
    classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
    type_check=AbsPunctuation,
    default="target_delay",
)
model_choices = ClassChoices(
    "model",
    classes=dict(
        punc=PunctuationModel,
    ),
    default="punc",
)
class_choices_list = [
    # --punc and --punc_conf
    punc_choices,
    # --model and --model_conf
    model_choices
]
def build_punc_model(args):
    # token_list and punc list
    if isinstance(args.token_list, str):
        with open(args.token_list, encoding="utf-8") as f:
            token_list = [line.rstrip() for line in f]
        args.token_list = token_list.copy()
    if isinstance(args.punc_list, str):
        with open(args.punc_list, encoding="utf-8") as f2:
            pairs = [line.rstrip().split(":") for line in f2]
        punc_list = [pair[0] for pair in pairs]
        punc_weight_list = [float(pair[1]) for pair in pairs]
        args.punc_list = punc_list.copy()
    elif isinstance(args.punc_list, list):
        punc_list = args.punc_list.copy()
        punc_weight_list = [1] * len(punc_list)
    if isinstance(args.token_list, (tuple, list)):
        token_list = args.token_list.copy()
    else:
        raise RuntimeError("token_list must be str or dict")
    vocab_size = len(token_list)
    punc_size = len(punc_list)
    logging.info(f"Vocabulary size: {vocab_size}")
    # punc
    punc_class = punc_choices.get_class(args.punctuation)
    punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf)
    if "punc_weight" in args.model_conf:
        args.model_conf.pop("punc_weight")
    model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
    # initialize
    if args.init is not None:
        initialize(model, args.init)
    return model