From 4e2fe544ae37174a3e09dfcdbbdae5abfe711e53 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 05 七月 2023 16:57:21 +0800
Subject: [PATCH] funasr sdk

---
 funasr/bin/asr_inference_launch.py |  128 +++++++++++++++++++++---------------------
 1 files changed, 65 insertions(+), 63 deletions(-)

diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index a56552d..c32357e 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -19,8 +19,8 @@
 import numpy as np
 import torch
 import torchaudio
+import soundfile
 import yaml
-from typeguard import check_argument_types
 
 from funasr.bin.asr_infer import Speech2Text
 from funasr.bin.asr_infer import Speech2TextMFCCA
@@ -31,11 +31,10 @@
 from funasr.bin.punc_infer import Text2Punc
 from funasr.bin.tp_infer import Speech2Timestamp
 from funasr.bin.vad_infer import Speech2VadSegment
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
 from funasr.fileio.datadir_writer import DatadirWriter
 from funasr.modules.beam_search.beam_search import Hypothesis
 from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTask
-from funasr.tasks.vad import VADTask
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.utils import asr_utils, postprocess_utils
@@ -80,7 +79,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
     ncpu = kwargs.get("ncpu", 1)
     torch.set_num_threads(ncpu)
     if batch_size > 1:
@@ -142,18 +140,16 @@
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
             data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = ASRTask.build_streaming_iterator(
-            data_path_and_name_and_type,
+        loader = build_streaming_iterator(
+            task_name="asr",
+            preprocess_args=speech2text.asr_train_args,
+            data_path_and_name_and_type=data_path_and_name_and_type,
             dtype=dtype,
             fs=fs,
             mc=mc,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
-            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
-            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
         )
 
         finish_count = 0
@@ -242,7 +238,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
     ncpu = kwargs.get("ncpu", 1)
     torch.set_num_threads(ncpu)
 
@@ -260,8 +255,10 @@
     if param_dict is not None:
         hotword_list_or_file = param_dict.get('hotword')
         export_mode = param_dict.get("export_mode", False)
+        clas_scale = param_dict.get('clas_scale', 1.0)
     else:
         hotword_list_or_file = None
+        clas_scale = 1.0
 
     if kwargs.get("device", None) == "cpu":
         ngpu = 0
@@ -294,6 +291,7 @@
         penalty=penalty,
         nbest=nbest,
         hotword_list_or_file=hotword_list_or_file,
+        clas_scale=clas_scale,
     )
 
     speech2text = Speech2TextParaformer(**speech2text_kwargs)
