| | |
| | | |
| | | import torch |
| | | import yaml |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.build_utils.build_model import build_model |
| | | from funasr.models.base_model import FunASRModel |
| | | |
| | | |
| | | def load_checkpoint(checkpoint_path, use_cuda=1): |
| | | if use_cuda: |
| | | checkpoint = torch.load(checkpoint_path) |
| | | else: |
| | | checkpoint = torch.load( |
| | | checkpoint_path, map_location=lambda storage, loc: storage) |
| | | return checkpoint |
| | | |
| | | def reload_ss_for_eval(model, checkpoint_path, use_cuda=False): |
| | | checkpoint = load_checkpoint(checkpoint_path, use_cuda) |
| | | model.load_state_dict(checkpoint['model'], strict=False) |
| | | |
| | | def build_model_from_file( |
| | | config_file: Union[Path, str] = None, |
| | |
| | | device: Device type, "cpu", "cuda", or "cuda:N". |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | if config_file is None: |
| | | assert model_file is not None, ( |
| | | "The argument 'model_file' must be provided " |
| | |
| | | model.load_state_dict(model_dict) |
| | | else: |
| | | model_dict = torch.load(model_file, map_location=device) |
| | | if task_name == 'ss': |
| | | reload_ss_for_eval(model, model_file, use_cuda=True) |
| | | logging.info("model is loaded from path: {}".format(model_file)) |
| | | if task_name == "diar" and mode == "sond": |
| | | model_dict = fileter_model_dict(model_dict, model.state_dict()) |
| | | model.load_state_dict(model_dict) |
| | | if task_name == "vad": |
| | | model.encoder.load_state_dict(model_dict) |
| | | else: |
| | | model.load_state_dict(model_dict) |
| | | if model_name_pth is not None and not os.path.exists(model_name_pth): |
| | | torch.save(model_dict, model_name_pth) |
| | | logging.info("model_file is saved to pth: {}".format(model_name_pth)) |
| | |
| | | ckpt, |
| | | mode, |
| | | ): |
| | | assert mode == "paraformer" or mode == "uniasr" or mode == "sond" |
| | | assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp" |
| | | logging.info("start convert tf model to torch model") |
| | | from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict |
| | | var_dict_tf = load_tf_dict(ckpt) |
| | |
| | | # bias_encoder |
| | | var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | else: |
| | | elif "mode" == "sond": |
| | | if model.encoder is not None: |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | |
| | | if model.decoder is not None: |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | elif "mode" == "sv": |
| | | # speech encoder |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # pooling layer |
| | | var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # decoder |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | else: |
| | | # encoder |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # predictor |
| | | var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # decoder |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # bias_encoder |
| | | var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | return var_dict_torch_update |
| | | |
| | | return var_dict_torch_update |
| | | |
| | | |
| | | def fileter_model_dict(src_dict: dict, dest_dict: dict): |
| | | from collections import OrderedDict |
| | |
| | | for key, value in dest_dict.items(): |
| | | if key not in new_dict: |
| | | logging.warning("{} is missed in checkpoint.".format(key)) |
| | | return new_dict |
| | | return new_dict |