From fac8dbd8210406b95d8b7d43e5ca540ac5cb1995 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 四月 2023 17:36:49 +0800
Subject: [PATCH] update
---
funasr/build_utils/build_model.py | 3
funasr/bin/train.py | 59 +++++++++++
funasr/build_utils/build_args.py | 134 ++------------------------
funasr/build_utils/build_punc_model.py | 67 +++++++++++++
4 files changed, 135 insertions(+), 128 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index c173167..e861199 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -6,18 +6,20 @@
import torch
-from funasr.torch_utils.model_summary import model_summary
-from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
from funasr.build_utils.build_args import build_args
from funasr.build_utils.build_dataloader import build_dataloader
from funasr.build_utils.build_distributed import build_distributed
from funasr.build_utils.build_model import build_model
from funasr.build_utils.build_optimizer import build_optimizer
from funasr.build_utils.build_scheduler import build_scheduler
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.model_summary import model_summary
+from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
from funasr.utils.prepare_data import prepare_data
from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
@@ -281,6 +283,55 @@
help="Apply preprocessing to data or not",
)
+ # most task related
+ parser.add_argument(
+ "--init",
+ type=lambda x: str_or_none(x.lower()),
+ default=None,
+ help="The initialization method",
+ choices=[
+ "chainer",
+ "xavier_uniform",
+ "xavier_normal",
+ "kaiming_uniform",
+ "kaiming_normal",
+ None,
+ ],
+ )
+ parser.add_argument(
+ "--token_list",
+ type=str_or_none,
+ default=None,
+ help="A text mapping int-id to token",
+ )
+ parser.add_argument(
+ "--token_type",
+ type=str,
+ default="bpe",
+ choices=["bpe", "char", "word"],
+ help="",
+ )
+ parser.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model file fo sentencepiece",
+ )
+ parser.add_argument(
+ "--cleaner",
+ type=str_or_none,
+ choices=[None, "tacotron", "jaconv", "vietnamese"],
+ default=None,
+ help="Apply text cleaning",
+ )
+ parser.add_argument(
+ "--g2p",
+ type=str_or_none,
+ choices=g2p_choices,
+ default=None,
+ help="Specify g2p method if --token_type=phn",
+ )
+
# pai related
parser.add_argument(
"--use_pai",
diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py
index 91f2810..fc737ba 100644
--- a/funasr/build_utils/build_args.py
+++ b/funasr/build_utils/build_args.py
@@ -17,12 +17,6 @@
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(parser)
parser.add_argument(
- "--token_list",
- type=str_or_none,
- default=None,
- help="A text mapping int-id to token",
- )
- parser.add_argument(
"--split_with_space",
type=str2bool,
default=True,
@@ -33,20 +27,6 @@
type=str,
default=None,
help="seg_dict_file for text processing",
- )
- parser.add_argument(
- "--init",
- type=lambda x: str_or_none(x.lower()),
- default=None,
- help="The initialization method",
- choices=[
- "chainer",
- "xavier_uniform",
- "xavier_normal",
- "kaiming_uniform",
- "kaiming_normal",
- None,
- ],
)
parser.add_argument(
"--input_size",
@@ -61,31 +41,12 @@
help="The keyword arguments for CTC class.",
)
parser.add_argument(
- "--token_type",
- type=str,
- default="bpe",
- choices=["bpe", "char", "word", "phn"],
- help="The text will be tokenized " "in the specified level token",
- )
- parser.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model file of sentencepiece",
- )
- parser.add_argument(
- "--cleaner",
- type=str_or_none,
- choices=[None, "tacotron", "jaconv", "vietnamese"],
- default=None,
- help="Apply text cleaning",
- )
- parser.add_argument(
"--cmvn_file",
type=str_or_none,
default=None,
help="The file path of noise scp file.",
)
+
elif args.task_name == "pretrain":
from funasr.build_utils.build_pretrain_model import class_choices_list
for class_choices in class_choices_list:
@@ -93,101 +54,26 @@
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(parser)
parser.add_argument(
- "--init",
- type=lambda x: str_or_none(x.lower()),
- default=None,
- help="The initialization method",
- choices=[
- "chainer",
- "xavier_uniform",
- "xavier_normal",
- "kaiming_uniform",
- "kaiming_normal",
- None,
- ],
- )
- parser.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
- parser.add_argument(
- "--feats_type",
- type=str,
- default='fbank',
- help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
- )
- parser.add_argument(
- "--noise_db_range",
- type=str,
- default="13_15",
- help="The range of noise decibel level.",
- )
- parser.add_argument(
- "--pred_masked_weight",
- type=float,
- default=1.0,
- help="weight for predictive loss for masked frames",
- )
- parser.add_argument(
- "--pred_nomask_weight",
- type=float,
- default=0.0,
- help="weight for predictive loss for unmasked frames",
- )
- parser.add_argument(
- "--loss_weights",
- type=float,
- default=0.0,
- help="weights for additional loss terms (not first one)",
- )
+
elif args.task_name == "lm":
from funasr.build_utils.build_lm_model import class_choices_list
for class_choices in class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(parser)
- parser.add_argument(
- "--token_list",
- type=str_or_none,
- default=None,
- help="A text mapping int-id to token",
- )
- parser.add_argument(
- "--init",
- type=lambda x: str_or_none(x.lower()),
- default=None,
- help="The initialization method",
- choices=[
- "chainer",
- "xavier_uniform",
- "xavier_normal",
- "kaiming_uniform",
- "kaiming_normal",
- None,
- ],
- )
- parser.add_argument(
- "--token_type",
- type=str,
- default="bpe",
- choices=["bpe", "char", "word"],
- help="",
- )
- parser.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model file fo sentencepiece",
- )
- parser.add_argument(
- "--cleaner",
- type=str_or_none,
- choices=[None, "tacotron", "jaconv", "vietnamese"],
- default=None,
- help="Apply text cleaning",
- )
+
+ elif args.task_name == "punc":
+ from funasr.build_utils.build_punc_model import class_choices_list
+ for class_choices in class_choices_list:
+ # Append --<name> and --<name>_conf.
+ # e.g. --encoder and --encoder_conf
+ class_choices.add_arguments(parser)
+
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))
diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py
index 8222631..b1d1230 100644
--- a/funasr/build_utils/build_model.py
+++ b/funasr/build_utils/build_model.py
@@ -1,6 +1,7 @@
from funasr.build_utils.build_asr_model import build_asr_model
from funasr.build_utils.build_pretrain_model import build_pretrain_model
from funasr.build_utils.build_lm_model import build_lm_model
+from funasr.build_utils.build_punc_model import build_punc_model
def build_model(args):
@@ -10,6 +11,8 @@
model = build_pretrain_model(args)
elif args.task_name == "lm":
model = build_lm_model(args)
+ elif args.task_name == "punc":
+ model = build_punc_model(args)
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))
diff --git a/funasr/build_utils/build_punc_model.py b/funasr/build_utils/build_punc_model.py
new file mode 100644
index 0000000..d098ffc
--- /dev/null
+++ b/funasr/build_utils/build_punc_model.py
@@ -0,0 +1,67 @@
+import logging
+
+from funasr.models.target_delay_transformer import TargetDelayTransformer
+from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
+from funasr.torch_utils.initialize import initialize
+from funasr.train.abs_model import AbsPunctuation
+from funasr.train.abs_model import PunctuationModel
+from funasr.train.class_choices import ClassChoices
+
+punc_choices = ClassChoices(
+ "punctuation",
+ classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
+ type_check=AbsPunctuation,
+ default="target_delay",
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ punc=PunctuationModel,
+ ),
+ default="punc",
+)
+class_choices_list = [
+ # --punc and --punc_conf
+ punc_choices,
+ # --model and --model_conf
+ model_choices
+]
+
+
+def build_punc_model(args):
+ # token_list and punc list
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+ args.token_list = token_list.copy()
+ if isinstance(args.punc_list, str):
+ with open(args.punc_list, encoding="utf-8") as f2:
+ pairs = [line.rstrip().split(":") for line in f2]
+ punc_list = [pair[0] for pair in pairs]
+ punc_weight_list = [float(pair[1]) for pair in pairs]
+ args.punc_list = punc_list.copy()
+ elif isinstance(args.punc_list, list):
+ punc_list = args.punc_list.copy()
+ punc_weight_list = [1] * len(punc_list)
+ if isinstance(args.token_list, (tuple, list)):
+ token_list = args.token_list.copy()
+ else:
+ raise RuntimeError("token_list must be str or dict")
+
+ vocab_size = len(token_list)
+ punc_size = len(punc_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+
+ # punc
+ punc_class = punc_choices.get_class(args.punctuation)
+ punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf)
+
+ if "punc_weight" in args.model_conf:
+ args.model_conf.pop("punc_weight")
+ model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
+
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
--
Gitblit v1.9.1