| | |
| | | else: |
| | | return mask |
| | | |
| | | class sequence_mask(nn.Module): |
| | | def __init__(self, max_seq_len=512, flip=True): |
| | | super().__init__() |
| | | |
| | | def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None): |
| | | if max_seq_len is None: |
| | | max_seq_len = lengths.max() |
| | | row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device) |
| | | matrix = torch.unsqueeze(lengths, dim=-1) |
| | | mask = row_vector < matrix |
| | | |
| | | return mask.type(dtype).to(device) if device is not None else mask.type(dtype) |
| | | |
| | | def normalize(input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None) -> torch.Tensor: |
| | | if out is None: |