| | |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.vad_utils import slice_padding_fbank |
| | | from funasr.utils.speaker_utils import (check_audio_list, |
| | | sv_preprocess, |
| | | sv_chunk, |
| | | CAMPPlus, |
| | | extract_feature, |
| | | from funasr.utils.speaker_utils import (check_audio_list, |
| | | sv_preprocess, |
| | | sv_chunk, |
| | | CAMPPlus, |
| | | extract_feature, |
| | | postprocess, |
| | | distribute_spk) |
| | | distribute_spk, ERes2Net) |
| | | from funasr.build_utils.build_model_from_file import build_model_from_file |
| | | from funasr.utils.cluster_backend import ClusterBackend |
| | | from funasr.utils.modelscope_utils import get_cache_dir |
| | |
| | | ) |
| | | |
| | | sv_model_file = asr_model_file.replace("model.pb", "campplus_cn_common.bin") |
| | | if not os.path.exists(sv_model_file): |
| | | sv_model_file = asr_model_file.replace("model.pb", "pretrained_eres2net_aug.ckpt") |
| | | if not os.path.exists(sv_model_file): |
| | | raise FileNotFoundError("sv_model_file not found: {}".format(sv_model_file)) |
| | | |
| | | if param_dict is not None: |
| | | hotword_list_or_file = param_dict.get('hotword') |
| | |
| | | ##### speaker_verification ##### |
| | | ################################## |
| | | # load sv model |
| | | sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu')) |
| | | sv_model = CAMPPlus() |
| | | sv_model_dict = torch.load(sv_model_file) |
| | | print(f'load sv model params: {sv_model_file}') |
| | | if os.path.basename(sv_model_file) == "campplus_cn_common.bin": |
| | | sv_model = CAMPPlus() |
| | | else: |
| | | sv_model = ERes2Net() |
| | | if ngpu > 0: |
| | | sv_model.cuda() |
| | | sv_model.load_state_dict(sv_model_dict) |
| | | sv_model.eval() |
| | | cb_model = ClusterBackend() |
| | |
| | | embs = [] |
| | | for x in wavs: |
| | | x = extract_feature([x]) |
| | | if ngpu > 0: |
| | | x = x.cuda() |
| | | embs.append(sv_model(x)) |
| | | embs = torch.cat(embs) |
| | | embeddings.append(embs.detach().numpy()) |
| | | embeddings.append(embs.cpu().detach().numpy()) |
| | | embeddings = np.concatenate(embeddings) |
| | | labels = cb_model(embeddings) |
| | | sv_output = postprocess(segments, vad_segments, labels, embeddings) |