From 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 18:03:06 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_aky
---
funasr/layers/stft.py | 15 ++++++++++-----
1 files changed, 10 insertions(+), 5 deletions(-)
diff --git a/funasr/layers/stft.py b/funasr/layers/stft.py
index 21beaae..376b5a3 100644
--- a/funasr/layers/stft.py
+++ b/funasr/layers/stft.py
@@ -42,7 +42,8 @@
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f"{window}_window"):
- raise ValueError(f"{window} window is not implemented")
+ if window.lower() != "povey":
+ raise ValueError(f"{window} window is not implemented")
self.window = window
def extra_repr(self):
@@ -83,10 +84,14 @@
# output: (Batch, Freq, Frames, 2=real_imag)
# or (Batch, Channel, Freq, Frames, 2=real_imag)
if self.window is not None:
- window_func = getattr(torch, f"{self.window}_window")
- window = window_func(
- self.win_length, dtype=input.dtype, device=input.device
- )
+ if self.window.lower() == "povey":
+ window = torch.hann_window(self.win_length, periodic=False,
+ device=input.device, dtype=input.dtype).pow(0.85)
+ else:
+ window_func = getattr(torch, f"{self.window}_window")
+ window = window_func(
+ self.win_length, dtype=input.dtype, device=input.device
+ )
else:
window = None
--
Gitblit v1.9.1