| | |
| | | m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) |
| | | return ys_mask.unsqueeze(-2) & m |
| | | |
| | | |
| | | def vad_mask(size, vad_pos, device="cpu", dtype=torch.bool): |
| | | """Create mask for decoder self-attention. |
| | | |
| | |
| | | ret = torch.ones(size, size, device=device, dtype=dtype) |
| | | if vad_pos <= 0 or vad_pos >= size: |
| | | return ret |
| | | sub_corner = torch.zeros( |
| | | vad_pos - 1, size - vad_pos, device=device, dtype=dtype) |
| | | ret[0:vad_pos - 1, vad_pos:] = sub_corner |
| | | sub_corner = torch.zeros(vad_pos - 1, size - vad_pos, device=device, dtype=dtype) |
| | | ret[0 : vad_pos - 1, vad_pos:] = sub_corner |
| | | return ret |