维石
2024-06-03 d2e9bf01426af61dce7fdc4449906755524eb253
funasr/utils/export_utils.py
@@ -20,10 +20,12 @@
                export_dir=export_dir,
                **kwargs
            )
        elif type == 'torchscript':
        elif type == 'torchscripts':
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            _torchscripts(
                m,
                path=export_dir,
                device=device
            )
        print("output dir: {}".format(export_dir))
@@ -78,13 +80,15 @@
            )
def _torchscripts(model, path, device='cpu'):
def _torchscripts(model, path, device='cuda'):
    dummy_input = model.export_dummy_inputs()
    if device == 'cuda':
        model = model.cuda()
        dummy_input = tuple([i.cuda() for i in dummy_input])
        if isinstance(dummy_input, torch.Tensor):
            dummy_input = dummy_input.cuda()
        else:
            dummy_input = tuple([i.cuda() for i in dummy_input])
    # model_script = torch.jit.script(model)
    model_script = torch.jit.trace(model, dummy_input)
    model_script.save(os.path.join(path, f'{model.export_name}.torchscripts'))