funasr/utils/export_utils.py
@@ -20,10 +20,12 @@ export_dir=export_dir, **kwargs ) elif type == 'torchscript': elif type == 'torchscripts': device = 'cuda' if torch.cuda.is_available() else 'cpu' _torchscripts( m, path=export_dir, device=device ) print("output dir: {}".format(export_dir)) @@ -88,6 +90,5 @@ else: dummy_input = tuple([i.cuda() for i in dummy_input]) # model_script = torch.jit.script(model) model_script = torch.jit.trace(model, dummy_input) model_script.save(os.path.join(path, f'{model.export_name}.torchscripts'))