| | |
| | | """Beamformer module.""" |
| | | |
| | | from distutils.version import LooseVersion |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | from typing import Union |
| | | |
| | | import torch |
| | | |
| | | try: |
| | | from torch_complex import functional as FC |
| | | from torch_complex.tensor import ComplexTensor |
| | | except: |
| | | print("Please install torch_complex firstly") |
| | | |
| | | |
| | | |
| | | EPS = torch.finfo(torch.double).eps |
| | |
| | | elif is_torch_complex_tensor(ref): |
| | | return torch.complex(*real_imag) |
| | | else: |
| | | raise ValueError( |
| | | "Please update your PyTorch version to 1.9+ for complex support." |
| | | ) |
| | | raise ValueError("Please update your PyTorch version to 1.9+ for complex support.") |
| | | |
| | | |
| | | def is_torch_complex_tensor(c): |
| | | return ( |
| | | not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c) |
| | | ) |
| | | return not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c) |
| | | |
| | | |
| | | def is_complex(c): |
| | |
| | | def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs): |
| | | if not isinstance(seq, (list, tuple)): |
| | | raise TypeError( |
| | | "cat(): argument 'tensors' (position 1) must be tuple of Tensors, " |
| | | "not Tensor" |
| | | "cat(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor" |
| | | ) |
| | | if isinstance(seq[0], ComplexTensor): |
| | | return FC.cat(seq, *args, **kwargs) |
| | |
| | | return torch.cat(seq, *args, **kwargs) |
| | | |
| | | |
| | | def complex_norm( |
| | | c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False |
| | | ) -> torch.Tensor: |
| | | def complex_norm(c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False) -> torch.Tensor: |
| | | if not is_complex(c): |
| | | raise TypeError("Input is not a complex tensor.") |
| | | if is_torch_complex_tensor(c): |
| | | return torch.norm(c, dim=dim, keepdim=keepdim) |
| | | else: |
| | | return torch.sqrt( |
| | | (c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS |
| | | ) |
| | | return torch.sqrt((c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS) |
| | | |
| | | |
| | | def einsum(equation, *operands): |
| | |
| | | return torch.einsum(equation, a, b) |
| | | |
| | | |
| | | def inverse( |
| | | c: Union[torch.Tensor, ComplexTensor] |
| | | ) -> Union[torch.Tensor, ComplexTensor]: |
| | | def inverse(c: Union[torch.Tensor, ComplexTensor]) -> Union[torch.Tensor, ComplexTensor]: |
| | | if isinstance(c, ComplexTensor): |
| | | return c.inverse2() |
| | | else: |
| | |
| | | def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs): |
| | | if not isinstance(seq, (list, tuple)): |
| | | raise TypeError( |
| | | "stack(): argument 'tensors' (position 1) must be tuple of Tensors, " |
| | | "not Tensor" |
| | | "stack(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor" |
| | | ) |
| | | if isinstance(seq[0], ComplexTensor): |
| | | return FC.stack(seq, *args, **kwargs) |