| | |
| | | from funasr.utils.speaker_utils import (check_audio_list, |
| | | sv_preprocess, |
| | | sv_chunk, |
| | | CAMPPlus, |
| | | extract_feature, |
| | | postprocess, |
| | | distribute_spk) |
| | | import funasr.modules.cnn as sv_module |
| | | from funasr.build_utils.build_model_from_file import build_model_from_file |
| | | from funasr.utils.cluster_backend import ClusterBackend |
| | | from funasr.utils.modelscope_utils import get_cache_dir |
| | |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | sv_model_file = asr_model_file.replace("model.pb", "campplus_cn_common.bin") |
| | | sv_model_config_path = asr_model_file.replace("model.pb", "sv_model_config.yaml") |
| | | if not os.path.exists(sv_model_config_path): |
| | | sv_model_config = {'sv_model_class': 'CAMPPlus','sv_model_file': 'campplus_cn_common.bin', 'models_config': {}} |
| | | else: |
| | | with open(sv_model_config_path, 'r') as f: |
| | | sv_model_config = yaml.load(f, Loader=yaml.FullLoader) |
| | | if sv_model_config['models_config'] is None: |
| | | sv_model_config['models_config'] = {} |
| | | sv_model_file = asr_model_file.replace("model.pb", sv_model_config['sv_model_file']) |
| | | |
| | | if param_dict is not None: |
| | | hotword_list_or_file = param_dict.get('hotword') |
| | |
| | | ##### speaker_verification ##### |
| | | ################################## |
| | | # load sv model |
| | | if ngpu > 0: |
| | | sv_model_dict = torch.load(sv_model_file) |
| | | sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config']) |
| | | sv_model.cuda() |
| | | else: |
| | | sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu')) |
| | | sv_model = CAMPPlus() |
| | | sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config']) |
| | | sv_model.load_state_dict(sv_model_dict) |
| | | print(f'load sv model params: {sv_model_file}') |
| | | sv_model.eval() |
| | | cb_model = ClusterBackend() |
| | | vad_segments = [] |
| | |
| | | embs = [] |
| | | for x in wavs: |
| | | x = extract_feature([x]) |
| | | if ngpu > 0: |
| | | x = x.cuda() |
| | | embs.append(sv_model(x)) |
| | | embs = torch.cat(embs) |
| | | embeddings.append(embs.detach().numpy()) |
| | | embeddings.append(embs.cpu().detach().numpy()) |
| | | embeddings = np.concatenate(embeddings) |
| | | labels = cb_model(embeddings) |
| | | sv_output = postprocess(segments, vad_segments, labels, embeddings) |