AldarisX
2025-04-07 d43d0853dcf3a1db04302c7b527e92ace3ccfb55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from typing import Tuple
import torch
import torch.nn as nn
 
 
from funasr.register import tables
from torch.nn.utils.rnn import pad_sequence
 
 
@tables.register("frontend_classes", "WhisperFrontend")
class WhisperFrontend(nn.Module):
    """Speech Representation Using Encoder Outputs from OpenAI's Whisper Model:
 
    URL: https://github.com/openai/whisper
    """
 
    def __init__(
        self,
        fs: int = 16000,
        whisper_model: str = None,
        do_pad_trim: bool = True,
        n_mels: int = 80,
        permute: bool = False,
        **kwargs,
    ):
        super().__init__()
        assert fs == 16000
        self.fs = fs
        import whisper
        from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
 
        self.n_fft = N_FFT
        self.win_length = N_FFT
        self.hop_length = HOP_LENGTH
        self.pad_samples = N_SAMPLES
        self.frame_shift = int(self.hop_length / self.fs * 1000)
        self.lfr_n = 1
        self.n_mels = n_mels
        if whisper_model == "large-v3" or whisper_model == "large":
            self.n_mels = 128
 
        filters_path = kwargs.get("filters_path", None)
        self.filters_path = filters_path
        if filters_path is not None:
            from funasr.models.sense_voice.whisper_lib.audio import mel_filters
 
            self.mel_filters = mel_filters
        else:
            self.mel_filters = whisper.audio.mel_filters
        self.do_pad_trim = do_pad_trim
        if do_pad_trim:
            self.pad_or_trim = whisper.pad_or_trim
        self.permute = permute
 
        # assert whisper_model in whisper.available_models()
 
    def output_size(self) -> int:
        return self.n_mels
 
    def log_mel_spectrogram(
        self,
        audio: torch.Tensor,
        ilens: torch.Tensor = None,
    ) -> torch.Tensor:
        window = torch.hann_window(self.win_length).to(audio.device)
        stft = torch.stft(audio, self.n_fft, self.hop_length, window=window, return_complex=True)
 
        # whisper deletes the last frame by default (Shih-Lun)
        magnitudes = stft[..., :-1].abs() ** 2
        if self.filters_path is not None:
            filters = self.mel_filters(audio.device, self.n_mels, self.filters_path)
        else:
            filters = self.mel_filters(audio.device, self.n_mels)
        mel_spec = filters @ magnitudes
 
        log_spec = torch.clamp(mel_spec, min=1e-10).log10()
 
        if ilens is not None:
            olens = ilens // self.hop_length
        else:
            olens = None
 
        log_spec = torch.maximum(
            log_spec,
            log_spec.view(audio.size(0), -1).max(dim=-1)[0][:, None, None] - 8.0,
        )
        log_spec = (log_spec + 4.0) / 4.0
 
        return log_spec, olens
 
    def forward(
        self,
        input: torch.Tensor,
        input_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        input = input.to(torch.float32)
        for i in range(batch_size):
            if self.do_pad_trim:
                feat = self.pad_or_trim(input[i], self.pad_samples)
            else:
                feat = input[i]
            feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0])
            feats.append(feat[0])
            feats_lens.append(feat_len)
        feats_lens = torch.as_tensor(feats_lens)
 
        if batch_size == 1:
            feats_pad = feats[0][None, :, :]
        else:
            feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
        if self.permute:
            feats_pad = feats_pad.permute(0, 2, 1)
        return feats_pad, feats_lens