From 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 18:03:06 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_aky

---
 funasr/bin/sv_inference.py |   62 +++++++++++++++++--------------
 1 files changed, 34 insertions(+), 28 deletions(-)

diff --git a/funasr/bin/sv_inference.py b/funasr/bin/sv_inference.py
index b0fae38..7e63bbd 100755
--- a/funasr/bin/sv_inference.py
+++ b/funasr/bin/sv_inference.py
@@ -1,4 +1,7 @@
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import argparse
 import logging
 import os
@@ -26,14 +29,14 @@
 from funasr.utils.types import str2bool
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
-
+from funasr.utils.misc import statistic_model_parameters
 
 class Speech2Xvector:
     """Speech2Xvector class
 
     Examples:
         >>> import soundfile
-        >>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pth")
+        >>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
         >>> audio, rate = soundfile.read("speech.wav")
         >>> speech2xvector(audio)
         [(text, token, token_int, hypothesis object), ...]
@@ -59,6 +62,7 @@
             device=device
         )
         logging.info("sv_model: {}".format(sv_model))
+        logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
         logging.info("sv_train_args: {}".format(sv_train_args))
         sv_model.to(dtype=getattr(torch, dtype)).eval()
 
@@ -156,21 +160,22 @@
 
 
 def inference_modelscope(
-        output_dir: Optional[str],
-        batch_size: int,
-        dtype: str,
-        ngpu: int,
-        seed: int,
-        num_workers: int,
-        log_level: Union[int, str],
-        key_file: Optional[str],
-        sv_train_config: Optional[str],
-        sv_model_file: Optional[str],
-        model_tag: Optional[str],
+        output_dir: Optional[str] = None,
+        batch_size: int = 1,
+        dtype: str = "float32",
+        ngpu: int = 1,
+        seed: int = 0,
+        num_workers: int = 0,
+        log_level: Union[int, str] = "INFO",
+        key_file: Optional[str] = None,
+        sv_train_config: Optional[str] = "sv.yaml",
+        sv_model_file: Optional[str] =  "sv.pb",
+        model_tag: Optional[str] = None,
         allow_variable_data_keys: bool = True,
         streaming: bool = False,
         embedding_node: str = "resnet1_dense",
         sv_threshold: float = 0.9465,
+        param_dict: Optional[dict] = None,
         **kwargs,
 ):
     assert check_argument_types()
@@ -183,6 +188,7 @@
         level=log_level,
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
     )
+    logging.info("param_dict: {}".format(param_dict))
 
     if ngpu >= 1 and torch.cuda.is_available():
         device = "cuda"
@@ -212,7 +218,9 @@
             data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
             raw_inputs: Union[np.ndarray, torch.Tensor] = None,
             output_dir_v2: Optional[str] = None,
+            param_dict: Optional[dict] = None,
     ):
+        logging.info("param_dict: {}".format(param_dict))
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
@@ -233,11 +241,10 @@
 
         # 7 .Start for-loop
         output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        embd_fd, ref_emb_fd, score_fd = None, None, None
+        embd_writer, ref_embd_writer, score_writer = None, None, None
         if output_path is not None:
             os.makedirs(output_path, exist_ok=True)
-            embd_writer = WriteHelper("ark:{}/xvector.ark".format(output_path))
-            # embd_fd = open(os.path.join(output_path, "xvector.ark"), "wb")
+            embd_writer = WriteHelper("ark,scp:{}/xvector.ark,{}/xvector.scp".format(output_path, output_path))
         sv_result_list = []
         for keys, batch in loader:
             assert isinstance(batch, dict), type(batch)
@@ -249,6 +256,7 @@
             embedding, ref_embedding, score = speech2xvector(**batch)
             # Only supporting batch_size==1
             key = keys[0]
+            normalized_score = 0.0
             if score is not None:
                 score = score.item()
                 normalized_score = max(score - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
@@ -257,23 +265,21 @@
                 item = {"key": key, "value": embedding.squeeze(0).cpu().numpy()}
             sv_result_list.append(item)
             if output_path is not None:
-                # kaldiio.save_mat(embd_fd, embedding[0].cpu().numpy(), key)
                 embd_writer(key, embedding[0].cpu().numpy())
                 if ref_embedding is not None:
-                    if ref_emb_fd is None:
-                        # ref_emb_fd = open(os.path.join(output_path, "ref_xvector.ark"), "wb")
-                        ref_embd_writer = WriteHelper("ark:{}/ref_xvector.ark".format(output_path))
-                        score_fd = open(os.path.join(output_path, "score.txt"), "w")
-                    # kaldiio.save_mat(ref_emb_fd, ref_embedding[0].cpu().numpy(), key)
+                    if ref_embd_writer is None:
+                        ref_embd_writer = WriteHelper(
+                            "ark,scp:{}/ref_xvector.ark,{}/ref_xvector.scp".format(output_path, output_path)
+                        )
+                        score_writer = open(os.path.join(output_path, "score.txt"), "w")
                     ref_embd_writer(key, ref_embedding[0].cpu().numpy())
-                    score_fd.write("{:.6f}\n".format(score.item()))
+                    score_writer.write("{} {:.6f}\n".format(key, normalized_score))
+
         if output_path is not None:
-            # embd_fd.close()
             embd_writer.close()
-            if ref_emb_fd is not None:
-                # ref_emb_fd.close()
-                ref_emb_fd.close()
-                score_fd.close()
+            if ref_embd_writer is not None:
+                ref_embd_writer.close()
+                score_writer.close()
 
         return sv_result_list
 

--
Gitblit v1.9.1