| | |
| | | |
| | | import numpy as np |
| | | import torch |
| | | |
| | | from funasr.modules.nets_utils import pad_list |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | from funasr.modules.nets_utils import pad_list, pad_list_all_dim |
| | | |
| | | |
| | | class CommonCollateFn: |
| | |
| | | output = (uttids, output) |
| | | return output |
| | | |
| | | |
| | | class DiarCollateFn: |
| | | """Functor class of common_collate_fn()""" |
| | | |
| | | def __init__( |
| | | self, |
| | | float_pad_value: Union[float, int] = 0.0, |
| | | int_pad_value: int = -32768, |
| | | not_sequence: Collection[str] = (), |
| | | max_sample_size=None |
| | | ): |
| | | assert check_argument_types() |
| | | self.float_pad_value = float_pad_value |
| | | self.int_pad_value = int_pad_value |
| | | self.not_sequence = set(not_sequence) |
| | | self.max_sample_size = max_sample_size |
| | | |
| | | def __repr__(self): |
| | | return ( |
| | | f"{self.__class__}(float_pad_value={self.float_pad_value}, " |
| | | f"int_pad_value={self.float_pad_value})" |
| | | ) |
| | | |
| | | def __call__( |
| | | self, data: Collection[Tuple[str, Dict[str, np.ndarray]]] |
| | | ) -> Tuple[List[str], Dict[str, torch.Tensor]]: |
| | | return diar_collate_fn( |
| | | data, |
| | | float_pad_value=self.float_pad_value, |
| | | int_pad_value=self.int_pad_value, |
| | | not_sequence=self.not_sequence, |
| | | ) |
| | | |
| | | |
| | | def diar_collate_fn( |
| | | data: Collection[Tuple[str, Dict[str, np.ndarray]]], |
| | | float_pad_value: Union[float, int] = 0.0, |
| | | int_pad_value: int = -32768, |
| | | not_sequence: Collection[str] = (), |
| | | ) -> Tuple[List[str], Dict[str, torch.Tensor]]: |
| | | """Concatenate ndarray-list to an array and convert to torch.Tensor. |
| | | """ |
| | | assert check_argument_types() |
| | | uttids = [u for u, _ in data] |
| | | data = [d for _, d in data] |
| | | |
| | | assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" |
| | | assert all( |
| | | not k.endswith("_lengths") for k in data[0] |
| | | ), f"*_lengths is reserved: {list(data[0])}" |
| | | |
| | | output = {} |
| | | for key in data[0]: |
| | | if data[0][key].dtype.kind == "i": |
| | | pad_value = int_pad_value |
| | | else: |
| | | pad_value = float_pad_value |
| | | |
| | | array_list = [d[key] for d in data] |
| | | tensor_list = [torch.from_numpy(a) for a in array_list] |
| | | tensor = pad_list_all_dim(tensor_list, pad_value) |
| | | output[key] = tensor |
| | | |
| | | if key not in not_sequence: |
| | | lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long) |
| | | output[key + "_lengths"] = lens |
| | | |
| | | output = (uttids, output) |
| | | assert check_return_type(output) |
| | | return output |
| | | |
| | | |
| | | def crop_to_max_size(feature, target_size): |
| | | size = len(feature) |
| | | diff = size - target_size |