| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | import sys |
| | | from pathlib import Path |
| | | from typing import Any |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | from typing import Union |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from kaldiio import WriteHelper |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.tasks.sv import SVTask |
| | | from funasr.tasks.asr import ASRTask |
| | | from funasr.build_utils.build_model_from_file import build_model_from_file |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.misc import statistic_model_parameters |
| | | |
| | | |
| | | class Speech2Xvector: |
| | | """Speech2Xvector class |
| | |
| | | streaming: bool = False, |
| | | embedding_node: str = "resnet1_dense", |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | | # TODO: 1. Build SV model |
| | | sv_model, sv_train_args = SVTask.build_model_from_file( |
| | | sv_model, sv_train_args = build_model_from_file( |
| | | config_file=sv_train_config, |
| | | model_file=sv_model_file, |
| | | device=device |
| | | cmvn_file=None, |
| | | device=device, |
| | | task_name="sv", |
| | | mode="sv", |
| | | ) |
| | | logging.info("sv_model: {}".format(sv_model)) |
| | | logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model))) |
| | |
| | | embedding, ref_embedding, similarity_score |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | self.sv_model.eval() |
| | | embedding = self.calculate_embedding(speech) |
| | | ref_emb, score = None, None |
| | |
| | | score = torch.cosine_similarity(embedding, ref_emb) |
| | | |
| | | results = (embedding, ref_emb, score) |
| | | assert check_return_type(results) |
| | | return results |
| | | |
| | | @staticmethod |
| | | def from_pretrained( |
| | | model_tag: Optional[str] = None, |
| | | **kwargs: Optional[Any], |
| | | ): |
| | | """Build Speech2Xvector instance from the pretrained model. |
| | | |
| | | Args: |
| | | model_tag (Optional[str]): Model tag of the pretrained models. |
| | | Currently, the tags of espnet_model_zoo are supported. |
| | | |
| | | Returns: |
| | | Speech2Xvector: Speech2Xvector instance. |
| | | |
| | | """ |
| | | if model_tag is not None: |
| | | try: |
| | | from espnet_model_zoo.downloader import ModelDownloader |
| | | |
| | | except ImportError: |
| | | logging.error( |
| | | "`espnet_model_zoo` is not installed. " |
| | | "Please install via `pip install -U espnet_model_zoo`." |
| | | ) |
| | | raise |
| | | d = ModelDownloader() |
| | | kwargs.update(**d.download_and_unpack(model_tag)) |
| | | |
| | | return Speech2Xvector(**kwargs) |
| | | |
| | | |
| | | |
| | | |