speech_asr
2023-04-20 eac9f111b502e4581b14dc718731bf7dc1c7d5f6
update
4个文件已修改
1个文件已添加
168 ■■■■ 已修改文件
funasr/utils/build_args.py 57 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_asr_model.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_lm_model.py 34 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_model.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_pretrain_model.py 74 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_args.py
@@ -79,7 +79,62 @@
            default=None,
            help="The file path of noise scp file.",
        )
    elif args.task_name == "pretrain":
        from funasr.utils.build_pretrain_model import class_choices_list
        for class_choices in class_choices_list:
            # Append --<name> and --<name>_conf.
            # e.g. --encoder and --encoder_conf
            class_choices.add_arguments(parser)
        parser.add_argument(
            "--init",
            type=lambda x: str_or_none(x.lower()),
            default=None,
            help="The initialization method",
            choices=[
                "chainer",
                "xavier_uniform",
                "xavier_normal",
                "kaiming_uniform",
                "kaiming_normal",
                None,
            ],
        )
        parser.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
        parser.add_argument(
            "--feats_type",
            type=str,
            default='fbank',
            help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
        )
        parser.add_argument(
            "--noise_db_range",
            type=str,
            default="13_15",
            help="The range of noise decibel level.",
        )
        parser.add_argument(
            "--pred_masked_weight",
            type=float,
            default=1.0,
            help="weight for predictive loss for masked frames",
        )
        parser.add_argument(
            "--pred_nomask_weight",
            type=float,
            default=0.0,
            help="weight for predictive loss for unmasked frames",
        )
        parser.add_argument(
            "--loss_weights",
            type=float,
            default=0.0,
            help="weights for additional loss terms (not first one)",
        )
    else:
        raise NotImplementedError("Not supported task: {}".format(args.task_name))
funasr/utils/build_asr_model.py
@@ -345,6 +345,7 @@
    else:
        raise NotImplementedError("Not supported model: {}".format(args.model))
    # initialize
    if args.init is not None:
        initialize(model, args.init)
funasr/utils/build_lm_model.py
New file
@@ -0,0 +1,34 @@
from funasr.lm.abs_model import AbsLM
from funasr.lm.seq_rnn_lm import SequentialRNNLM
from funasr.lm.transformer_lm import TransformerLM
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
lm_choices = ClassChoices(
    "lm",
    classes=dict(
        seq_rnn=SequentialRNNLM,
        transformer=TransformerLM,
    ),
    type_check=AbsLM,
    default="seq_rnn",
)
class_choices_list = [
    # --lm and --lm_conf
    lm_choices
]
def build_pretrain_model(args):
    # token_list
    if args.token_list is not None:
        with open(args.token_list) as f:
            token_list = [line.rstrip() for line in f]
        args.token_list = list(token_list)
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
    else:
        vocab_size = None
    return model
funasr/utils/build_model.py
@@ -7,6 +7,8 @@
        model = build_asr_model(args)
    elif args.task_name == "pretrain":
        model = build_pretrain_model(args)
    elif args.task_name == "lm":
        model = build_lm_model(args)
    else:
        raise NotImplementedError("Not supported task: {}".format(args.task_name))
funasr/utils/build_pretrain_model.py
@@ -57,39 +57,39 @@
def build_pretrain_model(args):
    # frontend
    if args.input_size is None:
        frontend_class = frontend_choices.get_class(args.frontend)
        frontend = frontend_class(**args.frontend_conf)
        input_size = frontend.output_size()
    else:
        args.frontend = None
        args.frontend_conf = {}
        frontend = None
        input_size = args.input_size
    # data augmentation for spectrogram
    if args.specaug is not None:
        specaug_class = specaug_choices.get_class(args.specaug)
        specaug = specaug_class(**args.specaug_conf)
    else:
        specaug = None
    # normalization layer
    if args.normalize is not None:
        normalize_class = normalize_choices.get_class(args.normalize)
        normalize = normalize_class(**args.normalize_conf)
    else:
        normalize = None
    # encoder
    encoder_class = encoder_choices.get_class(args.encoder)
    encoder = encoder_class(
        input_size=input_size,
        **args.encoder_conf,
    )
    if args.model_name == "data2vec":
        # frontend
        if args.input_size is None:
            frontend_class = frontend_choices.get_class(args.frontend)
            frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size
        # data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
            specaug = specaug_class(**args.specaug_conf)
        else:
            specaug = None
        # normalization layer
        if args.normalize is not None:
            normalize_class = normalize_choices.get_class(args.normalize)
            normalize = normalize_class(**args.normalize_conf)
        else:
            normalize = None
        # encoder
        encoder_class = encoder_choices.get_class(args.encoder)
        encoder = encoder_class(
            input_size=input_size,
            **args.encoder_conf,
        )
        model_class = model_choices.get_class("data2vec")
        model = model_class(
            frontend=frontend,
@@ -97,9 +97,11 @@
            normalize=normalize,
            encoder=encoder,
        )
    else:
        raise NotImplementedError("Not supported model: {}".format(args.model))
        # 7. Initialize
        if args.init is not None:
            initialize(model, args.init)
    # initialize
    if args.init is not None:
        initialize(model, args.init)
        return model
    return model