嘉渊
2023-04-28 f97e0eb9eee3f14c410ce905b73d0c83033dc1c9
funasr/tasks/asr.py
@@ -72,7 +72,7 @@
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.get_default_kwargs import get_default_kwargs
@@ -125,9 +125,9 @@
        bicif_paraformer=BiCifParaformer,
        contextual_paraformer=ContextualParaformer,
        mfcca=MFCCA,
        timestamp_predictor=TimestampPredictor,
        timestamp_prediction=TimestampPredictor,
    ),
    type_check=AbsESPnetModel,
    type_check=FunASRModel,
    default="asr",
)
preencoder_choices = ClassChoices(
@@ -810,9 +810,9 @@
            args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
        if not isinstance(model, FunASRModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        model.to(device)
        model_dict = dict()
@@ -826,7 +826,7 @@
            if "model.ckpt-" in model_name or ".bin" in model_name:
                model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
                                                                            '.pb')) if ".bin" in model_name else os.path.join(
                    model_dir, "{}.pth".format(model_name))
                    model_dir, "{}.pb".format(model_name))
                if os.path.exists(model_name_pth):
                    logging.info("model_file is load from pth: {}".format(model_name_pth))
                    model_dict = torch.load(model_name_pth, map_location=device)
@@ -1057,9 +1057,9 @@
            args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
        if not isinstance(model, FunASRModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        model.to(device)
        model_dict = dict()
@@ -1073,7 +1073,7 @@
            if "model.ckpt-" in model_name or ".bin" in model_name:
                model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
                                                                            '.pb')) if ".bin" in model_name else os.path.join(
                    model_dir, "{}.pth".format(model_name))
                    model_dir, "{}.pb".format(model_name))
                if os.path.exists(model_name_pth):
                    logging.info("model_file is load from pth: {}".format(model_name_pth))
                    model_dict = torch.load(model_name_pth, map_location=device)
@@ -1278,8 +1278,6 @@
            token_list = list(args.token_list)
        else:
            raise RuntimeError("token_list must be str or list")
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
        # 1. frontend
        if args.input_size is None:
@@ -1316,6 +1314,7 @@
            frontend=frontend,
            encoder=encoder,
            predictor=predictor,
            token_list=token_list,
            **args.model_conf,
        )
@@ -1332,12 +1331,3 @@
    ) -> Tuple[str, ...]:
        retval = ("speech", "text")
        return retval
class ASRTaskAligner(ASRTaskParaformer):
    @classmethod
    def required_data_names(
            cls, train: bool = True, inference: bool = False
    ) -> Tuple[str, ...]:
        retval = ("speech", "text")
        return retval