From 3360a1d9453ef0ce441cc41b0090d09b3bb296bb Mon Sep 17 00:00:00 2001
From: aky15 <ankeyuthu@gmail.com>
Date: 星期二, 04 七月 2023 20:02:45 +0800
Subject: [PATCH] Dev bat (#701)

---
 funasr/bin/asr_inference_launch.py |   34 ++++++++++++++++++++--------------
 1 files changed, 20 insertions(+), 14 deletions(-)

diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 026874e..8310791 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -21,7 +21,6 @@
 import torchaudio
 import soundfile
 import yaml
-from typeguard import check_argument_types
 
 from funasr.bin.asr_infer import Speech2Text
 from funasr.bin.asr_infer import Speech2TextMFCCA
@@ -80,7 +79,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
     ncpu = kwargs.get("ncpu", 1)
     torch.set_num_threads(ncpu)
     if batch_size > 1:
@@ -240,7 +238,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
     ncpu = kwargs.get("ncpu", 1)
     torch.set_num_threads(ncpu)
 
@@ -293,7 +290,7 @@
         penalty=penalty,
         nbest=nbest,
         hotword_list_or_file=hotword_list_or_file,
-        clas_sacle=clas_scale,
+        clas_scale=clas_scale,
     )
 
     speech2text = Speech2TextParaformer(**speech2text_kwargs)
@@ -483,7 +480,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
     ncpu = kwargs.get("ncpu", 1)
     torch.set_num_threads(ncpu)
 
@@ -623,6 +619,22 @@
             sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
             results_sorted = []
             
+            if not len(sorted_data):
+                key = keys[0]
+                # no active segments after VAD
+                if writer is not None:
+                    # Write empty results
+                    ibest_writer["token"][key] = ""
+                    ibest_writer["token_int"][key] = ""
+                    ibest_writer["vad"][key] = ""
+                    ibest_writer["text"][key] = ""
+                    ibest_writer["text_with_punc"][key] = ""
+                    if use_timestamp:
+                        ibest_writer["time_stamp"][key] = ""
+
+                logging.info("decoding, utt: {}, empty speech".format(key))
+                continue
+
             batch_size_token_ms = batch_size_token*60
             if speech2text.device == "cpu":
                 batch_size_token_ms = 0
@@ -751,7 +763,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
 
     if word_lm_train_config is not None:
         raise NotImplementedError("Word LM is not implemented")
@@ -959,7 +970,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
     ncpu = kwargs.get("ncpu", 1)
     torch.set_num_threads(ncpu)
     if batch_size > 1:
@@ -1128,7 +1138,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
     ncpu = kwargs.get("ncpu", 1)
     torch.set_num_threads(ncpu)
     if batch_size > 1:
@@ -1316,7 +1325,6 @@
         right_context: Number of frames in right context AFTER subsampling.
         display_partial_hypotheses: Whether to display partial hypotheses.
     """
-    assert check_argument_types()
 
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
@@ -1359,10 +1367,7 @@
         left_context=left_context,
         right_context=right_context,
     )
-    speech2text = Speech2TextTransducer.from_pretrained(
-        model_tag=model_tag,
-        **speech2text_kwargs,
-    )
+    speech2text = Speech2TextTransducer(**speech2text_kwargs)
 
     def _forward(data_path_and_name_and_type,
                  raw_inputs: Union[np.ndarray, torch.Tensor] = None,
@@ -1466,7 +1471,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
     if word_lm_train_config is not None:
@@ -1615,6 +1619,8 @@
         return inference_mfcca(**kwargs)
     elif mode == "rnnt":
         return inference_transducer(**kwargs)
+    elif mode == "bat":
+        return inference_transducer(**kwargs)
     elif mode == "sa_asr":
         return inference_sa_asr(**kwargs)
     else:

--
Gitblit v1.9.1