onlybetheone
2022-12-22 96c56e556e43fbe663a86f1f06d1b5b20a92e053
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
 
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
 
from funasr.modules.nets_utils import pad_list
 
 
class CommonCollateFn:
    """Functor class of common_collate_fn()"""
 
    def __init__(
            self,
            float_pad_value: Union[float, int] = 0.0,
            int_pad_value: int = -32768,
            not_sequence: Collection[str] = (),
            max_sample_size=None
    ):
        assert check_argument_types()
        self.float_pad_value = float_pad_value
        self.int_pad_value = int_pad_value
        self.not_sequence = set(not_sequence)
        self.max_sample_size = max_sample_size
 
    def __repr__(self):
        return (
            f"{self.__class__}(float_pad_value={self.float_pad_value}, "
            f"int_pad_value={self.float_pad_value})"
        )
 
    def __call__(
            self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
    ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
        return common_collate_fn(
            data,
            float_pad_value=self.float_pad_value,
            int_pad_value=self.int_pad_value,
            not_sequence=self.not_sequence,
        )
 
 
def common_collate_fn(
        data: Collection[Tuple[str, Dict[str, np.ndarray]]],
        float_pad_value: Union[float, int] = 0.0,
        int_pad_value: int = -32768,
        not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
    """Concatenate ndarray-list to an array and convert to torch.Tensor.
    """
    assert check_argument_types()
    uttids = [u for u, _ in data]
    data = [d for _, d in data]
 
    assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
    assert all(
        not k.endswith("_lengths") for k in data[0]
    ), f"*_lengths is reserved: {list(data[0])}"
 
    output = {}
    for key in data[0]:
        if data[0][key].dtype.kind == "i":
            pad_value = int_pad_value
        else:
            pad_value = float_pad_value
 
        array_list = [d[key] for d in data]
        tensor_list = [torch.from_numpy(a) for a in array_list]
        tensor = pad_list(tensor_list, pad_value)
        output[key] = tensor
 
        if key not in not_sequence:
            lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
            output[key + "_lengths"] = lens
 
    output = (uttids, output)
    assert check_return_type(output)
    return output