# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Part of the implementation is borrowed from espnet/espnet.
|
|
from typing import Tuple
|
|
import numpy as np
|
import torch
|
import torchaudio.compliance.kaldi as kaldi
|
from funasr.models.frontend.abs_frontend import AbsFrontend
|
from typeguard import check_argument_types
|
from torch.nn.utils.rnn import pad_sequence
|
import kaldi_native_fbank as knf
|
|
class WavFrontend(AbsFrontend):
|
"""Conventional frontend structure for ASR.
|
"""
|
|
def __init__(
|
self,
|
cmvn_file: str = None,
|
fs: int = 16000,
|
window: str = 'hamming',
|
n_mels: int = 80,
|
frame_length: int = 25,
|
frame_shift: int = 10,
|
filter_length_min: int = -1,
|
filter_length_max: int = -1,
|
lfr_m: int = 1,
|
lfr_n: int = 1,
|
dither: float = 1.0,
|
snip_edges: bool = True,
|
upsacle_samples: bool = True,
|
):
|
assert check_argument_types()
|
super().__init__()
|
self.fs = fs
|
self.window = window
|
self.n_mels = n_mels
|
self.frame_length = frame_length
|
self.frame_shift = frame_shift
|
self.filter_length_min = filter_length_min
|
self.filter_length_max = filter_length_max
|
self.lfr_m = lfr_m
|
self.lfr_n = lfr_n
|
self.cmvn_file = cmvn_file
|
self.dither = dither
|
self.snip_edges = snip_edges
|
self.upsacle_samples = upsacle_samples
|
|
def output_size(self) -> int:
|
return self.n_mels * self.lfr_m
|
|
def forward(
|
self,
|
input: torch.Tensor,
|
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
batch_size = input.size(0)
|
feats = []
|
feats_lens = []
|
for i in range(batch_size):
|
waveform_length = input_lengths[i]
|
waveform = input[i][:waveform_length]
|
waveform = waveform * (1 << 15)
|
waveform = waveform.unsqueeze(0)
|
mat = kaldi.fbank(waveform,
|
num_mel_bins=self.n_mels,
|
frame_length=self.frame_length,
|
frame_shift=self.frame_shift,
|
dither=self.dither,
|
energy_floor=0.0,
|
window_type=self.window,
|
sample_frequency=self.fs)
|
|
feat_length = mat.size(0)
|
feats.append(mat)
|
feats_lens.append(feat_length)
|
|
feats_lens = torch.as_tensor(feats_lens)
|
feats_pad = pad_sequence(feats,
|
batch_first=True,
|
padding_value=0.0)
|
return feats_pad, feats_lens
|
|
import kaldi_native_fbank as knf
|
|
def fbank_knf(waveform):
|
# sampling_rate = 16000
|
# samples = torch.randn(16000 * 10)
|
|
opts = knf.FbankOptions()
|
opts.frame_opts.samp_freq = 16000
|
opts.frame_opts.dither = 0.0
|
opts.frame_opts.window_type = "hamming"
|
opts.frame_opts.frame_shift_ms = 10.0
|
opts.frame_opts.frame_length_ms = 25.0
|
opts.mel_opts.num_bins = 80
|
opts.energy_floor = 1
|
opts.frame_opts.snip_edges = True
|
opts.mel_opts.debug_mel = False
|
|
fbank = knf.OnlineFbank(opts)
|
waveform = waveform * (1 << 15)
|
fbank.accept_waveform(opts.frame_opts.samp_freq, waveform.tolist())
|
frames = fbank.num_frames_ready
|
mat = np.empty([frames, opts.mel_opts.num_bins])
|
for i in range(frames):
|
mat[i, :] = fbank.get_frame(i)
|
return mat
|
|
if __name__ == '__main__':
|
import librosa
|
|
path = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
|
waveform, fs = librosa.load(path, sr=None)
|
fbank = fbank_knf(waveform)
|
frontend = WavFrontend(dither=0.0)
|
waveform_tensor = torch.from_numpy(waveform)[None, :]
|
fbank_torch, _ = frontend.forward(waveform_tensor, [waveform_tensor.size(1)])
|
fbank_torch = fbank_torch.cpu().numpy()[0, :, :]
|
diff = fbank - fbank_torch
|
diff_max = diff.max()
|
diff_sum = diff.abs().sum()
|
pass
|