1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
| """Time warp module."""
| import torch
|
| from funasr.modules.nets_utils import pad_list
|
| DEFAULT_TIME_WARP_MODE = "bicubic"
|
|
| def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
| """Time warping using torch.interpolate.
|
| Args:
| x: (Batch, Time, Freq)
| window: time warp parameter
| mode: Interpolate mode
| """
|
| # bicubic supports 4D or more dimension tensor
| org_size = x.size()
| if x.dim() == 3:
| # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
| x = x[:, None]
|
| t = x.shape[2]
| if t - window <= window:
| return x.view(*org_size)
|
| center = torch.randint(window, t - window, (1,))[0]
| warped = torch.randint(center - window, center + window, (1,))[0] + 1
|
| # left: (Batch, Channel, warped, Freq)
| # right: (Batch, Channel, time - warped, Freq)
| left = torch.nn.functional.interpolate(
| x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False
| )
| right = torch.nn.functional.interpolate(
| x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False
| )
|
| if x.requires_grad:
| x = torch.cat([left, right], dim=-2)
| else:
| x[:, :, :warped] = left
| x[:, :, warped:] = right
|
| return x.view(*org_size)
|
|
| class TimeWarp(torch.nn.Module):
| """Time warping using torch.interpolate.
|
| Args:
| window: time warp parameter
| mode: Interpolate mode
| """
|
| def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
| super().__init__()
| self.window = window
| self.mode = mode
|
| def extra_repr(self):
| return f"window={self.window}, mode={self.mode}"
|
| def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None):
| """Forward function.
|
| Args:
| x: (Batch, Time, Freq)
| x_lengths: (Batch,)
| """
|
| if x_lengths is None or all(le == x_lengths[0] for le in x_lengths):
| # Note that applying same warping for each sample
| y = time_warp(x, window=self.window, mode=self.mode)
| else:
| # FIXME(kamo): I have no idea to batchify Timewarp
| ys = []
| for i in range(x.size(0)):
| _y = time_warp(
| x[i][None, : x_lengths[i]],
| window=self.window,
| mode=self.mode,
| )[0]
| ys.append(_y)
| y = pad_list(ys, 0.0)
|
| return y, x_lengths
|
|