From 7a6d64fd23629bbb5ccc26c707f6e14005708ff3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 13 二月 2023 18:57:52 +0800
Subject: [PATCH] export model

---
 /dev/null |  123 -----------------------------------------
 1 files changed, 0 insertions(+), 123 deletions(-)

diff --git a/test_local/fbank.py b/test_local/fbank.py
deleted file mode 100644
index 26daa45..0000000
--- a/test_local/fbank.py
+++ /dev/null
@@ -1,123 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-# Part of the implementation is borrowed from espnet/espnet.
-
-from typing import Tuple
-
-import numpy as np
-import torch
-import torchaudio.compliance.kaldi as kaldi
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from typeguard import check_argument_types
-from torch.nn.utils.rnn import pad_sequence
-import kaldi_native_fbank as knf
-
-class WavFrontend(AbsFrontend):
-	"""Conventional frontend structure for ASR.
-	"""
-
-	def __init__(
-		self,
-		cmvn_file: str = None,
-		fs: int = 16000,
-		window: str = 'hamming',
-		n_mels: int = 80,
-		frame_length: int = 25,
-		frame_shift: int = 10,
-		filter_length_min: int = -1,
-		filter_length_max: int = -1,
-		lfr_m: int = 1,
-		lfr_n: int = 1,
-		dither: float = 1.0,
-		snip_edges: bool = True,
-		upsacle_samples: bool = True,
-	):
-		assert check_argument_types()
-		super().__init__()
-		self.fs = fs
-		self.window = window
-		self.n_mels = n_mels
-		self.frame_length = frame_length
-		self.frame_shift = frame_shift
-		self.filter_length_min = filter_length_min
-		self.filter_length_max = filter_length_max
-		self.lfr_m = lfr_m
-		self.lfr_n = lfr_n
-		self.cmvn_file = cmvn_file
-		self.dither = dither
-		self.snip_edges = snip_edges
-		self.upsacle_samples = upsacle_samples
-
-	def output_size(self) -> int:
-		return self.n_mels * self.lfr_m
-
-	def forward(
-		self,
-		input: torch.Tensor,
-		input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-		batch_size = input.size(0)
-		feats = []
-		feats_lens = []
-		for i in range(batch_size):
-			waveform_length = input_lengths[i]
-			waveform = input[i][:waveform_length]
-			waveform = waveform * (1 << 15)
-			waveform = waveform.unsqueeze(0)
-			mat = kaldi.fbank(waveform,
-			                  num_mel_bins=self.n_mels,
-			                  frame_length=self.frame_length,
-			                  frame_shift=self.frame_shift,
-			                  dither=self.dither,
-			                  energy_floor=0.0,
-			                  window_type=self.window,
-			                  sample_frequency=self.fs)
-
-			feat_length = mat.size(0)
-			feats.append(mat)
-			feats_lens.append(feat_length)
-
-		feats_lens = torch.as_tensor(feats_lens)
-		feats_pad = pad_sequence(feats,
-		                         batch_first=True,
-		                         padding_value=0.0)
-		return feats_pad, feats_lens
-
-import kaldi_native_fbank as knf
-
-def fbank_knf(waveform):
-	# sampling_rate = 16000
-	# samples = torch.randn(16000 * 10)
-
-	opts = knf.FbankOptions()
-	opts.frame_opts.samp_freq = 16000
-	opts.frame_opts.dither = 0.0
-	opts.frame_opts.window_type = "hamming"
-	opts.frame_opts.frame_shift_ms = 10.0
-	opts.frame_opts.frame_length_ms = 25.0
-	opts.mel_opts.num_bins = 80
-	opts.energy_floor = 1
-	opts.frame_opts.snip_edges = True
-	opts.mel_opts.debug_mel = False
-	
-	fbank = knf.OnlineFbank(opts)
-	waveform = waveform * (1 << 15)
-	fbank.accept_waveform(opts.frame_opts.samp_freq, waveform.tolist())
-	frames = fbank.num_frames_ready
-	mat = np.empty([frames, opts.mel_opts.num_bins])
-	for i in range(frames):
-		mat[i, :] = fbank.get_frame(i)
-	return mat
-
-if __name__ == '__main__':
-	import librosa
-	
-	path = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
-	waveform, fs = librosa.load(path, sr=None)
-	fbank = fbank_knf(waveform)
-	frontend = WavFrontend(dither=0.0)
-	waveform_tensor = torch.from_numpy(waveform)[None, :]
-	fbank_torch, _ = frontend.forward(waveform_tensor, [waveform_tensor.size(1)])
-	fbank_torch = fbank_torch.cpu().numpy()[0, :, :]
-	diff = fbank - fbank_torch
-	diff_max = diff.max()
-	diff_sum = diff.abs().sum()
-	pass
\ No newline at end of file

--
Gitblit v1.9.1