| | |
| | | |
| | | import torch |
| | | import yaml |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.build_utils.build_model import build_model |
| | | from funasr.models.base_model import FunASRModel |
| | |
| | | device: Device type, "cpu", "cuda", or "cuda:N". |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | if config_file is None: |
| | | assert model_file is not None, ( |
| | | "The argument 'model_file' must be provided " |
| | |
| | | model.load_state_dict(model_dict) |
| | | else: |
| | | model_dict = torch.load(model_file, map_location=device) |
| | | if task_name == "ss": |
| | | model_dict = model_dict['model'] |
| | | if task_name == "diar" and mode == "sond": |
| | | model_dict = fileter_model_dict(model_dict, model.state_dict()) |
| | | model.load_state_dict(model_dict) |
| | | if task_name == "vad": |
| | | model.encoder.load_state_dict(model_dict) |
| | | else: |
| | | model.load_state_dict(model_dict) |
| | | if model_name_pth is not None and not os.path.exists(model_name_pth): |
| | | torch.save(model_dict, model_name_pth) |
| | | logging.info("model_file is saved to pth: {}".format(model_name_pth)) |