From f57b68121a526baea43b2e93f4540d8a2995f633 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 15:15:24 +0800
Subject: [PATCH] batch
---
funasr/frontends/utils/complex_utils.py | 29 +++++++++--------------------
1 files changed, 9 insertions(+), 20 deletions(-)
diff --git a/funasr/frontends/utils/complex_utils.py b/funasr/frontends/utils/complex_utils.py
index 5d313c6..06c6936 100644
--- a/funasr/frontends/utils/complex_utils.py
+++ b/funasr/frontends/utils/complex_utils.py
@@ -1,16 +1,17 @@
"""Beamformer module."""
+
from distutils.version import LooseVersion
from typing import Sequence
from typing import Tuple
from typing import Union
import torch
+
try:
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor
except:
print("Please install torch_complex firstly")
-
EPS = torch.finfo(torch.double).eps
@@ -27,15 +28,11 @@
elif is_torch_complex_tensor(ref):
return torch.complex(*real_imag)
else:
- raise ValueError(
- "Please update your PyTorch version to 1.9+ for complex support."
- )
+ raise ValueError("Please update your PyTorch version to 1.9+ for complex support.")
def is_torch_complex_tensor(c):
- return (
- not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
- )
+ return not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
def is_complex(c):
@@ -59,8 +56,7 @@
def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
if not isinstance(seq, (list, tuple)):
raise TypeError(
- "cat(): argument 'tensors' (position 1) must be tuple of Tensors, "
- "not Tensor"
+ "cat(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor"
)
if isinstance(seq[0], ComplexTensor):
return FC.cat(seq, *args, **kwargs)
@@ -68,17 +64,13 @@
return torch.cat(seq, *args, **kwargs)
-def complex_norm(
- c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False
-) -> torch.Tensor:
+def complex_norm(c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False) -> torch.Tensor:
if not is_complex(c):
raise TypeError("Input is not a complex tensor.")
if is_torch_complex_tensor(c):
return torch.norm(c, dim=dim, keepdim=keepdim)
else:
- return torch.sqrt(
- (c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS
- )
+ return torch.sqrt((c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS)
def einsum(equation, *operands):
@@ -116,9 +108,7 @@
return torch.einsum(equation, a, b)
-def inverse(
- c: Union[torch.Tensor, ComplexTensor]
-) -> Union[torch.Tensor, ComplexTensor]:
+def inverse(c: Union[torch.Tensor, ComplexTensor]) -> Union[torch.Tensor, ComplexTensor]:
if isinstance(c, ComplexTensor):
return c.inverse2()
else:
@@ -186,8 +176,7 @@
def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
if not isinstance(seq, (list, tuple)):
raise TypeError(
- "stack(): argument 'tensors' (position 1) must be tuple of Tensors, "
- "not Tensor"
+ "stack(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor"
)
if isinstance(seq[0], ComplexTensor):
return FC.stack(seq, *args, **kwargs)
--
Gitblit v1.9.1