游雁
2024-06-09 b75d1e89bb2f513a79bb07e9100ba1cd2bbcf40c
funasr/auto/auto_frontend.py
@@ -52,7 +52,7 @@
        key_list, data_list = prepare_data_iterator(input, input_len=input_len)
        batch_size = kwargs.get("batch_size", 1)
        device = kwargs.get("device", "cpu")
        device = kwargs.get("device", "cuda")
        if device == "cpu":
            batch_size = 1
@@ -60,7 +60,7 @@
        result_list = []
        num_samples = len(data_list)
        pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
        # pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
        time0 = time.perf_counter()
        for beg_idx in range(0, num_samples, batch_size):
@@ -87,15 +87,23 @@
                speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
            )
            speech.to(device=device), speech_lengths.to(device=device)
            batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
            if kwargs.get("return_pt", True):
                speech, speech_lengths = speech.to(device=device), speech_lengths.to(device=device)
            else:
                speech, speech_lengths = speech.numpy(), speech_lengths.numpy()
            batch = {
                "input": speech,
                "input_len": speech_lengths,
                "key": key_batch,
                "data_type": "fbank",
            }
            result_list.append(batch)
            pbar.update(1)
            description = f"{meta_data}, "
            pbar.set_description(description)
            # pbar.update(1)
            # description = f"{meta_data}, "
            # pbar.set_description(description)
        time_end = time.perf_counter()
        pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
        # pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
        return result_list