游雁
2024-06-24 1596f6f414f6f41da66506debb1dff19fffeb3ec
funasr/datasets/large_datasets/collate_fn.py
@@ -13,11 +13,11 @@
    """Functor class of common_collate_fn()"""
    def __init__(
            self,
            float_pad_value: Union[float, int] = 0.0,
            int_pad_value: int = -32768,
            not_sequence: Collection[str] = (),
            max_sample_size=None
        self,
        float_pad_value: Union[float, int] = 0.0,
        int_pad_value: int = -32768,
        not_sequence: Collection[str] = (),
        max_sample_size=None,
    ):
        self.float_pad_value = float_pad_value
        self.int_pad_value = int_pad_value
@@ -31,7 +31,7 @@
        )
    def __call__(
            self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
        self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
    ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
        return common_collate_fn(
            data,
@@ -42,13 +42,12 @@
def common_collate_fn(
        data: Collection[Tuple[str, Dict[str, np.ndarray]]],
        float_pad_value: Union[float, int] = 0.0,
        int_pad_value: int = -32768,
        not_sequence: Collection[str] = (),
    data: Collection[Tuple[str, Dict[str, np.ndarray]]],
    float_pad_value: Union[float, int] = 0.0,
    int_pad_value: int = -32768,
    not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
    """Concatenate ndarray-list to an array and convert to torch.Tensor.
    """
    """Concatenate ndarray-list to an array and convert to torch.Tensor."""
    uttids = [u for u, _ in data]
    data = [d for _, d in data]
@@ -81,11 +80,11 @@
    """Functor class of common_collate_fn()"""
    def __init__(
            self,
            float_pad_value: Union[float, int] = 0.0,
            int_pad_value: int = -32768,
            not_sequence: Collection[str] = (),
            max_sample_size=None
        self,
        float_pad_value: Union[float, int] = 0.0,
        int_pad_value: int = -32768,
        not_sequence: Collection[str] = (),
        max_sample_size=None,
    ):
        self.float_pad_value = float_pad_value
        self.int_pad_value = int_pad_value
@@ -99,7 +98,7 @@
        )
    def __call__(
            self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
        self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
    ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
        return diar_collate_fn(
            data,
@@ -110,13 +109,12 @@
def diar_collate_fn(
        data: Collection[Tuple[str, Dict[str, np.ndarray]]],
        float_pad_value: Union[float, int] = 0.0,
        int_pad_value: int = -32768,
        not_sequence: Collection[str] = (),
    data: Collection[Tuple[str, Dict[str, np.ndarray]]],
    float_pad_value: Union[float, int] = 0.0,
    int_pad_value: int = -32768,
    not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
    """Concatenate ndarray-list to an array and convert to torch.Tensor.
    """
    """Concatenate ndarray-list to an array and convert to torch.Tensor."""
    uttids = [u for u, _ in data]
    data = [d for _, d in data]
@@ -157,9 +155,9 @@
def clipping_collate_fn(
        data: Collection[Tuple[str, Dict[str, np.ndarray]]],
        max_sample_size=None,
        not_sequence: Collection[str] = (),
    data: Collection[Tuple[str, Dict[str, np.ndarray]]],
    max_sample_size=None,
    not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
    # mainly for pre-training
    uttids = [u for u, _ in data]
@@ -193,4 +191,4 @@
            output[key + "_lengths"] = lens
    output = (uttids, output)
    return output
    return output