zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/datasets/large_datasets/utils/padding.py
@@ -16,7 +16,7 @@
        if data_name == "key" or data_name == "sampling_rate":
            continue
        else:
            if data_name != 'hotword_indxs':
            if data_name != "hotword_indxs":
                if data[0][data_name].dtype.kind == "i":
                    pad_value = int_pad_value
                    tensor_type = torch.int64
@@ -26,9 +26,7 @@
            tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
            tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
            tensor_pad = pad_sequence(tensor_list,
                                      batch_first=True,
                                      padding_value=pad_value)
            tensor_pad = pad_sequence(tensor_list, batch_first=True, padding_value=pad_value)
            batch[data_name] = tensor_pad
            batch[data_name + "_lengths"] = tensor_lengths
@@ -38,14 +36,16 @@
        # use it to slice hotwords out
        hotword_list = []
        hotword_lengths = []
        text = batch['text']
        text_lengths = batch['text_lengths']
        hotword_indxs = batch['hotword_indxs']
        text = batch["text"]
        text_lengths = batch["text_lengths"]
        hotword_indxs = batch["hotword_indxs"]
        dha_pad = torch.ones_like(text) * -1
        _, t1 = text.shape
        t1 += 1  # TODO: as parameter which is same as predictor_bias
        nth_hw = 0
        for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
        for b, (hotword_indx, one_text, length) in enumerate(
            zip(hotword_indxs, text, text_lengths)
        ):
            dha_pad[b][:length] = 8405
            if hotword_indx[0] != -1:
                start, end = int(hotword_indx[0]), int(hotword_indx[1])
@@ -63,12 +63,10 @@
                    nth_hw += 1
        hotword_list.append(torch.tensor([1]))
        hotword_lengths.append(1)
        hotword_pad = pad_sequence(hotword_list,
                                batch_first=True,
                                padding_value=0)
        hotword_pad = pad_sequence(hotword_list, batch_first=True, padding_value=0)
        batch["hotword_pad"] = hotword_pad
        batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
        batch['dha_pad'] = dha_pad
        del batch['hotword_indxs']
        del batch['hotword_indxs_lengths']
        batch["dha_pad"] = dha_pad
        del batch["hotword_indxs"]
        del batch["hotword_indxs_lengths"]
    return keys, batch