kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/frontends/utils/complex_utils.py
@@ -1,16 +1,17 @@
"""Beamformer module."""
from distutils.version import LooseVersion
from typing import Sequence
from typing import Tuple
from typing import Union
import torch
try:
    from torch_complex import functional as FC
    from torch_complex.tensor import ComplexTensor
except:
    print("Please install torch_complex firstly")
EPS = torch.finfo(torch.double).eps
@@ -27,15 +28,11 @@
    elif is_torch_complex_tensor(ref):
        return torch.complex(*real_imag)
    else:
        raise ValueError(
            "Please update your PyTorch version to 1.9+ for complex support."
        )
        raise ValueError("Please update your PyTorch version to 1.9+ for complex support.")
def is_torch_complex_tensor(c):
    return (
        not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
    )
    return not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
def is_complex(c):
@@ -59,8 +56,7 @@
def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
    if not isinstance(seq, (list, tuple)):
        raise TypeError(
            "cat(): argument 'tensors' (position 1) must be tuple of Tensors, "
            "not Tensor"
            "cat(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor"
        )
    if isinstance(seq[0], ComplexTensor):
        return FC.cat(seq, *args, **kwargs)
@@ -68,17 +64,13 @@
        return torch.cat(seq, *args, **kwargs)
def complex_norm(
    c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False
) -> torch.Tensor:
def complex_norm(c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False) -> torch.Tensor:
    if not is_complex(c):
        raise TypeError("Input is not a complex tensor.")
    if is_torch_complex_tensor(c):
        return torch.norm(c, dim=dim, keepdim=keepdim)
    else:
        return torch.sqrt(
            (c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS
        )
        return torch.sqrt((c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS)
def einsum(equation, *operands):
@@ -116,9 +108,7 @@
        return torch.einsum(equation, a, b)
def inverse(
    c: Union[torch.Tensor, ComplexTensor]
) -> Union[torch.Tensor, ComplexTensor]:
def inverse(c: Union[torch.Tensor, ComplexTensor]) -> Union[torch.Tensor, ComplexTensor]:
    if isinstance(c, ComplexTensor):
        return c.inverse2()
    else:
@@ -186,8 +176,7 @@
def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
    if not isinstance(seq, (list, tuple)):
        raise TypeError(
            "stack(): argument 'tensors' (position 1) must be tuple of Tensors, "
            "not Tensor"
            "stack(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor"
        )
    if isinstance(seq[0], ComplexTensor):
        return FC.stack(seq, *args, **kwargs)