jmwang66
2022-12-26 682204f0bb1335eb9ba3a2f0eb5605bdf42e8505
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
 
import copy
from typing import Optional, Tuple, Union
 
import humanfriendly
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.layers.log_mel import LogMel
from funasr.layers.stft import Stft
from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.modules.frontends.frontend import Frontend
from typeguard import check_argument_types
 
 
def apply_cmvn(inputs, mvn):  # noqa
    """
    Apply CMVN with mvn data
    """
 
    device = inputs.device
    dtype = inputs.dtype
    frame, dim = inputs.shape
 
    meams = np.tile(mvn[0:1, :dim], (frame, 1))
    vars = np.tile(mvn[1:2, :dim], (frame, 1))
    inputs += torch.from_numpy(meams).type(dtype).to(device)
    inputs *= torch.from_numpy(vars).type(dtype).to(device)
 
    return inputs.type(torch.float32)
 
 
def apply_lfr(inputs, lfr_m, lfr_n):
    LFR_inputs = []
    T = inputs.shape[0]
    T_lfr = int(np.ceil(T / lfr_n))
    left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
    inputs = torch.vstack((left_padding, inputs))
    T = T + (lfr_m - 1) // 2
    for i in range(T_lfr):
        if lfr_m <= T - i * lfr_n:
            LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
        else:  # process last LFR frame
            num_padding = lfr_m - (T - i * lfr_n)
            frame = (inputs[i * lfr_n:]).view(-1)
            for _ in range(num_padding):
                frame = torch.hstack((frame, inputs[-1]))
            LFR_inputs.append(frame)
    LFR_outputs = torch.vstack(LFR_inputs)
    return LFR_outputs.type(torch.float32)
 
 
class WavFrontend(AbsFrontend):
    """Conventional frontend structure for ASR.
    """
    def __init__(
        self,
        fs: Union[int, str] = 16000,
        n_fft: int = 512,
        win_length: int = 400,
        hop_length: int = 160,
        window: Optional[str] = 'hamming',
        center: bool = True,
        normalized: bool = False,
        onesided: bool = True,
        n_mels: int = 80,
        fmin: int = None,
        fmax: int = None,
        lfr_m: int = 1,
        lfr_n: int = 1,
        htk: bool = False,
        mvn_data=None,
        frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
        apply_stft: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
        if isinstance(fs, str):
            fs = humanfriendly.parse_size(fs)
 
        # Deepcopy (In general, dict shouldn't be used as default arg)
        frontend_conf = copy.deepcopy(frontend_conf)
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = window
        self.fs = fs
        self.mvn_data = mvn_data
        self.lfr_m = lfr_m
        self.lfr_n = lfr_n
 
        if apply_stft:
            self.stft = Stft(
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                center=center,
                window=window,
                normalized=normalized,
                onesided=onesided,
            )
        else:
            self.stft = None
        self.apply_stft = apply_stft
 
        if frontend_conf is not None:
            self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
        else:
            self.frontend = None
 
        self.logmel = LogMel(
            fs=fs,
            n_fft=n_fft,
            n_mels=n_mels,
            fmin=fmin,
            fmax=fmax,
            htk=htk,
        )
        self.n_mels = n_mels
        self.frontend_type = 'default'
 
    def output_size(self) -> int:
        return self.n_mels
 
    def forward(
            self, input: torch.Tensor,
            input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
        sample_frequency = self.fs
        num_mel_bins = self.n_mels
        frame_length = self.win_length * 1000 / sample_frequency
        frame_shift = self.hop_length * 1000 / sample_frequency
 
        waveform = input * (1 << 15)
 
        mat = kaldi.fbank(waveform,
                          num_mel_bins=num_mel_bins,
                          frame_length=frame_length,
                          frame_shift=frame_shift,
                          dither=1.0,
                          energy_floor=0.0,
                          window_type=self.window,
                          sample_frequency=sample_frequency)
        if self.lfr_m != 1 or self.lfr_n != 1:
            mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
        if self.mvn_data is not None:
            mat = apply_cmvn(mat, self.mvn_data)
 
        input_feats = mat[None, :]
        feats_lens = torch.randn(1)
        feats_lens.fill_(input_feats.shape[1])
 
        return input_feats, feats_lens