From 4dc3a1b011e1e72eb737417b8e0e0bec7a7e3a6e Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期二, 21 三月 2023 15:12:21 +0800
Subject: [PATCH] resolve conflict

---
 funasr/bin/sond_inference.py |   32 ++++++++++++++++++++++++++------
 1 files changed, 26 insertions(+), 6 deletions(-)

diff --git a/funasr/bin/sond_inference.py b/funasr/bin/sond_inference.py
index ab6d26f..5a0a8e2 100755
--- a/funasr/bin/sond_inference.py
+++ b/funasr/bin/sond_inference.py
@@ -42,7 +42,7 @@
     Examples:
         >>> import soundfile
         >>> import numpy as np
-        >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
+        >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pb")
         >>> profile = np.load("profiles.npy")
         >>> audio, rate = soundfile.read("speech.wav")
         >>> speech2diar(audio, profile)
@@ -54,7 +54,7 @@
             self,
             diar_train_config: Union[Path, str] = None,
             diar_model_file: Union[Path, str] = None,
-            device: str = "cpu",
+            device: Union[str, torch.device] = "cpu",
             batch_size: int = 1,
             dtype: str = "float32",
             streaming: bool = False,
@@ -114,9 +114,19 @@
             # little-endian order: lower bit first
             return (np.array(list(b)[::-1]) == '1').astype(dtype)
 
-        return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
+        # process oov
+        seq = np.array([int(x) for x in seq])
+        new_seq = []
+        for i, x in enumerate(seq):
+            if x < 2 ** vec_dim:
+                new_seq.append(x)
+            else:
+                idx_list = np.where(seq < 2 ** vec_dim)[0]
+                idx = np.abs(idx_list - i).argmin()
+                new_seq.append(seq[idx_list[idx]])
+        return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
 
-    def post_processing(self, raw_logits: torch.Tensor, spk_num: int):
+    def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
         logits_idx = raw_logits.argmax(-1)  # B, T, vocab_size -> B, T
         # upsampling outputs to match inputs
         ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
@@ -127,8 +137,14 @@
         ).squeeze(1).long()
         logits_idx = logits_idx[0].tolist()
         pse_labels = [self.token_list[x] for x in logits_idx]
+        if output_format == "pse_labels":
+            return pse_labels, None
+
         multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num]  # remove padding speakers
         multi_labels = self.smooth_multi_labels(multi_labels)
+        if output_format == "binary_labels":
+            return multi_labels, None
+
         spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
         spk_turns = self.calc_spk_turns(multi_labels, spk_list)
         results = OrderedDict()
@@ -149,6 +165,7 @@
             self,
             speech: Union[torch.Tensor, np.ndarray],
             profile: Union[torch.Tensor, np.ndarray],
+            output_format: str = "speaker_turn"
     ):
         """Inference
 
@@ -178,7 +195,7 @@
         batch = to_device(batch, device=self.device)
 
         logits = self.diar_model.prediction_forward(**batch)
-        results, pse_labels = self.post_processing(logits, profile.shape[1])
+        results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
 
         return results, pse_labels
 
@@ -367,7 +384,7 @@
             pse_label_writer = open("{}/labels.txt".format(output_path), "w")
         logging.info("Start to diarize...")
         result_list = []
-        for keys, batch in loader:
+        for idx, (keys, batch) in enumerate(loader):
             assert isinstance(batch, dict), type(batch)
             assert all(isinstance(s, str) for s in keys), keys
             _bs = len(next(iter(batch.values())))
@@ -385,6 +402,9 @@
                 pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
                 pse_label_writer.flush()
 
+            if idx % 100 == 0:
+                logging.info("Processing {:5d}: {}".format(idx, key))
+
         if output_path is not None:
             output_writer.close()
             pse_label_writer.close()

--
Gitblit v1.9.1