From cc2c1d1d53dea5d2c45f858d1baa5bd279f47987 Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期三, 31 五月 2023 14:39:25 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR
---
funasr/bin/sv_infer.py | 163 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 163 insertions(+), 0 deletions(-)
diff --git a/funasr/bin/sv_infer.py b/funasr/bin/sv_infer.py
new file mode 100755
index 0000000..1517bfa
--- /dev/null
+++ b/funasr/bin/sv_infer.py
@@ -0,0 +1,163 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import argparse
+import logging
+import os
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+from kaldiio import WriteHelper
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.tasks.sv import SVTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils.misc import statistic_model_parameters
+
+class Speech2Xvector:
+ """Speech2Xvector class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2xvector(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ sv_train_config: Union[Path, str] = None,
+ sv_model_file: Union[Path, str] = None,
+ device: str = "cpu",
+ batch_size: int = 1,
+ dtype: str = "float32",
+ streaming: bool = False,
+ embedding_node: str = "resnet1_dense",
+ ):
+ assert check_argument_types()
+
+ # TODO: 1. Build SV model
+ sv_model, sv_train_args = SVTask.build_model_from_file(
+ config_file=sv_train_config,
+ model_file=sv_model_file,
+ device=device
+ )
+ logging.info("sv_model: {}".format(sv_model))
+ logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
+ logging.info("sv_train_args: {}".format(sv_train_args))
+ sv_model.to(dtype=getattr(torch, dtype)).eval()
+
+ self.sv_model = sv_model
+ self.sv_train_args = sv_train_args
+ self.device = device
+ self.dtype = dtype
+ self.embedding_node = embedding_node
+
+ @torch.no_grad()
+ def calculate_embedding(self, speech: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ # data: (Nsamples,) -> (1, Nsamples)
+ speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ # lengths: (1,)
+ lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+ batch = {"speech": speech, "speech_lengths": lengths}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ enc, ilens = self.sv_model.encode(**batch)
+
+ # c. Forward Pooling
+ pooling = self.sv_model.pooling_layer(enc)
+
+ # d. Forward Decoder
+ outputs, embeddings = self.sv_model.decoder(pooling)
+
+ if self.embedding_node not in embeddings:
+ raise ValueError("Required embedding node {} not in {}".format(
+ self.embedding_node, embeddings.keys()))
+
+ return embeddings[self.embedding_node]
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray],
+ ref_speech: Optional[Union[torch.Tensor, np.ndarray]] = None,
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Union[torch.Tensor, None]]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ ref_speech: Reference speech to compare
+ Returns:
+ embedding, ref_embedding, similarity_score
+
+ """
+ assert check_argument_types()
+ self.sv_model.eval()
+ embedding = self.calculate_embedding(speech)
+ ref_emb, score = None, None
+ if ref_speech is not None:
+ ref_emb = self.calculate_embedding(ref_speech)
+ score = torch.cosine_similarity(embedding, ref_emb)
+
+ results = (embedding, ref_emb, score)
+ assert check_return_type(results)
+ return results
+
+ @staticmethod
+ def from_pretrained(
+ model_tag: Optional[str] = None,
+ **kwargs: Optional[Any],
+ ):
+ """Build Speech2Xvector instance from the pretrained model.
+
+ Args:
+ model_tag (Optional[str]): Model tag of the pretrained models.
+ Currently, the tags of espnet_model_zoo are supported.
+
+ Returns:
+ Speech2Xvector: Speech2Xvector instance.
+
+ """
+ if model_tag is not None:
+ try:
+ from espnet_model_zoo.downloader import ModelDownloader
+
+ except ImportError:
+ logging.error(
+ "`espnet_model_zoo` is not installed. "
+ "Please install via `pip install -U espnet_model_zoo`."
+ )
+ raise
+ d = ModelDownloader()
+ kwargs.update(**d.download_and_unpack(model_tag))
+
+ return Speech2Xvector(**kwargs)
+
+
+
+
--
Gitblit v1.9.1