From 813027835e90d97c1d54ffddcf100a587b77af5e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 18 四月 2023 19:36:08 +0800
Subject: [PATCH] ncpu

---
 funasr/bin/sond_inference.py                     |    2 +
 funasr/bin/sv_inference_launch.py                |    2 -
 funasr/bin/vad_inference_launch.py               |    5 --
 funasr/bin/asr_inference_mfcca.py                |    2 +
 funasr/bin/asr_inference_paraformer_vad.py       |    2 +
 funasr/bin/diar_inference_launch.py              |    2 -
 funasr/bin/lm_inference_launch.py                |    5 --
 funasr/bin/punc_inference_launch.py              |    4 --
 funasr/bin/tp_inference.py                       |    3 +
 funasr/bin/asr_inference_paraformer.py           |    4 +
 funasr/bin/asr_inference_uniasr.py               |    2 +
 funasr/bin/asr_inference.py                      |    2 +
 funasr/bin/vad_inference.py                      |    3 +
 funasr/bin/vad_inference_online.py               |    3 +
 funasr/bin/lm_inference.py                       |    7 +--
 funasr/bin/asr_inference_paraformer_streaming.py |    2 +
 funasr/bin/sv_inference.py                       |    3 +
 funasr/bin/eend_ola_inference.py                 |    2 +
 funasr/bin/punctuation_infer_vadrealtime.py      |    6 +--
 funasr/bin/tp_inference_launch.py                |    4 --
 funasr/bin/asr_inference_paraformer_vad_punc.py  |    2 +
 funasr/bin/asr_inference_launch.py               |    5 --
 22 files changed, 37 insertions(+), 35 deletions(-)

diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index f3b4d56..4722602 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -346,6 +346,8 @@
     **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
     if word_lm_train_config is not None:
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 2b6716e..e10ebf4 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -1,9 +1,4 @@
 #!/usr/bin/env python3
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
-
-import torch
-torch.set_num_threads(1)
 
 import argparse
 import logging
diff --git a/funasr/bin/asr_inference_mfcca.py b/funasr/bin/asr_inference_mfcca.py
index 6f3dbb1..e832869 100644
--- a/funasr/bin/asr_inference_mfcca.py
+++ b/funasr/bin/asr_inference_mfcca.py
@@ -472,6 +472,8 @@
     **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
     if word_lm_train_config is not None:
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 8cbd419..a8ac99d 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -612,7 +612,9 @@
         **kwargs,
 ):
     assert check_argument_types()
-
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
+    
     if word_lm_train_config is not None:
         raise NotImplementedError("Word LM is not implemented")
     if ngpu > 1:
diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
index 944685f..821f694 100644
--- a/funasr/bin/asr_inference_paraformer_streaming.py
+++ b/funasr/bin/asr_inference_paraformer_streaming.py
@@ -536,6 +536,8 @@
         **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
 
     if word_lm_train_config is not None:
         raise NotImplementedError("Word LM is not implemented")
diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py
index 1548f9f..977dc9b 100644
--- a/funasr/bin/asr_inference_paraformer_vad.py
+++ b/funasr/bin/asr_inference_paraformer_vad.py
@@ -157,6 +157,8 @@
     **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
     
     if word_lm_train_config is not None:
         raise NotImplementedError("Word LM is not implemented")
diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index 9dc0b79..197930f 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -484,6 +484,8 @@
         **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
 
     if word_lm_train_config is not None:
         raise NotImplementedError("Word LM is not implemented")
diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
index 4aea720..35ecdc2 100644
--- a/funasr/bin/asr_inference_uniasr.py
+++ b/funasr/bin/asr_inference_uniasr.py
@@ -379,6 +379,8 @@
         **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
     if word_lm_train_config is not None:
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index 83436e8..07974c0 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -2,8 +2,6 @@
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 
-import torch
-torch.set_num_threads(1)
 
 import argparse
 import logging
diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py
index 01d3f29..87816dd 100755
--- a/funasr/bin/eend_ola_inference.py
+++ b/funasr/bin/eend_ola_inference.py
@@ -158,6 +158,8 @@
         **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
     if ngpu > 1:
