funasr/utils/export_utils.py
@@ -78,12 +78,15 @@ ) def _torchscripts(model, path, device='cpu'): def _torchscripts(model, path, device='cuda'): dummy_input = model.export_dummy_inputs() if device == 'cuda': model = model.cuda() dummy_input = tuple([i.cuda() for i in dummy_input]) if isinstance(dummy_input, torch.Tensor): dummy_input = dummy_input.cuda() else: dummy_input = tuple([i.cuda() for i in dummy_input]) # model_script = torch.jit.script(model) model_script = torch.jit.trace(model, dummy_input)