| | |
| | | from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN |
| | | from funasr.models.decoder.transformer_decoder import TransformerDecoder |
| | | from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder |
| | | from funasr.models.e2e_asr import ASRModel |
| | | from funasr.models.decoder.rnnt_decoder import RNNTDecoder |
| | | from funasr.models.joint_net.joint_network import JointNetwork |
| | | from funasr.models.e2e_asr import ESPnetASRModel |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer |
| | | from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer |
| | | from funasr.models.e2e_tp import TimestampPredictor |
| | |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.torch_utils.initialize import initialize |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.train.class_choices import ClassChoices |
| | | from funasr.train.trainer import Trainer |
| | | from funasr.utils.get_default_kwargs import get_default_kwargs |
| | |
| | | model_choices = ClassChoices( |
| | | "model", |
| | | classes=dict( |
| | | asr=ESPnetASRModel, |
| | | asr=ASRModel, |
| | | uniasr=UniASR, |
| | | paraformer=Paraformer, |
| | | paraformer_online=ParaformerOnline, |
| | |
| | | mfcca=MFCCA, |
| | | timestamp_prediction=TimestampPredictor, |
| | | ), |
| | | type_check=AbsESPnetModel, |
| | | type_check=FunASRModel, |
| | | default="asr", |
| | | ) |
| | | preencoder_choices = ClassChoices( |
| | |
| | | token_type=args.token_type, |
| | | token_list=args.token_list, |
| | | bpemodel=args.bpemodel, |
| | | non_linguistic_symbols=args.non_linguistic_symbols, |
| | | non_linguistic_symbols=args.non_linguistic_symbols if hasattr(args, "non_linguistic_symbols") else None, |
| | | text_cleaner=args.cleaner, |
| | | g2p_type=args.g2p, |
| | | split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, |
| | |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | if not isinstance(model, FunASRModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | ) |
| | | model.to(device) |
| | | model_dict = dict() |
| | |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | if not isinstance(model, FunASRModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | ) |
| | | model.to(device) |
| | | model_dict = dict() |