zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/auto/auto_frontend.py
@@ -45,12 +45,10 @@
            del kwargs["frontend"]
        self.kwargs = kwargs
    def __call__(self, input, input_len=None, kwargs=None, **cfg):
        
        kwargs = self.kwargs if kwargs is None else kwargs
        kwargs.update(cfg)
        key_list, data_list = prepare_data_iterator(input, input_len=input_len)
        batch_size = kwargs.get("batch_size", 1)
@@ -72,27 +70,32 @@
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
            audio_sample_list = load_audio_text_image_video(
                data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)
            )
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=self.frontend, **kwargs)
            speech, speech_lengths = extract_fbank(
                audio_sample_list,
                data_type=kwargs.get("data_type", "sound"),
                frontend=self.frontend,
                **kwargs,
            )
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
            meta_data["batch_data_time"] = (
                speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
            )
            
            speech.to(device=device), speech_lengths.to(device=device)
            batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
            result_list.append(batch)
            
            pbar.update(1)
            description = (
                f"{meta_data}, "
            )
            description = f"{meta_data}, "
            pbar.set_description(description)
        
        time_end = time.perf_counter()
        pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
        
        return result_list