| | |
| | | from funasr.layers.label_aggregation import LabelAggregate |
| | | from funasr.layers.utterance_mvn import UtteranceMVN |
| | | from funasr.models.e2e_diar_sond import DiarSondModel |
| | | from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder |
| | | from funasr.models.encoder.data2vec_encoder import Data2VecEncoder |
| | |
| | | "model", |
| | | classes=dict( |
| | | sond=DiarSondModel, |
| | | eend_ola=DiarEENDOLAModel, |
| | | ), |
| | | type_check=AbsESPnetModel, |
| | | default="sond", |
| | |
| | | if ".bin" in model_name: |
| | | model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb')) |
| | | else: |
| | | model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name)) |
| | | model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name)) |
| | | if os.path.exists(model_name_pth): |
| | | logging.info("model_file is load from pth: {}".format(model_name_pth)) |
| | | model_dict = torch.load(model_name_pth, map_location=device) |
| | |
| | | cls, args: argparse.Namespace, train: bool |
| | | ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: |
| | | assert check_argument_types() |
| | | if args.use_preprocessor: |
| | | retval = CommonPreprocessor( |
| | | train=train, |
| | | token_type=args.token_type, |
| | | token_list=args.token_list, |
| | | bpemodel=None, |
| | | non_linguistic_symbols=None, |
| | | text_cleaner=None, |
| | | g2p_type=None, |
| | | split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, |
| | | # NOTE(kamo): Check attribute existence for backward compatibility |
| | | rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, |
| | | rir_apply_prob=args.rir_apply_prob |
| | | if hasattr(args, "rir_apply_prob") |
| | | else 1.0, |
| | | noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, |
| | | noise_apply_prob=args.noise_apply_prob |
| | | if hasattr(args, "noise_apply_prob") |
| | | else 1.0, |
| | | noise_db_range=args.noise_db_range |
| | | if hasattr(args, "noise_db_range") |
| | | else "13_15", |
| | | speech_volume_normalize=args.speech_volume_normalize |
| | | if hasattr(args, "rir_scp") |
| | | else None, |
| | | ) |
| | | else: |
| | | retval = None |
| | | assert check_return_type(retval) |
| | | return retval |
| | | # if args.use_preprocessor: |
| | | # retval = CommonPreprocessor( |
| | | # train=train, |
| | | # token_type=args.token_type, |
| | | # token_list=args.token_list, |
| | | # bpemodel=None, |
| | | # non_linguistic_symbols=None, |
| | | # text_cleaner=None, |
| | | # g2p_type=None, |
| | | # split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, |
| | | # seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, |
| | | # # NOTE(kamo): Check attribute existence for backward compatibility |
| | | # rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, |
| | | # rir_apply_prob=args.rir_apply_prob |
| | | # if hasattr(args, "rir_apply_prob") |
| | | # else 1.0, |
| | | # noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, |
| | | # noise_apply_prob=args.noise_apply_prob |
| | | # if hasattr(args, "noise_apply_prob") |
| | | # else 1.0, |
| | | # noise_db_range=args.noise_db_range |
| | | # if hasattr(args, "noise_db_range") |
| | | # else "13_15", |
| | | # speech_volume_normalize=args.speech_volume_normalize |
| | | # if hasattr(args, "rir_scp") |
| | | # else None, |
| | | # ) |
| | | # else: |
| | | # retval = None |
| | | # assert check_return_type(retval) |
| | | return None |
| | | |
| | | @classmethod |
| | | def required_data_names( |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | if not inference: |
| | | retval = ("speech", "profile", "binary_labels") |
| | | retval = ("speech", ) |
| | | else: |
| | | # Recognition mode |
| | | retval = ("speech") |
| | | retval = ("speech", ) |
| | | return retval |
| | | |
| | | @classmethod |
| | |
| | | |
| | | # 2. Encoder |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class(input_size=input_size, **args.encoder_conf) |
| | | encoder = encoder_class(**args.encoder_conf) |
| | | |
| | | # 3. EncoderDecoderAttractor |
| | | encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor) |