| | |
| | | from funasr.models.e2e_sa_asr import SAASRModel |
| | | from funasr.models.e2e_uni_asr import UniASR |
| | | from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel |
| | | from funasr.models.e2e_asr_bat import BATModel |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder |
| | | from funasr.models.encoder.data2vec_encoder import Data2VecEncoder |
| | |
| | | from funasr.models.postencoder.hugging_face_transformers_postencoder import ( |
| | | HuggingFaceTransformersPostEncoder, # noqa: H301 |
| | | ) |
| | | from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3 |
| | | from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor |
| | | from funasr.models.preencoder.abs_preencoder import AbsPreEncoder |
| | | from funasr.models.preencoder.linear import LinearProjection |
| | | from funasr.models.preencoder.sinc import LightweightSincConvs |
| | |
| | | timestamp_prediction=TimestampPredictor, |
| | | rnnt=TransducerModel, |
| | | rnnt_unified=UnifiedTransducerModel, |
| | | bat=BATModel, |
| | | sa_asr=SAASRModel, |
| | | ), |
| | | type_check=FunASRModel, |
| | |
| | | ctc_predictor=None, |
| | | cif_predictor_v2=CifPredictorV2, |
| | | cif_predictor_v3=CifPredictorV3, |
| | | bat_predictor=BATPredictor, |
| | | ), |
| | | type_check=None, |
| | | default="cif_predictor", |
| | |
| | | |
| | | return model |
| | | |
| | | class ASRBATTask(ASRTask): |
| | | """ASR Boundary Aware Transducer Task definition.""" |
| | | |
| | | num_optimizers: int = 1 |
| | | |
| | | class_choices_list = [ |
| | | model_choices, |
| | | frontend_choices, |
| | | specaug_choices, |
| | | normalize_choices, |
| | | encoder_choices, |
| | | rnnt_decoder_choices, |
| | | joint_network_choices, |
| | | predictor_choices, |
| | | ] |
| | | |
| | | trainer = Trainer |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace) -> BATModel: |
| | | """Required data depending on task mode. |
| | | Args: |
| | | cls: ASRBATTask object. |
| | | args: Task arguments. |
| | | Return: |
| | | model: ASR BAT model. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | |
| | | # Overwriting token_list to keep it as "portable". |
| | | args.token_list = list(token_list) |
| | | elif isinstance(args.token_list, (tuple, list)): |
| | | token_list = list(args.token_list) |
| | | else: |
| | | raise RuntimeError("token_list must be str or list") |
| | | vocab_size = len(token_list) |
| | | logging.info(f"Vocabulary size: {vocab_size }") |
| | | |
| | | # 1. frontend |
| | | if args.input_size is None: |
| | | # Extract features in the model |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | | # Give features from data-loader |
| | | frontend = None |
| | | input_size = args.input_size |
| | | |
| | | # 2. Data augmentation for spectrogram |
| | | if args.specaug is not None: |
| | | specaug_class = specaug_choices.get_class(args.specaug) |
| | | specaug = specaug_class(**args.specaug_conf) |
| | | else: |
| | | specaug = None |
| | | |
| | | # 3. Normalization layer |
| | | if args.normalize is not None: |
| | | normalize_class = normalize_choices.get_class(args.normalize) |
| | | normalize = normalize_class(**args.normalize_conf) |
| | | else: |
| | | normalize = None |
| | | |
| | | # 4. Encoder |
| | | if getattr(args, "encoder", None) is not None: |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class(input_size, **args.encoder_conf) |
| | | else: |
| | | encoder = Encoder(input_size, **args.encoder_conf) |
| | | encoder_output_size = encoder.output_size() |
| | | |
| | | # 5. Decoder |
| | | rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder) |
| | | decoder = rnnt_decoder_class( |
| | | vocab_size, |
| | | **args.rnnt_decoder_conf, |
| | | ) |
| | | decoder_output_size = decoder.output_size |
| | | |
| | | if getattr(args, "decoder", None) is not None: |
| | | att_decoder_class = decoder_choices.get_class(args.decoder) |
| | | |
| | | att_decoder = att_decoder_class( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | | **args.decoder_conf, |
| | | ) |
| | | else: |
| | | att_decoder = None |
| | | # 6. Joint Network |
| | | joint_network = JointNetwork( |
| | | vocab_size, |
| | | encoder_output_size, |
| | | decoder_output_size, |
| | | **args.joint_network_conf, |
| | | ) |
| | | |
| | | predictor_class = predictor_choices.get_class(args.predictor) |
| | | predictor = predictor_class(**args.predictor_conf) |
| | | |
| | | # 7. Build model |
| | | try: |
| | | model_class = model_choices.get_class(args.model) |
| | | except AttributeError: |
| | | model_class = model_choices.get_class("rnnt_unified") |
| | | |
| | | model = model_class( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | att_decoder=att_decoder, |
| | | joint_network=joint_network, |
| | | predictor=predictor, |
| | | **args.model_conf, |
| | | ) |
| | | # 8. Initialize model |
| | | if args.init is not None: |
| | | raise NotImplementedError( |
| | | "Currently not supported.", |
| | | "Initialization part will be reworked in a short future.", |
| | | ) |
| | | |
| | | #assert check_return_type(model) |
| | | |
| | | return model |
| | | |
| | | class ASRTaskSAASR(ASRTask): |
| | | # If you need more than one optimizers, change this value |