| | |
| | | import torch.nn |
| | | import torch.optim |
| | | import yaml |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.base_model import FunASRModel |
| | | from torch.utils.data import DataLoader |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | |
| | | >>> cls.check_task_requirements() |
| | | If your model is defined as following, |
| | | |
| | | >>> from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | >>> class Model(AbsESPnetModel): |
| | | >>> from funasr.models.base_model import FunASRModel |
| | | >>> class Model(FunASRModel): |
| | | ... def forward(self, input, output, opt=None): pass |
| | | |
| | | then "required_data_names" should be as |
| | |
| | | >>> cls.check_task_requirements() |
| | | If your model is defined as follows, |
| | | |
| | | >>> from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | >>> class Model(AbsESPnetModel): |
| | | >>> from funasr.models.base_model import FunASRModel |
| | | >>> class Model(FunASRModel): |
| | | ... def forward(self, input, output, opt=None): pass |
| | | |
| | | then "optional_data_names" should be as |
| | |
| | | |
| | | @classmethod |
| | | @abstractmethod |
| | | def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel: |
| | | def build_model(cls, args: argparse.Namespace) -> FunASRModel: |
| | | raise NotImplementedError |
| | | |
| | | @classmethod |
| | |
| | | |
| | | # 2. Build model |
| | | model = cls.build_model(args=args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | if not isinstance(model, FunASRModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | ) |
| | | model = model.to( |
| | | dtype=getattr(torch, args.train_dtype), |
| | |
| | | model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | ) -> Tuple[AbsESPnetModel, argparse.Namespace]: |
| | | ) -> Tuple[FunASRModel, argparse.Namespace]: |
| | | """Build model from the files. |
| | | |
| | | This method is used for inference or fine-tuning. |
| | |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | if not isinstance(model, FunASRModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | ) |
| | | model.to(device) |
| | | if model_file is not None: |