@@ -329,17 +327,15 @@
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
             data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = ASRTask.build_streaming_iterator(
-            data_path_and_name_and_type,
+        loader = build_streaming_iterator(
+            task_name="asr",
+            preprocess_args=speech2text.asr_train_args,
+            data_path_and_name_and_type=data_path_and_name_and_type,
             dtype=dtype,
             fs=fs,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
-            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
-            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
         )
 
         if param_dict is not None:
@@ -485,7 +481,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
     ncpu = kwargs.get("ncpu", 1)
     torch.set_num_threads(ncpu)
 
@@ -580,17 +575,15 @@
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
             data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = ASRTask.build_streaming_iterator(
-            data_path_and_name_and_type,
+        loader = build_streaming_iterator(
+            task_name="asr",
+            preprocess_args=None,
+            data_path_and_name_and_type=data_path_and_name_and_type,
             dtype=dtype,
             fs=fs,
             batch_size=1,
             key_file=key_file,
             num_workers=num_workers,
-            preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
-            collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
         )
 
         if param_dict is not None:
@@ -626,7 +619,28 @@
             data_with_index = [(vadsegments[i], i) for i in range(n)]
             sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
             results_sorted = []
-            batch_size_token_ms = batch_size_token * 60
+            
+            if not len(sorted_data):
+                key = keys[0]
+                # no active segments after VAD
+                if writer is not None:
+                    # Write empty results
+                    ibest_writer["token"][key] = ""
+                    ibest_writer["token_int"][key] = ""
+                    ibest_writer["vad"][key] = ""
+                    ibest_writer["text"][key] = ""
+                    ibest_writer["text_with_punc"][key] = ""
+                    if use_timestamp:
+                        ibest_writer["time_stamp"][key] = ""
+
+                logging.info("decoding, utt: {}, empty speech".format(key))
+                continue
+
+            batch_size_token_ms = batch_size_token*60
+            if speech2text.device == "cpu":
+                batch_size_token_ms = 0
+            batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
+            
             batch_size_token_ms_cum = 0
             beg_idx = 0
             for j, _ in enumerate(range(0, n)):
@@ -750,7 +764,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
 
     if word_lm_train_config is not None:
         raise NotImplementedError("Word LM is not implemented")
@@ -865,7 +878,13 @@
             raw_inputs = _load_bytes(data_path_and_name_and_type[0])
             raw_inputs = torch.tensor(raw_inputs)
         if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
-            raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+            try:
+                raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+            except:
+                raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
+                if raw_inputs.ndim == 2:
+                    raw_inputs = raw_inputs[:, 0]
+                raw_inputs = torch.tensor(raw_inputs)
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, np.ndarray):
                 raw_inputs = torch.tensor(raw_inputs)
@@ -952,7 +971,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
     ncpu = kwargs.get("ncpu", 1)
     torch.set_num_threads(ncpu)
     if batch_size > 1:
@@ -1027,17 +1045,15 @@
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
             data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = ASRTask.build_streaming_iterator(
-            data_path_and_name_and_type,
+        loader = build_streaming_iterator(
+            task_name="asr",
+            preprocess_args=speech2text.asr_train_args,
+            data_path_and_name_and_type=data_path_and_name_and_type,
             dtype=dtype,
             fs=fs,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
-            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
-            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
         )
 
         finish_count = 0
@@ -1123,7 +1139,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
     ncpu = kwargs.get("ncpu", 1)
     torch.set_num_threads(ncpu)
     if batch_size > 1:
@@ -1182,18 +1197,16 @@
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
             data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = ASRTask.build_streaming_iterator(
-            data_path_and_name_and_type,
+        loader = build_streaming_iterator(
+            task_name="asr",
+            preprocess_args=speech2text.asr_train_args,
+            data_path_and_name_and_type=data_path_and_name_and_type,
             dtype=dtype,
             batch_size=batch_size,
             fs=fs,
             mc=True,
             key_file=key_file,
             num_workers=num_workers,
-            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
-            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
         )
 
         finish_count = 0
@@ -1313,7 +1326,6 @@
         right_context: Number of frames in right context AFTER subsampling.
         display_partial_hypotheses: Whether to display partial hypotheses.
     """
-    assert check_argument_types()
 
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
@@ -1356,10 +1368,7 @@
         left_context=left_context,
         right_context=right_context,
     )
-    speech2text = Speech2TextTransducer.from_pretrained(
-        model_tag=model_tag,
-        **speech2text_kwargs,
-    )
+    speech2text = Speech2TextTransducer(**speech2text_kwargs)
 
     def _forward(data_path_and_name_and_type,
                  raw_inputs: Union[np.ndarray, torch.Tensor] = None,
@@ -1369,20 +1378,14 @@
                  **kwargs,
                  ):
         # 3. Build data-iterator
-        loader = ASRTask.build_streaming_iterator(
-            data_path_and_name_and_type,
+        loader = build_streaming_iterator(
+            task_name="asr",
+            preprocess_args=speech2text.asr_train_args,
+            data_path_and_name_and_type=data_path_and_name_and_type,
             dtype=dtype,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
-            preprocess_fn=ASRTask.build_preprocess_fn(
-                speech2text.asr_train_args, False
-            ),
-            collate_fn=ASRTask.build_collate_fn(
-                speech2text.asr_train_args, False
-            ),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
         )
 
         # 4 .Start for-loop
@@ -1469,7 +1472,6 @@
         param_dict: dict = None,
         **kwargs,
 ):
-    assert check_argument_types()
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
     if word_lm_train_config is not None:
@@ -1529,18 +1531,16 @@
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
             data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = ASRTask.build_streaming_iterator(
-            data_path_and_name_and_type,
+        loader = build_streaming_iterator(
+            task_name="asr",
+            preprocess_args=speech2text.asr_train_args,
+            data_path_and_name_and_type=data_path_and_name_and_type,
             dtype=dtype,
             fs=fs,
             mc=mc,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
-            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
-            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
         )
 
         finish_count = 0
@@ -1620,6 +1620,8 @@
         return inference_mfcca(**kwargs)
     elif mode == "rnnt":
         return inference_transducer(**kwargs)
+    elif mode == "bat":
+        return inference_transducer(**kwargs)
     elif mode == "sa_asr":
         return inference_sa_asr(**kwargs)
     else:

--
Gitblit v1.9.1