From 777ae05adb7f6934892d7685ce8e1e1ba57cc8b5 Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期四, 09 三月 2023 12:16:58 +0800
Subject: [PATCH] add en sv model
---
funasr/models/encoder/resnet34_encoder.py | 1
egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer.py | 8 +-
funasr/tasks/sv.py | 110 +++++++++++++++++++++++++++++++++++-
egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer.py | 39 +++++++++++++
egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer_sv.py | 21 +++++++
5 files changed, 171 insertions(+), 8 deletions(-)
diff --git a/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer.py b/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer.py
new file mode 100644
index 0000000..d3975ae
--- /dev/null
+++ b/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer.py
@@ -0,0 +1,39 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+import numpy as np
+
+if __name__ == '__main__':
+ inference_sv_pipline = pipeline(
+ task=Tasks.speaker_verification,
+ model='damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch'
+ )
+
+ # extract speaker embedding
+ # for url use "spk_embedding" as key
+ rec_result = inference_sv_pipline(
+ audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav')
+ enroll = rec_result["spk_embedding"]
+
+ # for local file use "spk_embedding" as key
+ rec_result = inference_sv_pipline(audio_in='sv_example_same.wav')["test1"]
+ same = rec_result["spk_embedding"]
+
+ import soundfile
+ wav = soundfile.read('sv_example_enroll.wav')[0]
+ # for raw inputs use "spk_embedding" as key
+ spk_embedding = inference_sv_pipline(audio_in=wav)["spk_embedding"]
+
+ rec_result = inference_sv_pipline(
+ audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav')
+ different = rec_result["spk_embedding"]
+
+ # calculate cosine similarity for same speaker
+ sv_threshold = 0.9465
+ same_cos = np.sum(enroll * same) / (np.linalg.norm(enroll) * np.linalg.norm(same))
+ same_cos = max(same_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
+ print("Similarity:", same_cos)
+
+ # calculate cosine similarity for different speaker
+ diff_cos = np.sum(enroll * different) / (np.linalg.norm(enroll) * np.linalg.norm(different))
+ diff_cos = max(diff_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
+ print("Similarity:", diff_cos)
diff --git a/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer_sv.py b/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer_sv.py
new file mode 100644
index 0000000..1151ceb
--- /dev/null
+++ b/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer_sv.py
@@ -0,0 +1,21 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+if __name__ == '__main__':
+ inference_sv_pipline = pipeline(
+ task=Tasks.speaker_verification,
+ model='speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch'
+ )
+
+ # the same speaker
+ rec_result = inference_sv_pipline(audio_in=(
+ 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
+ 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav'))
+ print("Similarity", rec_result["scores"])
+
+ # different speakers
+ rec_result = inference_sv_pipline(audio_in=(
+ 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
+ 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav'))
+
+ print("Similarity", rec_result["scores"])
diff --git a/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer.py b/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer.py
index a48088c..87f3801 100644
--- a/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer.py
+++ b/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer.py
@@ -12,20 +12,20 @@
# for url use "utt_id" as key
rec_result = inference_sv_pipline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav')
- enroll = rec_result["utt_id"]
+ enroll = rec_result["spk_embedding"]
# for local file use "utt_id" as key
rec_result = inference_sv_pipline(audio_in='sv_example_same.wav')["test1"]
- same = rec_result["test1"]
+ same = rec_result["spk_embedding"]
import soundfile
wav = soundfile.read('sv_example_enroll.wav')[0]
# for raw inputs use "utt_id" as key
- spk_embedding = inference_sv_pipline(audio_in=wav)["utt_id"]
+ spk_embedding = inference_sv_pipline(audio_in=wav)["spk_embedding"]
rec_result = inference_sv_pipline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav')
- different = rec_result["utt_id"]
+ different = rec_result["spk_embedding"]
# 瀵圭浉鍚岀殑璇磋瘽浜鸿绠椾綑寮︾浉浼煎害
sv_threshold = 0.9465
diff --git a/funasr/models/encoder/resnet34_encoder.py b/funasr/models/encoder/resnet34_encoder.py
index 952ce15..930f7e0 100644
--- a/funasr/models/encoder/resnet34_encoder.py
+++ b/funasr/models/encoder/resnet34_encoder.py
@@ -387,7 +387,6 @@
return var_dict_torch_update
-
class ResNet34Diar(ResNet34):
def __init__(
self,
diff --git a/funasr/tasks/sv.py b/funasr/tasks/sv.py
index 16384a7..1b08c4d 100644
--- a/funasr/tasks/sv.py
+++ b/funasr/tasks/sv.py
@@ -1,14 +1,18 @@
import argparse
import logging
+import os
+from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
+from typing import Union
import numpy as np
import torch
+import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
@@ -21,7 +25,7 @@
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.encoder.resnet34_encoder import ResNet34
+from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
from funasr.models.pooling.statistic_pooling import StatisticPooling
from funasr.models.decoder.sv_decoder import DenseDecoder
from funasr.models.e2e_sv import ESPnetSVModel
@@ -103,6 +107,7 @@
"encoder",
classes=dict(
resnet34=ResNet34,
+ resnet34_sp_l2reg=ResNet34_SP_L2Reg,
rnn=RNNEncoder,
),
type_check=AbsEncoder,
@@ -394,9 +399,16 @@
# 7. Pooling layer
pooling_class = pooling_choices.get_class(args.pooling_type)
+ pooling_dim = (2, 3)
+ eps = 1e-12
+ if hasattr(args, "pooling_type_conf"):
+ if "pooling_dim" in args.pooling_type_conf:
+ pooling_dim = args.pooling_type_conf["pooling_dim"]
+ if "eps" in args.pooling_type_conf:
+ eps = args.pooling_type_conf["eps"]
pooling_layer = pooling_class(
- pooling_dim=(2, 3),
- eps=1e-12,
+ pooling_dim=pooling_dim,
+ eps=eps,
)
if args.pooling_type == "statistic":
encoder_output_size *= 2
@@ -435,3 +447,95 @@
assert check_return_type(model)
return model
+
+ # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
+ @classmethod
+ def build_model_from_file(
+ cls,
+ config_file: Union[Path, str] = None,
+ model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ device: str = "cpu",
+ ):
+ """Build model from the files.
+
+ This method is used for inference or fine-tuning.
+
+ Args:
+ config_file: The yaml file saved when training.
+ model_file: The model file saved when training.
+ cmvn_file: The cmvn file for front-end
+ device: Device type, "cpu", "cuda", or "cuda:N".
+
+ """
+ assert check_argument_types()
+ if config_file is None:
+ assert model_file is not None, (
+ "The argument 'model_file' must be provided "
+ "if the argument 'config_file' is not specified."
+ )
+ config_file = Path(model_file).parent / "config.yaml"
+ else:
+ config_file = Path(config_file)
+
+ with config_file.open("r", encoding="utf-8") as f:
+ args = yaml.safe_load(f)
+ if cmvn_file is not None:
+ args["cmvn_file"] = cmvn_file
+ args = argparse.Namespace(**args)
+ model = cls.build_model(args)
+ if not isinstance(model, AbsESPnetModel):
+ raise RuntimeError(
+ f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ )
+ model.to(device)
+ model_dict = dict()
+ model_name_pth = None
+ if model_file is not None:
+ logging.info("model_file is {}".format(model_file))
+ if device == "cuda":
+ device = f"cuda:{torch.cuda.current_device()}"
+ model_dir = os.path.dirname(model_file)
+ model_name = os.path.basename(model_file)
+ if "model.ckpt-" in model_name or ".bin" in model_name:
+ if ".bin" in model_name:
+ model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
+ else:
+ model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
+ if os.path.exists(model_name_pth):
+ logging.info("model_file is load from pth: {}".format(model_name_pth))
+ model_dict = torch.load(model_name_pth, map_location=device)
+ else:
+ model_dict = cls.convert_tf2torch(model, model_file)
+ model.load_state_dict(model_dict)
+ else:
+ model_dict = torch.load(model_file, map_location=device)
+ model.load_state_dict(model_dict)
+ if model_name_pth is not None and not os.path.exists(model_name_pth):
+ torch.save(model_dict, model_name_pth)
+ logging.info("model_file is saved to pth: {}".format(model_name_pth))
+
+ return model, args
+
+ @classmethod
+ def convert_tf2torch(
+ cls,
+ model,
+ ckpt,
+ ):
+ logging.info("start convert tf model to torch model")
+ from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
+ var_dict_tf = load_tf_dict(ckpt)
+ var_dict_torch = model.state_dict()
+ var_dict_torch_update = dict()
+ # speech encoder
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # pooling layer
+ var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+
+ return var_dict_torch_update
--
Gitblit v1.9.1