From 1d1ef01b4e23630a99a3be7e9d1dce9550a793e9 Mon Sep 17 00:00:00 2001
From: yhliang <68215459+yhliang-aslp@users.noreply.github.com>
Date: 星期四, 11 五月 2023 16:26:24 +0800
Subject: [PATCH] Merge branch 'main' into dev_smohan

---
 funasr/datasets/large_datasets/utils/padding.py |   58 +++++++++++++++++++++++++++++++++++++++++++++++++++-------
 1 files changed, 51 insertions(+), 7 deletions(-)

diff --git a/funasr/datasets/large_datasets/utils/padding.py b/funasr/datasets/large_datasets/utils/padding.py
index e0feac6..20ba7a3 100644
--- a/funasr/datasets/large_datasets/utils/padding.py
+++ b/funasr/datasets/large_datasets/utils/padding.py
@@ -13,15 +13,16 @@
     batch = {}
     data_names = data[0].keys()
     for data_name in data_names:
-        if data_name == "key" or data_name =="sampling_rate":
+        if data_name == "key" or data_name == "sampling_rate":
             continue
         else:
-            if data[0][data_name].dtype.kind == "i":
-                pad_value = int_pad_value
-                tensor_type = torch.int64
-            else:
-                pad_value = float_pad_value
-                tensor_type = torch.float32
+            if data_name != 'hotword_indxs':
+                if data[0][data_name].dtype.kind == "i":
+                    pad_value = int_pad_value
+                    tensor_type = torch.int64
+                else:
+                    pad_value = float_pad_value
+                    tensor_type = torch.float32
 
             tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
             tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
@@ -31,4 +32,47 @@
             batch[data_name] = tensor_pad
             batch[data_name + "_lengths"] = tensor_lengths
 
+    # DHA, EAHC NOT INCLUDED
+    if "hotword_indxs" in batch:
+        # if hotword indxs in batch
+        # use it to slice hotwords out
+        hotword_list = []
+        hotword_lengths = []
+        text = batch['text']
+        text_lengths = batch['text_lengths']
+        hotword_indxs = batch['hotword_indxs']
+        num_hw = sum([int(i) for i in batch['hotword_indxs_lengths'] if i != 1]) // 2
+        B, t1 = text.shape
+        t1 += 1  # TODO: as parameter which is same as predictor_bias
+        ideal_attn = torch.zeros(B, t1, num_hw+1)
+        nth_hw = 0
+        for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
+            ideal_attn[b][:,-1] = 1
+            if hotword_indx[0] != -1:
+                start, end = int(hotword_indx[0]), int(hotword_indx[1])
+                hotword = one_text[start: end+1]
+                hotword_list.append(hotword)
+                hotword_lengths.append(end-start+1)
+                ideal_attn[b][start:end+1, nth_hw] = 1
+                ideal_attn[b][start:end+1, -1] = 0
+                nth_hw += 1
+                if len(hotword_indx) == 4 and hotword_indx[2] != -1:
+                    # the second hotword if exist
+                    start, end = int(hotword_indx[2]), int(hotword_indx[3])
+                    hotword_list.append(one_text[start: end+1])
+                    hotword_lengths.append(end-start+1)
+                    ideal_attn[b][start:end+1, nth_hw-1] = 1
+                    ideal_attn[b][start:end+1, -1] = 0
+                    nth_hw += 1
+        hotword_list.append(torch.tensor([1]))
+        hotword_lengths.append(1)
+        hotword_pad = pad_sequence(hotword_list,
+                                batch_first=True,
+                                padding_value=0)
+        batch["hotword_pad"] = hotword_pad
+        batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
+        batch['ideal_attn'] = ideal_attn
+        del batch['hotword_indxs']
+        del batch['hotword_indxs_lengths']
+
     return keys, batch

--
Gitblit v1.9.1