zhifu gao
2024-06-04 3b0526e7be3565c42007313b90a018a2f8c8dff1
funasr/auto/auto_frontend.py
@@ -52,7 +52,7 @@
        key_list, data_list = prepare_data_iterator(input, input_len=input_len)
        batch_size = kwargs.get("batch_size", 1)
        device = kwargs.get("device", "cpu")
        device = kwargs.get("device", "cuda")
        if device == "cpu":
            batch_size = 1
@@ -87,8 +87,16 @@
                speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
            )
            speech.to(device=device), speech_lengths.to(device=device)
            batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
            if kwargs.get("return_pt", True):
                speech, speech_lengths = speech.to(device=device), speech_lengths.to(device=device)
            else:
                speech, speech_lengths = speech.numpy(), speech_lengths.numpy()
            batch = {
                "input": speech,
                "input_len": speech_lengths,
                "key": key_batch,
                data_type: "fbank",
            }
            result_list.append(batch)
            pbar.update(1)