| | |
| | | from funasr.tasks.asr import ASRTaskParaformer as ASRTask |
| | | if args.mode == "uniasr": |
| | | from funasr.tasks.asr import ASRTaskUniASR as ASRTask |
| | | if args.mode == "rnnt": |
| | | from funasr.tasks.asr import ASRTransducerTask as ASRTask |
| | | |
| | | ASRTask.main(args=args, cmd=cmd) |
| | | |
| | |
| | | ) |
| | | from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN |
| | | from funasr.models.decoder.transformer_decoder import TransformerDecoder |
| | | from funasr.models.decoder.rnnt_decoder import RNNTDecoder |
| | | from funasr.models.joint_net.joint_network import JointNetwork |
| | | from funasr.models.e2e_asr import ASRModel |
| | | from funasr.models.e2e_asr_mfcca import MFCCA |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer |
| | | from funasr.models.e2e_tp import TimestampPredictor |
| | | from funasr.models.e2e_uni_asr import UniASR |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder |
| | | from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder |
| | | from funasr.models.encoder.data2vec_encoder import Data2VecEncoder |
| | | from funasr.models.encoder.mfcca_encoder import MFCCAEncoder |
| | | from funasr.models.encoder.rnn_encoder import RNNEncoder |
| | |
| | | sanm_chunk_opt=SANMEncoderChunkOpt, |
| | | data2vec_encoder=Data2VecEncoder, |
| | | mfcca_enc=MFCCAEncoder, |
| | | chunk_conformer=ConformerChunkEncoder, |
| | | ), |
| | | default="rnn", |
| | | ) |
| | |
| | | default="stride_conv1d", |
| | | optional=True, |
| | | ) |
| | | rnnt_decoder_choices = ClassChoices( |
| | | name="rnnt_decoder", |
| | | classes=dict( |
| | | rnnt=RNNTDecoder, |
| | | ), |
| | | default="rnnt", |
| | | optional=True, |
| | | ) |
| | | joint_network_choices = ClassChoices( |
| | | name="joint_network", |
| | | classes=dict( |
| | | joint_network=JointNetwork, |
| | | ), |
| | | default="joint_network", |
| | | optional=True, |
| | | ) |
| | | |
| | | class_choices_list = [ |
| | | # --frontend and --frontend_conf |
| | | frontend_choices, |
| | |
| | | predictor_choices2, |
| | | # --stride_conv and --stride_conv_conf |
| | | stride_conv_choices, |
| | | # --rnnt_decoder and --rnnt_decoder_conf |
| | | rnnt_decoder_choices, |
| | | # --joint_network and --joint_network_conf |
| | | joint_network_choices, |
| | | ] |
| | | |
| | | |
| | |
| | | token_list=token_list, |
| | | **args.model_conf, |
| | | ) |
| | | elif args.model == "rnnt": |
| | | # 5. Decoder |
| | | encoder_output_size = encoder.output_size() |
| | | |
| | | rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder) |
| | | decoder = rnnt_decoder_class( |
| | | vocab_size, |
| | | **args.rnnt_decoder_conf, |
| | | ) |
| | | decoder_output_size = decoder.output_size |
| | | |
| | | if getattr(args, "decoder", None) is not None: |
| | | att_decoder_class = decoder_choices.get_class(args.decoder) |
| | | |
| | | att_decoder = att_decoder_class( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | | **args.decoder_conf, |
| | | ) |
| | | else: |
| | | att_decoder = None |
| | | # 6. Joint Network |
| | | joint_network = JointNetwork( |
| | | vocab_size, |
| | | encoder_output_size, |
| | | decoder_output_size, |
| | | **args.joint_network_conf, |
| | | ) |
| | | |
| | | # 7. Build model |
| | | if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training: |
| | | model = UnifiedTransducerModel( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | att_decoder=att_decoder, |
| | | joint_network=joint_network, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | else: |
| | | model = TransducerModel( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | att_decoder=att_decoder, |
| | | joint_network=joint_network, |
| | | **args.model_conf, |
| | | ) |
| | | else: |
| | | raise NotImplementedError("Not supported model: {}".format(args.model)) |
| | | |
| | |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | |
| | | return model |
| | | return model |
| | |
| | | limit_size, |
| | | ) |
| | | |
| | | mask = make_source_mask(x_len) |
| | | mask = make_source_mask(x_len).to(x.device) |
| | | |
| | | if self.unified_model_training: |
| | | chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() |
| | |
| | | predictor_choices2, |
| | | # --stride_conv and --stride_conv_conf |
| | | stride_conv_choices, |
| | | # --rnnt_decoder and --rnnt_decoder_conf |
| | | rnnt_decoder_choices, |
| | | ] |
| | | |
| | | # If you need to modify train() or eval() procedures, change Trainer class here |
| | |
| | | return retval |
| | | |
| | | |
| | | class ASRTransducerTask(AbsTask): |
| | | class ASRTransducerTask(ASRTask): |
| | | """ASR Transducer Task definition.""" |
| | | |
| | | num_optimizers: int = 1 |
| | |
| | | normalize_choices, |
| | | encoder_choices, |
| | | rnnt_decoder_choices, |
| | | joint_network_choices, |
| | | ] |
| | | |
| | | trainer = Trainer |
| | | |
| | | @classmethod |
| | | def add_task_arguments(cls, parser: argparse.ArgumentParser): |
| | | """Add Transducer task arguments. |
| | | Args: |
| | | cls: ASRTransducerTask object. |
| | | parser: Transducer arguments parser. |
| | | """ |
| | | group = parser.add_argument_group(description="Task related.") |
| | | |
| | | # required = parser.get_default("required") |
| | | # required += ["token_list"] |
| | | |
| | | group.add_argument( |
| | | "--token_list", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="Integer-string mapper for tokens.", |
| | | ) |
| | | group.add_argument( |
| | | "--split_with_space", |
| | | type=str2bool, |
| | | default=True, |
| | | help="whether to split text using <space>", |
| | | ) |
| | | group.add_argument( |
| | | "--input_size", |
| | | type=int_or_none, |
| | | default=None, |
| | | help="The number of dimensions for input features.", |
| | | ) |
| | | group.add_argument( |
| | | "--init", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="Type of model initialization to use.", |
| | | ) |
| | | group.add_argument( |
| | | "--model_conf", |
| | | action=NestedDictAction, |
| | | default=get_default_kwargs(TransducerModel), |
| | | help="The keyword arguments for the model class.", |
| | | ) |
| | | # group.add_argument( |
| | | # "--encoder_conf", |
| | | # action=NestedDictAction, |
| | | # default={}, |
| | | # help="The keyword arguments for the encoder class.", |
| | | # ) |
| | | group.add_argument( |
| | | "--joint_network_conf", |
| | | action=NestedDictAction, |
| | | default={}, |
| | | help="The keyword arguments for the joint network class.", |
| | | ) |
| | | group = parser.add_argument_group(description="Preprocess related.") |
| | | group.add_argument( |
| | | "--use_preprocessor", |
| | | type=str2bool, |
| | | default=True, |
| | | help="Whether to apply preprocessing to input data.", |
| | | ) |
| | | group.add_argument( |
| | | "--token_type", |
| | | type=str, |
| | | default="bpe", |
| | | choices=["bpe", "char", "word", "phn"], |
| | | help="The type of tokens to use during tokenization.", |
| | | ) |
| | | group.add_argument( |
| | | "--bpemodel", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The path of the sentencepiece model.", |
| | | ) |
| | | parser.add_argument( |
| | | "--non_linguistic_symbols", |
| | | type=str_or_none, |
| | | help="The 'non_linguistic_symbols' file path.", |
| | | ) |
| | | parser.add_argument( |
| | | "--cleaner", |
| | | type=str_or_none, |
| | | choices=[None, "tacotron", "jaconv", "vietnamese"], |
| | | default=None, |
| | | help="Text cleaner to use.", |
| | | ) |
| | | parser.add_argument( |
| | | "--g2p", |
| | | type=str_or_none, |
| | | choices=g2p_choices, |
| | | default=None, |
| | | help="g2p method to use if --token_type=phn.", |
| | | ) |
| | | parser.add_argument( |
| | | "--speech_volume_normalize", |
| | | type=float_or_none, |
| | | default=None, |
| | | help="Normalization value for maximum amplitude scaling.", |
| | | ) |
| | | parser.add_argument( |
| | | "--rir_scp", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The RIR SCP file path.", |
| | | ) |
| | | parser.add_argument( |
| | | "--rir_apply_prob", |
| | | type=float, |
| | | default=1.0, |
| | | help="The probability of the applied RIR convolution.", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_scp", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The path of noise SCP file.", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_apply_prob", |
| | | type=float, |
| | | default=1.0, |
| | | help="The probability of the applied noise addition.", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_db_range", |
| | | type=str, |
| | | default="13_15", |
| | | help="The range of the noise decibel level.", |
| | | ) |
| | | for class_choices in cls.class_choices_list: |
| | | # Append --<name> and --<name>_conf. |
| | | # e.g. --decoder and --decoder_conf |
| | | class_choices.add_arguments(group) |
| | | |
| | | @classmethod |
| | | def build_collate_fn( |
| | | cls, args: argparse.Namespace, train: bool |
| | | ) -> Callable[ |
| | | [Collection[Tuple[str, Dict[str, np.ndarray]]]], |
| | | Tuple[List[str], Dict[str, torch.Tensor]], |
| | | ]: |
| | | """Build collate function. |
| | | Args: |
| | | cls: ASRTransducerTask object. |
| | | args: Task arguments. |
| | | train: Training mode. |
| | | Return: |
| | | : Callable collate function. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) |
| | | |
| | | @classmethod |
| | | def build_preprocess_fn( |
| | | cls, args: argparse.Namespace, train: bool |
| | | ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: |
| | | """Build pre-processing function. |
| | | Args: |
| | | cls: ASRTransducerTask object. |
| | | args: Task arguments. |
| | | train: Training mode. |
| | | Return: |
| | | : Callable pre-processing function. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | if args.use_preprocessor: |
| | | retval = CommonPreprocessor( |
| | | train=train, |
| | | token_type=args.token_type, |
| | | token_list=args.token_list, |
| | | bpemodel=args.bpemodel, |
| | | non_linguistic_symbols=args.non_linguistic_symbols, |
| | | text_cleaner=args.cleaner, |
| | | g2p_type=args.g2p, |
| | | split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, |
| | | rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, |
| | | rir_apply_prob=args.rir_apply_prob |
| | | if hasattr(args, "rir_apply_prob") |
| | | else 1.0, |
| | | noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, |
| | | noise_apply_prob=args.noise_apply_prob |
| | | if hasattr(args, "noise_apply_prob") |
| | | else 1.0, |
| | | noise_db_range=args.noise_db_range |
| | | if hasattr(args, "noise_db_range") |
| | | else "13_15", |
| | | speech_volume_normalize=args.speech_volume_normalize |
| | | if hasattr(args, "rir_scp") |
| | | else None, |
| | | ) |
| | | else: |
| | | retval = None |
| | | |
| | | assert check_return_type(retval) |
| | | return retval |
| | | |
| | | @classmethod |
| | | def required_data_names( |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | """Required data depending on task mode. |
| | | Args: |
| | | cls: ASRTransducerTask object. |
| | | train: Training mode. |
| | | inference: Inference mode. |
| | | Return: |
| | | retval: Required task data. |
| | | """ |
| | | if not inference: |
| | | retval = ("speech", "text") |
| | | else: |
| | | retval = ("speech",) |
| | | |
| | | return retval |
| | | |
| | | @classmethod |
| | | def optional_data_names( |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | """Optional data depending on task mode. |
| | | Args: |
| | | cls: ASRTransducerTask object. |
| | | train: Training mode. |
| | | inference: Inference mode. |
| | | Return: |
| | | retval: Optional task data. |
| | | """ |
| | | retval = () |
| | | assert check_return_type(retval) |
| | | |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace) -> TransducerModel: |