shixian.shi
2023-05-04 8cac1c97d92a0cbfaf9862930404ede92150beeb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
 
 
def padding(data, float_pad_value=0.0, int_pad_value=-1):
    assert isinstance(data, list)
    assert "key" in data[0]
    assert "speech" in data[0] or "text" in data[0]
    
    keys = [x["key"] for x in data]
 
    batch = {}
    data_names = data[0].keys()
    for data_name in data_names:
        if data_name == "key" or data_name == "sampling_rate" or data_name == 'hotword_indxs':
            batch[data_name] = data[0][data_name]
            continue
        else:
            if data[0][data_name].dtype.kind == "i":
                pad_value = int_pad_value
                tensor_type = torch.int64
            else:
                pad_value = float_pad_value
                tensor_type = torch.float32
 
            tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
            tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
            tensor_pad = pad_sequence(tensor_list,
                                      batch_first=True,
                                      padding_value=pad_value)
            batch[data_name] = tensor_pad
            batch[data_name + "_lengths"] = tensor_lengths
 
    # DHA, EAHC NOT INCLUDED
    if "hotword_indxs" in batch:
        # if hotword indxs in batch
        # use it to slice hotwords out
        hotword_list = []
        hotword_lengths = []
        text = batch['text']
        text_lengths = batch['text_lengths']
        hotword_indxs = batch['hotword_indxs']
        num_hw = sum([int(i) for i in batch['hotword_indxs_lengths'] if i != 1]) // 2
        B, t1 = text.shape
        t1 += 1  # TODO: as parameter which is same as predictor_bias
        ideal_attn = torch.zeros(B, t1, num_hw+1)
        nth_hw = 0
        for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
            ideal_attn[b][:,-1] = 1
            if hotword_indx[0] != -1:
                start, end = int(hotword_indx[0]), int(hotword_indx[1])
                hotword = one_text[start: end+1]
                hotword_list.append(hotword)
                hotword_lengths.append(end-start+1)
                ideal_attn[b][start:end+1, nth_hw] = 1
                ideal_attn[b][start:end+1, -1] = 0
                nth_hw += 1
                if len(hotword_indx) == 4 and hotword_indx[2] != -1:
                    # the second hotword if exist
                    start, end = int(hotword_indx[2]), int(hotword_indx[3])
                    hotword_list.append(one_text[start: end+1])
                    hotword_lengths.append(end-start+1)
                    ideal_attn[b][start:end+1, nth_hw-1] = 1
                    ideal_attn[b][start:end+1, -1] = 0
                    nth_hw += 1
        hotword_list.append(torch.tensor([1]))
        hotword_lengths.append(1)
        hotword_pad = pad_sequence(hotword_list,
                                batch_first=True,
                                padding_value=0)
        batch["hotword_pad"] = hotword_pad
        batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
        batch['ideal_attn'] = ideal_attn
        del batch['hotword_indxs']
        del batch['hotword_indxs_lengths']
 
    return keys, batch