嘉渊
2023-04-27 6997763bf65705257fe6bca6ee63fcf006122abb
funasr/tasks/abs_task.py
@@ -30,7 +30,7 @@
import torch.nn
import torch.optim
import yaml
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from typeguard import check_return_type
@@ -230,8 +230,8 @@
        >>> cls.check_task_requirements()
        If your model is defined as following,
        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
        >>> class Model(AbsESPnetModel):
        >>> from funasr.models.base_model import FunASRModel
        >>> class Model(FunASRModel):
        ...     def forward(self, input, output, opt=None):  pass
        then "required_data_names" should be as
@@ -251,8 +251,8 @@
        >>> cls.check_task_requirements()
        If your model is defined as follows,
        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
        >>> class Model(AbsESPnetModel):
        >>> from funasr.models.base_model import FunASRModel
        >>> class Model(FunASRModel):
        ...     def forward(self, input, output, opt=None):  pass
        then "optional_data_names" should be as
@@ -263,7 +263,7 @@
    @classmethod
    @abstractmethod
    def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel:
    def build_model(cls, args: argparse.Namespace) -> FunASRModel:
        raise NotImplementedError
    @classmethod
@@ -1235,9 +1235,9 @@
        # 2. Build model
        model = cls.build_model(args=args)
        if not isinstance(model, AbsESPnetModel):
        if not isinstance(model, FunASRModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        model = model.to(
            dtype=getattr(torch, args.train_dtype),
@@ -1921,7 +1921,7 @@
            model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            device: str = "cpu",
    ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
    ) -> Tuple[FunASRModel, argparse.Namespace]:
        """Build model from the files.
        This method is used for inference or fine-tuning.
@@ -1948,9 +1948,9 @@
            args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
        if not isinstance(model, FunASRModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        model.to(device)
        if model_file is not None: