王梦迪
2025-05-20 fe588bc508c0076bb007d6ed36c18ac8ecb341ac
funasr/auto/auto_frontend.py
@@ -19,7 +19,7 @@
from funasr.download.file import download_from_url
from funasr.auto.auto_model import prepare_data_iterator
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.download.download_from_hub import download_model
from funasr.download.download_model_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
@@ -52,7 +52,7 @@
        key_list, data_list = prepare_data_iterator(input, input_len=input_len)
        batch_size = kwargs.get("batch_size", 1)
        device = kwargs.get("device", "cpu")
        device = kwargs.get("device", "cuda")
        if device == "cpu":
            batch_size = 1
@@ -60,7 +60,7 @@
        result_list = []
        num_samples = len(data_list)
        pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
        # pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
        time0 = time.perf_counter()
        for beg_idx in range(0, num_samples, batch_size):
@@ -87,15 +87,23 @@
                speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
            )
            speech.to(device=device), speech_lengths.to(device=device)
            batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
            if kwargs.get("return_pt", True):
                speech, speech_lengths = speech.to(device=device), speech_lengths.to(device=device)
            else:
                speech, speech_lengths = speech.numpy(), speech_lengths.numpy()
            batch = {
                "input": speech,
                "input_len": speech_lengths,
                "key": key_batch,
                "data_type": "fbank",
            }
            result_list.append(batch)
            pbar.update(1)
            description = f"{meta_data}, "
            pbar.set_description(description)
            # pbar.update(1)
            # description = f"{meta_data}, "
            # pbar.set_description(description)
        time_end = time.perf_counter()
        pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
        # pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
        return result_list