| | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | from pathlib import Path |
| | | from typing import Callable |
| | | from typing import Collection |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | from typing import Union |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import yaml |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.preprocessor import CommonPreprocessor |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.layers.global_mvn import GlobalMVN |
| | | from funasr.layers.utterance_mvn import UtteranceMVN |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.decoder.rnn_decoder import RNNDecoder |
| | | from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt |
| | | from funasr.models.decoder.transformer_decoder import ( |
| | | DynamicConvolution2DTransformerDecoder, # noqa: H301 |
| | | ) |
| | |
| | | from funasr.models.decoder.transformer_decoder import ( |
| | | LightweightConvolutionTransformerDecoder, # noqa: H301 |
| | | ) |
| | | from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN |
| | | from funasr.models.decoder.transformer_decoder import TransformerDecoder |
| | | from funasr.models.e2e_asr import ESPnetASRModel |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer |
| | | from funasr.models.e2e_uni_asr import UniASR |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder |
| | | from funasr.models.encoder.data2vec_encoder import Data2VecEncoder |
| | | from funasr.models.encoder.rnn_encoder import RNNEncoder |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt |
| | | from funasr.models.encoder.transformer_encoder import TransformerEncoder |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.frontend.default import DefaultFrontend |
| | | from funasr.models.frontend.fused import FusedFrontends |
| | | from funasr.models.frontend.s3prl import S3prlFrontend |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.models.frontend.windowing import SlidingWindow |
| | | from funasr.models.postencoder.abs_postencoder import AbsPostEncoder |
| | | from funasr.models.postencoder.hugging_face_transformers_postencoder import ( |
| | | HuggingFaceTransformersPostEncoder, # noqa: H301 |
| | | ) |
| | | from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3 |
| | | from funasr.models.preencoder.abs_preencoder import AbsPreEncoder |
| | | from funasr.models.preencoder.linear import LinearProjection |
| | | from funasr.models.preencoder.sinc import LightweightSincConvs |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.layers.global_mvn import GlobalMVN |
| | | from funasr.layers.utterance_mvn import UtteranceMVN |
| | | from funasr.models.specaug.specaug import SpecAugLFR |
| | | from funasr.modules.subsampling import Conv1dSubsampling |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.torch_utils.initialize import initialize |
| | |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str_or_none |
| | | |
| | | from funasr.models.specaug.specaug import SpecAugLFR |
| | | from funasr.models.predictor.cif import CifPredictor, CifPredictorV2 |
| | | from funasr.modules.subsampling import Conv1dSubsampling |
| | | from funasr.models.e2e_asr import ESPnetASRModel |
| | | from funasr.models.e2e_uni_asr import UniASR |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt |
| | | from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert |
| | | from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN |
| | | |
| | | frontend_choices = ClassChoices( |
| | | name="frontend", |
| | | classes=dict( |
| | |
| | | sliding_window=SlidingWindow, |
| | | s3prl=S3prlFrontend, |
| | | fused=FusedFrontends, |
| | | wav_frontend=WavFrontend, |
| | | ), |
| | | type_check=AbsFrontend, |
| | | default="default", |
| | |
| | | uniasr=UniASR, |
| | | paraformer=Paraformer, |
| | | paraformer_bert=ParaformerBert, |
| | | bicif_paraformer=BiCifParaformer, |
| | | ), |
| | | type_check=AbsESPnetModel, |
| | | default="asr", |
| | |
| | | rnn=RNNEncoder, |
| | | sanm=SANMEncoder, |
| | | sanm_chunk_opt=SANMEncoderChunkOpt, |
| | | data2vec_encoder=Data2VecEncoder, |
| | | ), |
| | | type_check=AbsEncoder, |
| | | default="rnn", |
| | |
| | | cif_predictor=CifPredictor, |
| | | ctc_predictor=None, |
| | | cif_predictor_v2=CifPredictorV2, |
| | | cif_predictor_v3=CifPredictorV3, |
| | | ), |
| | | type_check=None, |
| | | default="cif_predictor", |
| | |
| | | |
| | | # NOTE(kamo): add_arguments(..., required=True) can't be used |
| | | # to provide --print_config mode. Instead of it, do as |
| | | required = parser.get_default("required") |
| | | required += ["token_list"] |
| | | # required = parser.get_default("required") |
| | | # required += ["token_list"] |
| | | |
| | | group.add_argument( |
| | | "--token_list", |
| | |
| | | type=str2bool, |
| | | default=True, |
| | | help="whether to split text using <space>", |
| | | ) |
| | | group.add_argument( |
| | | "--seg_dict_file", |
| | | type=str, |
| | | default=None, |
| | | help="seg_dict_file for text processing", |
| | | ) |
| | | group.add_argument( |
| | | "--init", |
| | |
| | | help="THe probability for applying RIR convolution.", |
| | | ) |
| | | parser.add_argument( |
| | | "--cmvn_file", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The file path of noise scp file.", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_scp", |
| | | type=str_or_none, |
| | | default=None, |
| | |
| | | text_cleaner=args.cleaner, |
| | | g2p_type=args.g2p, |
| | | split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, |
| | | # NOTE(kamo): Check attribute existence for backward compatibility |
| | | rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, |
| | | rir_apply_prob=args.rir_apply_prob |
| | |
| | | if args.input_size is None: |
| | | # Extract features in the model |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | if args.frontend == 'wav_frontend': |
| | | frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) |
| | | else: |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | |
| | | if args.input_size is None: |
| | | # Extract features in the model |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | if args.frontend == 'wav_frontend': |
| | | frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) |
| | | else: |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | |
| | | assert check_return_type(model) |
| | | return model |
| | | |
| | | # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ |
| | | @classmethod |
| | | def build_model_from_file( |
| | | cls, |
| | | config_file: Union[Path, str] = None, |
| | | model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | ): |
| | | """Build model from the files. |
| | | |
| | | This method is used for inference or fine-tuning. |
| | | |
| | | Args: |
| | | config_file: The yaml file saved when training. |
| | | model_file: The model file saved when training. |
| | | device: Device type, "cpu", "cuda", or "cuda:N". |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | if config_file is None: |
| | | assert model_file is not None, ( |
| | | "The argument 'model_file' must be provided " |
| | | "if the argument 'config_file' is not specified." |
| | | ) |
| | | config_file = Path(model_file).parent / "config.yaml" |
| | | else: |
| | | config_file = Path(config_file) |
| | | |
| | | with config_file.open("r", encoding="utf-8") as f: |
| | | args = yaml.safe_load(f) |
| | | if cmvn_file is not None: |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | ) |
| | | model.to(device) |
| | | model_dict = dict() |
| | | model_name_pth = None |
| | | if model_file is not None: |
| | | logging.info("model_file is {}".format(model_file)) |
| | | if device == "cuda": |
| | | device = f"cuda:{torch.cuda.current_device()}" |
| | | model_dir = os.path.dirname(model_file) |
| | | model_name = os.path.basename(model_file) |
| | | if "model.ckpt-" in model_name or ".bin" in model_name: |
| | | model_name_pth = os.path.join(model_dir, model_name.replace('.bin', |
| | | '.pb')) if ".bin" in model_name else os.path.join( |
| | | model_dir, "{}.pth".format(model_name)) |
| | | if os.path.exists(model_name_pth): |
| | | logging.info("model_file is load from pth: {}".format(model_name_pth)) |
| | | model_dict = torch.load(model_name_pth, map_location=device) |
| | | else: |
| | | model_dict = cls.convert_tf2torch(model, model_file) |
| | | model.load_state_dict(model_dict) |
| | | else: |
| | | model_dict = torch.load(model_file, map_location=device) |
| | | model.load_state_dict(model_dict) |
| | | if model_name_pth is not None and not os.path.exists(model_name_pth): |
| | | torch.save(model_dict, model_name_pth) |
| | | logging.info("model_file is saved to pth: {}".format(model_name_pth)) |
| | | |
| | | return model, args |
| | | |
| | | @classmethod |
| | | def convert_tf2torch( |
| | | cls, |
| | | model, |
| | | ckpt, |
| | | ): |
| | | logging.info("start convert tf model to torch model") |
| | | from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict |
| | | var_dict_tf = load_tf_dict(ckpt) |
| | | var_dict_torch = model.state_dict() |
| | | var_dict_torch_update = dict() |
| | | # encoder |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # predictor |
| | | var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # decoder |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # encoder2 |
| | | var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # predictor2 |
| | | var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # decoder2 |
| | | var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # stride_conv |
| | | var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | |
| | | return var_dict_torch_update |
| | | |
| | | |
| | | class ASRTaskParaformer(ASRTask): |
| | | # If you need more than one optimizers, change this value |
| | |
| | | if args.input_size is None: |
| | | # Extract features in the model |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | if args.frontend == 'wav_frontend': |
| | | frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) |
| | | else: |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | |
| | | |
| | | assert check_return_type(model) |
| | | return model |
| | | |
| | | # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ |
| | | @classmethod |
| | | def build_model_from_file( |
| | | cls, |
| | | config_file: Union[Path, str] = None, |
| | | model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | ): |
| | | """Build model from the files. |
| | | |
| | | This method is used for inference or fine-tuning. |
| | | |
| | | Args: |
| | | config_file: The yaml file saved when training. |
| | | model_file: The model file saved when training. |
| | | device: Device type, "cpu", "cuda", or "cuda:N". |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | if config_file is None: |
| | | assert model_file is not None, ( |
| | | "The argument 'model_file' must be provided " |
| | | "if the argument 'config_file' is not specified." |
| | | ) |
| | | config_file = Path(model_file).parent / "config.yaml" |
| | | else: |
| | | config_file = Path(config_file) |
| | | |
| | | with config_file.open("r", encoding="utf-8") as f: |
| | | args = yaml.safe_load(f) |
| | | if cmvn_file is not None: |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | ) |
| | | model.to(device) |
| | | model_dict = dict() |
| | | model_name_pth = None |
| | | if model_file is not None: |
| | | logging.info("model_file is {}".format(model_file)) |
| | | if device == "cuda": |
| | | device = f"cuda:{torch.cuda.current_device()}" |
| | | model_dir = os.path.dirname(model_file) |
| | | model_name = os.path.basename(model_file) |
| | | if "model.ckpt-" in model_name or ".bin" in model_name: |
| | | model_name_pth = os.path.join(model_dir, model_name.replace('.bin', |
| | | '.pb')) if ".bin" in model_name else os.path.join( |
| | | model_dir, "{}.pth".format(model_name)) |
| | | if os.path.exists(model_name_pth): |
| | | logging.info("model_file is load from pth: {}".format(model_name_pth)) |
| | | model_dict = torch.load(model_name_pth, map_location=device) |
| | | else: |
| | | model_dict = cls.convert_tf2torch(model, model_file) |
| | | model.load_state_dict(model_dict) |
| | | else: |
| | | model_dict = torch.load(model_file, map_location=device) |
| | | model.load_state_dict(model_dict) |
| | | if model_name_pth is not None and not os.path.exists(model_name_pth): |
| | | torch.save(model_dict, model_name_pth) |
| | | logging.info("model_file is saved to pth: {}".format(model_name_pth)) |
| | | model.to(device) |
| | | return model, args |
| | | |
| | | @classmethod |
| | | def convert_tf2torch( |
| | | cls, |
| | | model, |
| | | ckpt, |
| | | ): |
| | | logging.info("start convert tf model to torch model") |
| | | from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict |
| | | var_dict_tf = load_tf_dict(ckpt) |
| | | var_dict_torch = model.state_dict() |
| | | var_dict_torch_update = dict() |
| | | # encoder |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # predictor |
| | | var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # decoder |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | |
| | | return var_dict_torch_update |