BienBoy
2025-02-05 6ebf6e48eb0368518452312c803c58a65fe9bd26
fix: resolve CPU runtime error introduced by previous commit (c1e365f) (#2375)

Fixed a bug that caused a runtime error when running the model on CPU, which was introduced in commit c1e365fea09aafda387cac12fdff43d28c598979. The error was related to incorrect handling of device placement.
4个文件已修改
25 ■■■■■ 已修改文件
funasr/auto/auto_model.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train_ds.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/language_model/rnn/decoders.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_model.py
@@ -366,8 +366,11 @@
        if pbar:
            # pbar.update(1)
            pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
        with torch.cuda.device(next(model.parameters()).device):
            torch.cuda.empty_cache()
        device = next(model.parameters()).device
        if device.type == 'cuda':
            with torch.cuda.device():
                torch.cuda.empty_cache()
        return asr_result_list
    def inference_with_vad(self, input, input_len=None, **cfg):
funasr/bin/train.py
@@ -221,8 +221,10 @@
            )
            trainer.start_step = 0
            with torch.cuda.device(kwargs["device"]):
                torch.cuda.empty_cache()
            device = next(model.parameters()).device
            if device.type == 'cuda':
                with torch.cuda.device():
                    torch.cuda.empty_cache()
            time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
            logging.info(
funasr/bin/train_ds.py
@@ -184,8 +184,10 @@
            )
            trainer.start_step = 0
            with torch.cuda.device(kwargs["device"]):
                torch.cuda.empty_cache()
            device = next(model.parameters()).device
            if device.type == 'cuda':
                with torch.cuda.device():
                    torch.cuda.empty_cache()
            time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
            logging.info(
funasr/models/language_model/rnn/decoders.py
@@ -873,8 +873,10 @@
                        ctc_state[idx], accum_best_ids
                    )
        with torch.cuda.device(vscores.device):
            torch.cuda.empty_cache()
        device = vscores.device
        if device.type == 'cuda':
            with torch.cuda.device():
                torch.cuda.empty_cache()
        dummy_hyps = [{"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}]
        ended_hyps = [