R1ckShi
2024-05-30 d097d0ca45472965d4411357d52adda5657691a2
funasr/utils/export_utils.py
@@ -78,12 +78,15 @@
            )
def _torchscripts(model, path, device='cpu'):
def _torchscripts(model, path, device='cuda'):
    dummy_input = model.export_dummy_inputs()
    if device == 'cuda':
        model = model.cuda()
        dummy_input = tuple([i.cuda() for i in dummy_input])
        if isinstance(dummy_input, torch.Tensor):
            dummy_input = dummy_input.cuda()
        else:
            dummy_input = tuple([i.cuda() for i in dummy_input])
    # model_script = torch.jit.script(model)
    model_script = torch.jit.trace(model, dummy_input)