Daniel
2023-03-07 3547adb4fb8b8284afefb8413382592fcdfa0302
funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py
@@ -105,8 +105,8 @@
            frame_shift: int = 10,
            filter_length_min: int = -1,
            filter_length_max: float = -1,
            lfr_m: int = 1,
            lfr_n: int = 1,
            lfr_m: int = 7,
            lfr_n: int = 6,
            dither: float = 1.0
    ) -> None:
        # check_argument_types()
@@ -229,22 +229,24 @@
            if key == "config_path":
                with open(str(value), 'rb') as f:
                    config = yaml.load(f, Loader=yaml.Loader)
            if key == "cmvn_path":
                cmvn_path = str(value)
        opts = kaldifeat.FbankOptions()
        opts.frame_opts.dither = 1.0 # TODO: 0.0 or 1.0
        opts.frame_opts.window_type = config['WavFrontend']['frontend_conf']['window']
        opts.mel_opts.num_bins = int(config['WavFrontend']['frontend_conf']['n_mels'])
        opts.frame_opts.frame_shift_ms = float(config['WavFrontend']['frontend_conf']['frame_shift'])
        opts.frame_opts.frame_length_ms = float(config['WavFrontend']['frontend_conf']['frame_length'])
        opts.frame_opts.samp_freq = int(config['WavFrontend']['frontend_conf']['fs'])
        opts.frame_opts.window_type = config['frontend_conf']['window']
        opts.mel_opts.num_bins = int(config['frontend_conf']['n_mels'])
        opts.frame_opts.frame_shift_ms = float(config['frontend_conf']['frame_shift'])
        opts.frame_opts.frame_length_ms = float(config['frontend_conf']['frame_length'])
        opts.frame_opts.samp_freq = int(config['frontend_conf']['fs'])
        opts.device = torch.device(self.device)
        self.opts = opts
        self.feature_extractor = Fbank(self.opts)
        self.feature_size = opts.mel_opts.num_bins
        self.frontend = WavFrontend(
            cmvn_file=config['WavFrontend']['cmvn_file'],
            **config['WavFrontend']['frontend_conf'])
            cmvn_file=cmvn_path,
            **config['frontend_conf'])
    def extract_feat(self,
                     waveform_list: List[np.ndarray]