From 580b11b57ac4b62f7e2acda73813a4e10e8e4cd3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 10 十月 2023 17:17:29 +0800
Subject: [PATCH] v0.8.0

---
 funasr/layers/stft.py |   17 ++++++++++-------
 1 files changed, 10 insertions(+), 7 deletions(-)

diff --git a/funasr/layers/stft.py b/funasr/layers/stft.py
index 21beaae..dfb6919 100644
--- a/funasr/layers/stft.py
+++ b/funasr/layers/stft.py
@@ -5,7 +5,6 @@
 
 import torch
 from torch_complex.tensor import ComplexTensor
-from typeguard import check_argument_types
 
 from funasr.modules.nets_utils import make_pad_mask
 from funasr.layers.complex_utils import is_complex
@@ -30,7 +29,6 @@
         normalized: bool = False,
         onesided: bool = True,
     ):
-        assert check_argument_types()
         super().__init__()
         self.n_fft = n_fft
         if win_length is None:
@@ -42,7 +40,8 @@
         self.normalized = normalized
         self.onesided = onesided
         if window is not None and not hasattr(torch, f"{window}_window"):
-            raise ValueError(f"{window} window is not implemented")
+            if window.lower() != "povey":
+                raise ValueError(f"{window} window is not implemented")
         self.window = window
 
     def extra_repr(self):
@@ -83,10 +82,14 @@
         # output: (Batch, Freq, Frames, 2=real_imag)
         # or (Batch, Channel, Freq, Frames, 2=real_imag)
         if self.window is not None:
-            window_func = getattr(torch, f"{self.window}_window")
-            window = window_func(
-                self.win_length, dtype=input.dtype, device=input.device
-            )
+            if self.window.lower() == "povey":
+                window = torch.hann_window(self.win_length, periodic=False,
+                                           device=input.device, dtype=input.dtype).pow(0.85)
+            else:
+                window_func = getattr(torch, f"{self.window}_window")
+                window = window_func(
+                    self.win_length, dtype=input.dtype, device=input.device
+                )
         else:
             window = None
 

--
Gitblit v1.9.1