From 7acfa5efd9d74c8727b5c16e739a34e8e07373f1 Mon Sep 17 00:00:00 2001
From: Zhihao Du <neo.dzh@alibaba-inc.com>
Date: 星期四, 16 三月 2023 19:41:56 +0800
Subject: [PATCH] Merge pull request #250 from alibaba-damo-academy/dev_dzh

---
 funasr/models/e2e_diar_sond.py      |   26 ++++++++----
 funasr/bin/sond_inference.py        |   30 ++++++++++++--
 funasr/tasks/diar.py                |   17 ++++++++
 funasr/datasets/iterable_dataset.py |    3 +
 4 files changed, 60 insertions(+), 16 deletions(-)

diff --git a/funasr/bin/sond_inference.py b/funasr/bin/sond_inference.py
index 936dc21..5a0a8e2 100755
--- a/funasr/bin/sond_inference.py
+++ b/funasr/bin/sond_inference.py
@@ -54,7 +54,7 @@
             self,
             diar_train_config: Union[Path, str] = None,
             diar_model_file: Union[Path, str] = None,
-            device: str = "cpu",
+            device: Union[str, torch.device] = "cpu",
             batch_size: int = 1,
             dtype: str = "float32",
             streaming: bool = False,
@@ -114,9 +114,19 @@
             # little-endian order: lower bit first
             return (np.array(list(b)[::-1]) == '1').astype(dtype)
 
-        return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
+        # process oov
+        seq = np.array([int(x) for x in seq])
+        new_seq = []
+        for i, x in enumerate(seq):
+            if x < 2 ** vec_dim:
+                new_seq.append(x)
+            else:
+                idx_list = np.where(seq < 2 ** vec_dim)[0]
+                idx = np.abs(idx_list - i).argmin()
+                new_seq.append(seq[idx_list[idx]])
+        return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
 
-    def post_processing(self, raw_logits: torch.Tensor, spk_num: int):
+    def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
         logits_idx = raw_logits.argmax(-1)  # B, T, vocab_size -> B, T
         # upsampling outputs to match inputs
         ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
@@ -127,8 +137,14 @@
         ).squeeze(1).long()
         logits_idx = logits_idx[0].tolist()
         pse_labels = [self.token_list[x] for x in logits_idx]
+        if output_format == "pse_labels":
+            return pse_labels, None
+
         multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num]  # remove padding speakers
         multi_labels = self.smooth_multi_labels(multi_labels)
+        if output_format == "binary_labels":
+            return multi_labels, None
+
         spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
         spk_turns = self.calc_spk_turns(multi_labels, spk_list)
         results = OrderedDict()
@@ -149,6 +165,7 @@
             self,
             speech: Union[torch.Tensor, np.ndarray],
             profile: Union[torch.Tensor, np.ndarray],
+            output_format: str = "speaker_turn"
     ):
         """Inference
 
@@ -178,7 +195,7 @@
         batch = to_device(batch, device=self.device)
 
         logits = self.diar_model.prediction_forward(**batch)
-        results, pse_labels = self.post_processing(logits, profile.shape[1])
+        results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
 
         return results, pse_labels
 
@@ -367,7 +384,7 @@
             pse_label_writer = open("{}/labels.txt".format(output_path), "w")
         logging.info("Start to diarize...")
         result_list = []
-        for keys, batch in loader:
+        for idx, (keys, batch) in enumerate(loader):
             assert isinstance(batch, dict), type(batch)
             assert all(isinstance(s, str) for s in keys), keys
             _bs = len(next(iter(batch.values())))
@@ -385,6 +402,9 @@
                 pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
                 pse_label_writer.flush()
 
+            if idx % 100 == 0:
+                logging.info("Processing {:5d}: {}".format(idx, key))
+
         if output_path is not None:
             output_writer.close()
             pse_label_writer.close()
diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py
index 49c7068..c8c51d4 100644
--- a/funasr/datasets/iterable_dataset.py
+++ b/funasr/datasets/iterable_dataset.py
@@ -8,6 +8,7 @@
 from typing import Iterator
 from typing import Tuple
 from typing import Union
+from typing import List
 
 import kaldiio
 import numpy as np
@@ -129,7 +130,7 @@
         non_iterable_list = []
         self.path_name_type_list = []
 
-        if not isinstance(path_name_type_list[0], Tuple):
+        if not isinstance(path_name_type_list[0], (Tuple, List)):
             path = path_name_type_list[0]
             name = path_name_type_list[1]
             _type = path_name_type_list[2]
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index 258d780..de669f2 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -59,7 +59,8 @@
         normalize_speech_speaker: bool = False,
         ignore_id: int = -1,
         speaker_discrimination_loss_weight: float = 1.0,
-        inter_score_loss_weight: float = 0.0
+        inter_score_loss_weight: float = 0.0,
+        inputs_type: str = "raw",
     ):
         assert check_argument_types()
 
@@ -86,14 +87,12 @@
         )
         self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
         self.pse_embedding = self.generate_pse_embedding()
-        # self.register_buffer("pse_embedding", pse_embedding)
         self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
-        # self.register_buffer("power_weight", power_weight)
         self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
-        # self.register_buffer("int_token_arr", int_token_arr)
         self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
         self.inter_score_loss_weight = inter_score_loss_weight
         self.forward_steps = 0
+        self.inputs_type = inputs_type
 
     def generate_pse_embedding(self):
         embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
@@ -125,9 +124,14 @@
             binary_labels: (Batch, frames, max_spk_num)
             binary_labels_lengths: (Batch,)
         """
