游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/layers/stft.py
@@ -5,7 +5,6 @@
import torch
from torch_complex.tensor import ComplexTensor
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.layers.complex_utils import is_complex
@@ -30,7 +29,6 @@
        normalized: bool = False,
        onesided: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
        self.n_fft = n_fft
        if win_length is None:
@@ -42,7 +40,8 @@
        self.normalized = normalized
        self.onesided = onesided
        if window is not None and not hasattr(torch, f"{window}_window"):
            raise ValueError(f"{window} window is not implemented")
            if window.lower() != "povey":
                raise ValueError(f"{window} window is not implemented")
        self.window = window
    def extra_repr(self):
@@ -83,10 +82,14 @@
        # output: (Batch, Freq, Frames, 2=real_imag)
        # or (Batch, Channel, Freq, Frames, 2=real_imag)
        if self.window is not None:
            window_func = getattr(torch, f"{self.window}_window")
            window = window_func(
                self.win_length, dtype=input.dtype, device=input.device
            )
            if self.window.lower() == "povey":
                window = torch.hann_window(self.win_length, periodic=False,
                                           device=input.device, dtype=input.dtype).pow(0.85)
            else:
                window_func = getattr(torch, f"{self.window}_window")
                window = window_func(
                    self.win_length, dtype=input.dtype, device=input.device
                )
        else:
            window = None