游雁
2024-06-09 b75d1e89bb2f513a79bb07e9100ba1cd2bbcf40c
funasr/datasets/large_datasets/utils/clipping.py
@@ -1,7 +1,7 @@
import numpy as np
import torch
from funasr.datasets.collate_fn import crop_to_max_size
from funasr.datasets.large_datasets.collate_fn import crop_to_max_size
def clipping(data):
@@ -25,7 +25,9 @@
            tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
            length_clip = min(tensor_lengths)
            tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
            tensor_clip = tensor_list[0].new_zeros(
                len(tensor_list), length_clip, tensor_list[0].shape[1]
            )
            for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
                diff = length - length_clip
                assert diff >= 0
@@ -35,6 +37,8 @@
                    tensor_clip[i] = crop_to_max_size(tensor, length_clip)
            batch[data_name] = tensor_clip
            batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
            batch[data_name + "_lengths"] = torch.tensor(
                [tensor.shape[0] for tensor in tensor_clip], dtype=torch.long
            )
    return keys, batch