-        assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape)
+        assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape)
         batch_size = speech.shape[0]
         self.forward_steps = self.forward_steps + 1
+        if self.pse_embedding.device != speech.device:
+            self.pse_embedding = self.pse_embedding.to(speech.device)
+            self.power_weight = self.power_weight.to(speech.device)
+            self.int_token_arr = self.int_token_arr.to(speech.device)
+
         # 1. Network forward
         pred, inter_outputs = self.prediction_forward(
             speech, speech_lengths,
@@ -149,9 +153,13 @@
         # the sequence length of 'pred' might be slightly less than the
         # length of 'spk_labels'. Here we force them to be equal.
         length_diff_tolerance = 2
-        length_diff = pse_labels.shape[1] - pred.shape[1]
-        if 0 < length_diff <= length_diff_tolerance:
-            pse_labels = pse_labels[:, 0: pred.shape[1]]
+        length_diff = abs(pse_labels.shape[1] - pred.shape[1])
+        if length_diff <= length_diff_tolerance:
+            min_len = min(pred.shape[1], pse_labels.shape[1])
+            pse_labels = pse_labels[:, :min_len]
+            pred = pred[:, :min_len]
+            cd_score = cd_score[:, :min_len]
+            ci_score = ci_score[:, :min_len]
 
         loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
         loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
@@ -299,7 +307,7 @@
             speech: torch.Tensor,
             speech_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
-        if self.encoder is not None:
+        if self.encoder is not None and self.inputs_type == "raw":
             speech, speech_lengths = self.encode(speech, speech_lengths)
             speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1])
             speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float()
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 9875f6a..096a5c8 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -507,7 +507,7 @@
             config_file: Union[Path, str] = None,
             model_file: Union[Path, str] = None,
             cmvn_file: Union[Path, str] = None,
-            device: str = "cpu",
+            device: Union[str, torch.device] = "cpu",
     ):
         """Build model from the files.
 
@@ -562,6 +562,7 @@
                 model.load_state_dict(model_dict)
             else:
                 model_dict = torch.load(model_file, map_location=device)
+        model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
         model.load_state_dict(model_dict)
         if model_name_pth is not None and not os.path.exists(model_name_pth):
             torch.save(model_dict, model_name_pth)
@@ -570,6 +571,20 @@
         return model, args
 
     @classmethod
+    def fileter_model_dict(cls, src_dict: dict, dest_dict: dict):
+        from collections import OrderedDict
+        new_dict = OrderedDict()
+        for key, value in src_dict.items():
+            if key in dest_dict:
+                new_dict[key] = value
+            else:
+                logging.info("{} is no longer needed in this model.".format(key))
+        for key, value in dest_dict.items():
+            if key not in new_dict:
+                logging.warning("{} is missed in checkpoint.".format(key))
+        return new_dict
+
+    @classmethod
     def convert_tf2torch(
             cls,
             model,

--
Gitblit v1.9.1