From 167cea2074a9ab2b697fc3b43ed63babe276217f Mon Sep 17 00:00:00 2001
From: rebear077 <51772538+rebear077@users.noreply.github.com>
Date: 星期三, 14 八月 2024 16:08:08 +0800
Subject: [PATCH] Go_ws_client客户端V1 (#2011)

---
 funasr/frontends/utils/stft.py |   25 ++++++++-----------------
 1 files changed, 8 insertions(+), 17 deletions(-)

diff --git a/funasr/frontends/utils/stft.py b/funasr/frontends/utils/stft.py
index 00d9ec5..381f1e9 100644
--- a/funasr/frontends/utils/stft.py
+++ b/funasr/frontends/utils/stft.py
@@ -86,13 +86,12 @@
         # or (Batch, Channel, Freq, Frames, 2=real_imag)
         if self.window is not None:
             if self.window.lower() == "povey":
-                window = torch.hann_window(self.win_length, periodic=False,
-                                           device=input.device, dtype=input.dtype).pow(0.85)
+                window = torch.hann_window(
+                    self.win_length, periodic=False, device=input.device, dtype=input.dtype
+                ).pow(0.85)
             else:
                 window_func = getattr(torch, f"{self.window}_window")
-                window = window_func(
-                    self.win_length, dtype=input.dtype, device=input.device
-                )
+                window = window_func(self.win_length, dtype=input.dtype, device=input.device)
         else:
             window = None
 
@@ -135,9 +134,7 @@
                     [torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
                 ).numpy()
             else:
-                win_length = (
-                    self.win_length if self.win_length is not None else self.n_fft
-                )
+                win_length = self.win_length if self.win_length is not None else self.n_fft
                 stft_kwargs["window"] = torch.ones(win_length)
 
             output = []
@@ -160,9 +157,7 @@
         if multi_channel:
             # output: (Batch * Channel, Frames, Freq, 2=real_imag)
             # -> (Batch, Frame, Channel, Freq, 2=real_imag)
-            output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
-                1, 2
-            )
+            output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(1, 2)
 
         if ilens is not None:
             if self.center:
@@ -194,14 +189,10 @@
             try:
                 import torchaudio
             except ImportError:
-                raise ImportError(
-                    "Please install torchaudio>=0.3.0 or use torch>=1.6.0"
-                )
+                raise ImportError("Please install torchaudio>=0.3.0 or use torch>=1.6.0")
 
             if not hasattr(torchaudio.functional, "istft"):
-                raise ImportError(
-                    "Please install torchaudio>=0.3.0 or use torch>=1.6.0"
-                )
+                raise ImportError("Please install torchaudio>=0.3.0 or use torch>=1.6.0")
             istft = torchaudio.functional.istft
 
         if self.window is not None:

--
Gitblit v1.9.1