From f57b68121a526baea43b2e93f4540d8a2995f633 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 15:15:24 +0800
Subject: [PATCH] batch

---
 funasr/frontends/utils/complex_utils.py |   29 +++++++++--------------------
 1 files changed, 9 insertions(+), 20 deletions(-)

diff --git a/funasr/frontends/utils/complex_utils.py b/funasr/frontends/utils/complex_utils.py
index 5d313c6..06c6936 100644
--- a/funasr/frontends/utils/complex_utils.py
+++ b/funasr/frontends/utils/complex_utils.py
@@ -1,16 +1,17 @@
 """Beamformer module."""
+
 from distutils.version import LooseVersion
 from typing import Sequence
 from typing import Tuple
 from typing import Union
 
 import torch
+
 try:
     from torch_complex import functional as FC
     from torch_complex.tensor import ComplexTensor
 except:
     print("Please install torch_complex firstly")
-
 
 
 EPS = torch.finfo(torch.double).eps
@@ -27,15 +28,11 @@
     elif is_torch_complex_tensor(ref):
         return torch.complex(*real_imag)
     else:
-        raise ValueError(
-            "Please update your PyTorch version to 1.9+ for complex support."
-        )
+        raise ValueError("Please update your PyTorch version to 1.9+ for complex support.")
 
 
 def is_torch_complex_tensor(c):
-    return (
-        not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
-    )
+    return not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
 
 
 def is_complex(c):
@@ -59,8 +56,7 @@
 def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
     if not isinstance(seq, (list, tuple)):
         raise TypeError(
-            "cat(): argument 'tensors' (position 1) must be tuple of Tensors, "
-            "not Tensor"
+            "cat(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor"
         )
     if isinstance(seq[0], ComplexTensor):
         return FC.cat(seq, *args, **kwargs)
@@ -68,17 +64,13 @@
         return torch.cat(seq, *args, **kwargs)
 
 
-def complex_norm(
-    c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False
-) -> torch.Tensor:
+def complex_norm(c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False) -> torch.Tensor:
     if not is_complex(c):
         raise TypeError("Input is not a complex tensor.")
     if is_torch_complex_tensor(c):
         return torch.norm(c, dim=dim, keepdim=keepdim)
     else:
-        return torch.sqrt(
-            (c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS
-        )
+        return torch.sqrt((c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS)
 
 
 def einsum(equation, *operands):
@@ -116,9 +108,7 @@
         return torch.einsum(equation, a, b)
 
 
-def inverse(
-    c: Union[torch.Tensor, ComplexTensor]
-) -> Union[torch.Tensor, ComplexTensor]:
+def inverse(c: Union[torch.Tensor, ComplexTensor]) -> Union[torch.Tensor, ComplexTensor]:
     if isinstance(c, ComplexTensor):
         return c.inverse2()
     else:
@@ -186,8 +176,7 @@
 def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
     if not isinstance(seq, (list, tuple)):
         raise TypeError(
-            "stack(): argument 'tensors' (position 1) must be tuple of Tensors, "
-            "not Tensor"
+            "stack(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor"
         )
     if isinstance(seq[0], ComplexTensor):
         return FC.stack(seq, *args, **kwargs)

--
Gitblit v1.9.1