嘉渊
2023-04-28 f97e0eb9eee3f14c410ce905b73d0c83033dc1c9
funasr/tasks/asr.py
@@ -72,7 +72,7 @@
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.get_default_kwargs import get_default_kwargs
@@ -127,7 +127,7 @@
        mfcca=MFCCA,
        timestamp_prediction=TimestampPredictor,
    ),
    type_check=AbsESPnetModel,
    type_check=FunASRModel,
    default="asr",
)
preencoder_choices = ClassChoices(
@@ -810,9 +810,9 @@
            args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
        if not isinstance(model, FunASRModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        model.to(device)
        model_dict = dict()
@@ -826,7 +826,7 @@
            if "model.ckpt-" in model_name or ".bin" in model_name:
                model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
                                                                            '.pb')) if ".bin" in model_name else os.path.join(
                    model_dir, "{}.pth".format(model_name))
                    model_dir, "{}.pb".format(model_name))
                if os.path.exists(model_name_pth):
                    logging.info("model_file is load from pth: {}".format(model_name_pth))
                    model_dict = torch.load(model_name_pth, map_location=device)
@@ -1057,9 +1057,9 @@
            args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
        if not isinstance(model, FunASRModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        model.to(device)
        model_dict = dict()
@@ -1073,7 +1073,7 @@
            if "model.ckpt-" in model_name or ".bin" in model_name:
                model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
                                                                            '.pb')) if ".bin" in model_name else os.path.join(
                    model_dir, "{}.pth".format(model_name))
                    model_dir, "{}.pb".format(model_name))
                if os.path.exists(model_name_pth):
                    logging.info("model_file is load from pth: {}".format(model_name_pth))
                    model_dict = torch.load(model_name_pth, map_location=device)