From 1cdb3cc28d4d89a576cc06e5cd8eb80da1f3a3aa Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 26 四月 2024 11:27:39 +0800
Subject: [PATCH] Dev gzf exp (#1665)

---
 funasr/datasets/large_datasets/utils/padding.py |   52 +++++++++++++++++++++++-----------------------------
 1 files changed, 23 insertions(+), 29 deletions(-)

diff --git a/funasr/datasets/large_datasets/utils/padding.py b/funasr/datasets/large_datasets/utils/padding.py
index 20ba7a3..cb43a27 100644
--- a/funasr/datasets/large_datasets/utils/padding.py
+++ b/funasr/datasets/large_datasets/utils/padding.py
@@ -7,7 +7,7 @@
     assert isinstance(data, list)
     assert "key" in data[0]
     assert "speech" in data[0] or "text" in data[0]
-    
+
     keys = [x["key"] for x in data]
 
     batch = {}
@@ -16,7 +16,7 @@
         if data_name == "key" or data_name == "sampling_rate":
             continue
         else:
-            if data_name != 'hotword_indxs':
+            if data_name != "hotword_indxs":
                 if data[0][data_name].dtype.kind == "i":
                     pad_value = int_pad_value
                     tensor_type = torch.int64
@@ -26,53 +26,47 @@
 
             tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
             tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
-            tensor_pad = pad_sequence(tensor_list,
-                                      batch_first=True,
-                                      padding_value=pad_value)
+            tensor_pad = pad_sequence(tensor_list, batch_first=True, padding_value=pad_value)
             batch[data_name] = tensor_pad
             batch[data_name + "_lengths"] = tensor_lengths
 
-    # DHA, EAHC NOT INCLUDED
+    # SAC LABEL INCLUDE
     if "hotword_indxs" in batch:
         # if hotword indxs in batch
         # use it to slice hotwords out
         hotword_list = []
         hotword_lengths = []
-        text = batch['text']
-        text_lengths = batch['text_lengths']
-        hotword_indxs = batch['hotword_indxs']
-        num_hw = sum([int(i) for i in batch['hotword_indxs_lengths'] if i != 1]) // 2
-        B, t1 = text.shape
+        text = batch["text"]
+        text_lengths = batch["text_lengths"]
+        hotword_indxs = batch["hotword_indxs"]
+        dha_pad = torch.ones_like(text) * -1
+        _, t1 = text.shape
         t1 += 1  # TODO: as parameter which is same as predictor_bias
-        ideal_attn = torch.zeros(B, t1, num_hw+1)
         nth_hw = 0
-        for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
-            ideal_attn[b][:,-1] = 1
+        for b, (hotword_indx, one_text, length) in enumerate(
+            zip(hotword_indxs, text, text_lengths)
+        ):
+            dha_pad[b][:length] = 8405
             if hotword_indx[0] != -1:
                 start, end = int(hotword_indx[0]), int(hotword_indx[1])
-                hotword = one_text[start: end+1]
+                hotword = one_text[start : end + 1]
                 hotword_list.append(hotword)
-                hotword_lengths.append(end-start+1)
-                ideal_attn[b][start:end+1, nth_hw] = 1
-                ideal_attn[b][start:end+1, -1] = 0
+                hotword_lengths.append(end - start + 1)
+                dha_pad[b][start : end + 1] = one_text[start : end + 1]
                 nth_hw += 1
                 if len(hotword_indx) == 4 and hotword_indx[2] != -1:
                     # the second hotword if exist
                     start, end = int(hotword_indx[2]), int(hotword_indx[3])
-                    hotword_list.append(one_text[start: end+1])
-                    hotword_lengths.append(end-start+1)
-                    ideal_attn[b][start:end+1, nth_hw-1] = 1
-                    ideal_attn[b][start:end+1, -1] = 0
+                    hotword_list.append(one_text[start : end + 1])
+                    hotword_lengths.append(end - start + 1)
+                    dha_pad[b][start : end + 1] = one_text[start : end + 1]
                     nth_hw += 1
         hotword_list.append(torch.tensor([1]))
         hotword_lengths.append(1)
-        hotword_pad = pad_sequence(hotword_list,
-                                batch_first=True,
-                                padding_value=0)
+        hotword_pad = pad_sequence(hotword_list, batch_first=True, padding_value=0)
         batch["hotword_pad"] = hotword_pad
         batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
-        batch['ideal_attn'] = ideal_attn
-        del batch['hotword_indxs']
-        del batch['hotword_indxs_lengths']
-
+        batch["dha_pad"] = dha_pad
+        del batch["hotword_indxs"]
+        del batch["hotword_indxs_lengths"]
     return keys, batch

--
Gitblit v1.9.1