jmwang66
2023-06-29 98abc0e5ac1a1da0fe1802d9ffb623802fbf0b2f
funasr/bin/sv_infer.py
@@ -12,8 +12,6 @@
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.torch_utils.device_funcs import to_device
@@ -42,7 +40,6 @@
            streaming: bool = False,
            embedding_node: str = "resnet1_dense",
    ):
        assert check_argument_types()
        # TODO: 1. Build SV model
        sv_model, sv_train_args = build_model_from_file(
@@ -108,7 +105,6 @@
            embedding, ref_embedding, similarity_score
        """
        assert check_argument_types()
        self.sv_model.eval()
        embedding = self.calculate_embedding(speech)
        ref_emb, score = None, None
@@ -117,35 +113,4 @@
            score = torch.cosine_similarity(embedding, ref_emb)
        results = (embedding, ref_emb, score)
        assert check_return_type(results)
        return results
    @staticmethod
    def from_pretrained(
            model_tag: Optional[str] = None,
            **kwargs: Optional[Any],
    ):
        """Build Speech2Xvector instance from the pretrained model.
        Args:
            model_tag (Optional[str]): Model tag of the pretrained models.
                Currently, the tags of espnet_model_zoo are supported.
        Returns:
            Speech2Xvector: Speech2Xvector instance.
        """
        if model_tag is not None:
            try:
                from espnet_model_zoo.downloader import ModelDownloader
            except ImportError:
                logging.error(
                    "`espnet_model_zoo` is not installed. "
                    "Please install via `pip install -U espnet_model_zoo`."
                )
                raise
            d = ModelDownloader()
            kwargs.update(**d.download_and_unpack(model_tag))
        return Speech2Xvector(**kwargs)