| | |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.preprocessor import MutliTokenizerCommonPreprocessor |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.punctuation.espnet_model import ESPnetPunctuationModel |
| | | from funasr.punctuation.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor |
| | | from funasr.train.abs_model import PunctuationModel |
| | | from funasr.models.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.models.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.torch_utils.initialize import initialize |
| | |
| | | |
| | | punc_choices = ClassChoices( |
| | | "punctuation", |
| | | classes=dict( |
| | | target_delay=TargetDelayTransformer, |
| | | ), |
| | | type_check=AbsPunctuation, |
| | | default="TargetDelayTransformer", |
| | | classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer), |
| | | default="target_delay", |
| | | ) |
| | | |
| | | |
| | |
| | | @classmethod |
| | | def add_task_arguments(cls, parser: argparse.ArgumentParser): |
| | | # NOTE(kamo): Use '_' instead of '-' to avoid confusion |
| | | assert check_argument_types() |
| | | group = parser.add_argument_group(description="Task related") |
| | | |
| | | # NOTE(kamo): add_arguments(..., required=True) can't be used |
| | | # to provide --print_config mode. Instead of it, do as |
| | | required = parser.get_default("required") |
| | | #import pdb;pdb.set_trace() |
| | | #required += ["token_list"] |
| | | |
| | | group.add_argument( |
| | | "--token_list", |
| | |
| | | group.add_argument( |
| | | "--model_conf", |
| | | action=NestedDictAction, |
| | | default=get_default_kwargs(ESPnetPunctuationModel), |
| | | default=get_default_kwargs(PunctuationModel), |
| | | help="The keyword arguments for model class.", |
| | | ) |
| | | |
| | |
| | | # e.g. --encoder and --encoder_conf |
| | | class_choices.add_arguments(group) |
| | | |
| | | assert check_return_type(parser) |
| | | return parser |
| | | |
| | | @classmethod |
| | |
| | | [Collection[Tuple[str, Dict[str, np.ndarray]]]], |
| | | Tuple[List[str], Dict[str, torch.Tensor]], |
| | | ]: |
| | | assert check_argument_types() |
| | | return CommonCollateFn(int_pad_value=0) |
| | | |
| | | @classmethod |
| | | def build_preprocess_fn( |
| | | cls, args: argparse.Namespace, train: bool |
| | | ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: |
| | | assert check_argument_types() |
| | | token_types = [args.token_type, args.token_type] |
| | | token_lists = [args.token_list, args.punc_list] |
| | | bpemodels = [args.bpemodel, args.bpemodel] |
| | | text_names = ["text", "punc"] |
| | | if args.use_preprocessor: |
| | | retval = MutliTokenizerCommonPreprocessor( |
| | | retval = PuncTrainTokenizerCommonPreprocessor( |
| | | train=train, |
| | | token_type=token_types, |
| | | token_list=token_lists, |
| | |
| | | ) |
| | | else: |
| | | retval = None |
| | | assert check_return_type(retval) |
| | | return retval |
| | | |
| | | @classmethod |
| | |
| | | def optional_data_names( |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | retval = () |
| | | retval = ("vad",) |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel: |
| | | assert check_argument_types() |
| | | def build_model(cls, args: argparse.Namespace) -> PunctuationModel: |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | | token_list = [line.rstrip() for line in f] |
| | |
| | | args.token_list = token_list.copy() |
| | | if isinstance(args.punc_list, str): |
| | | with open(args.punc_list, encoding="utf-8") as f2: |
| | | punc_list = [line.rstrip() for line in f2] |
| | | pairs = [line.rstrip().split(":") for line in f2] |
| | | punc_list = [pair[0] for pair in pairs] |
| | | punc_weight_list = [float(pair[1]) for pair in pairs] |
| | | args.punc_list = punc_list.copy() |
| | | elif isinstance(args.punc_list, list): |
| | | # This is in the inference code path. |
| | | punc_list = args.punc_list.copy() |
| | | punc_weight_list = [1] * len(punc_list) |
| | | if isinstance(args.token_list, (tuple, list)): |
| | | token_list = args.token_list.copy() |
| | | else: |
| | |
| | | |
| | | # 2. Build ESPnetModel |
| | | # Assume the last-id is sos_and_eos |
| | | model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, **args.model_conf) |
| | | if "punc_weight" in args.model_conf: |
| | | args.model_conf.pop("punc_weight") |
| | | model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) |
| | | |
| | | # FIXME(kamo): Should be done in model? |
| | | # 3. Initialize |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | |
| | | assert check_return_type(model) |
| | | return model |