| | |
| | | class MultiSequential(torch.nn.Sequential): |
| | | """Multi-input multi-output torch.nn.Sequential.""" |
| | | |
| | | def __init__(self, *args, layer_drop_rate=0.0): |
| | | """Initialize MultiSequential with layer_drop. |
| | | |
| | | Args: |
| | | layer_drop_rate (float): Probability of dropping out each fn (layer). |
| | | |
| | | """ |
| | | super(MultiSequential, self).__init__(*args) |
| | | self.layer_drop_rate = layer_drop_rate |
| | | |
| | | def forward(self, *args): |
| | | """Repeat.""" |
| | | for m in self: |
| | | args = m(*args) |
| | | _probs = torch.empty(len(self)).uniform_() |
| | | for idx, m in enumerate(self): |
| | | if not self.training or (_probs[idx] >= self.layer_drop_rate): |
| | | args = m(*args) |
| | | return args |
| | | |
| | | |
| | | def repeat(N, fn): |
| | | def repeat(N, fn, layer_drop_rate=0.0): |
| | | """Repeat module N times. |
| | | |
| | | Args: |
| | | N (int): Number of repeat time. |
| | | fn (Callable): Function to generate module. |
| | | layer_drop_rate (float): Probability of dropping out each fn (layer). |
| | | |
| | | Returns: |
| | | MultiSequential: Repeated model instance. |
| | | |
| | | """ |
| | | return MultiSequential(*[fn(n) for n in range(N)]) |
| | | return MultiSequential(*[fn(n) for n in range(N)], layer_drop_rate=layer_drop_rate) |
| | | |
| | | |
| | | class MultiBlocks(torch.nn.Module): |