| | |
| | | |
| | | import torch |
| | | import yaml |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.build_utils.build_model import build_model |
| | | from funasr.models.base_model import FunASRModel |
| | |
| | | device: Device type, "cpu", "cuda", or "cuda:N". |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | if config_file is None: |
| | | assert model_file is not None, ( |
| | | "The argument 'model_file' must be provided " |
| | |
| | | model.load_state_dict(model_dict) |
| | | else: |
| | | model_dict = torch.load(model_file, map_location=device) |
| | | if task_name == "ss": |
| | | model_dict = model_dict['model'] |
| | | if task_name == "diar" and mode == "sond": |
| | | model_dict = fileter_model_dict(model_dict, model.state_dict()) |
| | | if task_name == "vad": |