diff --git a/funasr/bin/lm_inference.py b/funasr/bin/lm_inference.py
index 15c56ca..76de6df 100644
--- a/funasr/bin/lm_inference.py
+++ b/funasr/bin/lm_inference.py
@@ -89,10 +89,9 @@
     **kwargs,
 ):
     assert check_argument_types()
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
+
 
     if ngpu >= 1 and torch.cuda.is_available():
         device = "cuda"
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
index d229cc6..dc6414f 100644
--- a/funasr/bin/lm_inference_launch.py
+++ b/funasr/bin/lm_inference_launch.py
@@ -1,9 +1,6 @@
 #!/usr/bin/env python3
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
 
-import torch
-torch.set_num_threads(1)
+
 
 import argparse
 import logging
diff --git a/funasr/bin/punc_inference_launch.py b/funasr/bin/punc_inference_launch.py
index 2c5a286..b1d9235 100755
--- a/funasr/bin/punc_inference_launch.py
+++ b/funasr/bin/punc_inference_launch.py
@@ -1,9 +1,5 @@
 #!/usr/bin/env python3
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
 
-import torch
-torch.set_num_threads(1)
 
 import argparse
 import logging
diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py
index 5157eeb..b2db1bf 100644
--- a/funasr/bin/punctuation_infer_vadrealtime.py
+++ b/funasr/bin/punctuation_infer_vadrealtime.py
@@ -203,10 +203,8 @@
     **kwargs,
 ):
     assert check_argument_types()
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
 
     if ngpu >= 1 and torch.cuda.is_available():
         device = "cuda"
diff --git a/funasr/bin/sond_inference.py b/funasr/bin/sond_inference.py
index 5a0a8e2..c55bc35 100755
--- a/funasr/bin/sond_inference.py
+++ b/funasr/bin/sond_inference.py
@@ -252,6 +252,8 @@
         **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
     if ngpu > 1:
diff --git a/funasr/bin/sv_inference.py b/funasr/bin/sv_inference.py
index 7e63bbd..76b1dfb 100755
--- a/funasr/bin/sv_inference.py
+++ b/funasr/bin/sv_inference.py
@@ -179,6 +179,9 @@
         **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
+    
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
     if ngpu > 1:
diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py
index 64a3cff..8806070 100755
--- a/funasr/bin/sv_inference_launch.py
+++ b/funasr/bin/sv_inference_launch.py
@@ -2,8 +2,6 @@
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 
-import torch
-torch.set_num_threads(1)
 
 import argparse
 import logging
diff --git a/funasr/bin/tp_inference.py b/funasr/bin/tp_inference.py
index 6360b17..191bbf3 100644
--- a/funasr/bin/tp_inference.py
+++ b/funasr/bin/tp_inference.py
@@ -179,6 +179,9 @@
         **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
+    
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
     if ngpu > 1:
diff --git a/funasr/bin/tp_inference_launch.py b/funasr/bin/tp_inference_launch.py
index 55debac..6cdff05 100644
--- a/funasr/bin/tp_inference_launch.py
+++ b/funasr/bin/tp_inference_launch.py
@@ -1,9 +1,5 @@
 #!/usr/bin/env python3
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
 
-import torch
-torch.set_num_threads(1)
 
 import argparse
 import logging
diff --git a/funasr/bin/vad_inference.py b/funasr/bin/vad_inference.py
index 08d65a4..aff0a44 100644
--- a/funasr/bin/vad_inference.py
+++ b/funasr/bin/vad_inference.py
@@ -192,6 +192,9 @@
         **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
+    
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
     if ngpu > 1:
diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py
index 8fea8db..4a1f334 100644
--- a/funasr/bin/vad_inference_launch.py
+++ b/funasr/bin/vad_inference_launch.py
@@ -1,9 +1,4 @@
 #!/usr/bin/env python3
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
-
-import torch
-torch.set_num_threads(1)
 
 import argparse
 import logging
diff --git a/funasr/bin/vad_inference_online.py b/funasr/bin/vad_inference_online.py
index 9ed0721..4d02620 100644
--- a/funasr/bin/vad_inference_online.py
+++ b/funasr/bin/vad_inference_online.py
@@ -151,6 +151,9 @@
         **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
+    
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
     if ngpu > 1:

--
Gitblit v1.9.1