From e54535e5ebec1871c404dc73653885c3a0114cbc Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期三, 06 十二月 2023 11:29:43 +0800
Subject: [PATCH] update spk inference

---
 funasr/bin/asr_inference_launch.py |    7 +++++--
 1 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 59e61ee..f34bfb2 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -952,10 +952,13 @@
             #####  speaker_verification  #####
             ##################################
             # load sv model
-            sv_model_dict = torch.load(sv_model_file)
-            sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
             if ngpu > 0:
+                sv_model_dict = torch.load(sv_model_file)
+                sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
                 sv_model.cuda()
+            else:
+                sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
+                sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
             sv_model.load_state_dict(sv_model_dict)
             print(f'load sv model params: {sv_model_file}')
             sv_model.eval()

--
Gitblit v1.9.1