From 331d57253ae25dd42c8e14930dee30cd8d2affa6 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 四月 2023 11:41:18 +0800
Subject: [PATCH] Merge pull request #408 from alibaba-damo-academy/hnluo-patch-1

---
 funasr/bin/asr_inference_paraformer_streaming.py |   32 +++++++++++++++++++++++++++++---
 1 files changed, 29 insertions(+), 3 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
index df590df..821f694 100644
--- a/funasr/bin/asr_inference_paraformer_streaming.py
+++ b/funasr/bin/asr_inference_paraformer_streaming.py
@@ -19,6 +19,7 @@
 
 import numpy as np
 import torch
+import torchaudio
 from typeguard import check_argument_types
 
 from funasr.fileio.datadir_writer import DatadirWriter
@@ -535,6 +536,8 @@
         **kwargs,
 ):
     assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
 
     if word_lm_train_config is not None:
         raise NotImplementedError("Word LM is not implemented")
@@ -581,7 +584,22 @@
         speech2text = Speech2TextExport(**speech2text_kwargs)
     else:
         speech2text = Speech2Text(**speech2text_kwargs)
+        
+    def _load_bytes(input):
+        middle_data = np.frombuffer(input, dtype=np.int16)
+        middle_data = np.asarray(middle_data)
+        if middle_data.dtype.kind not in 'iu':
+            raise TypeError("'middle_data' must be an array of integers")
+        dtype = np.dtype('float32')
+        if dtype.kind != 'f':
+            raise TypeError("'dtype' must be a floating point type")
 
+        i = np.iinfo(middle_data.dtype)
+        abs_max = 2 ** (i.bits - 1)
+        offset = i.min + abs_max
+        array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
+        return array
+    
     def _forward(
             data_path_and_name_and_type,
             raw_inputs: Union[np.ndarray, torch.Tensor] = None,
@@ -592,14 +610,22 @@
     ):
 
         # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, np.ndarray):
-                raw_inputs = torch.tensor(raw_inputs)
         is_final = False
+        cache = {}
         if param_dict is not None and "cache" in param_dict:
             cache = param_dict["cache"]
         if param_dict is not None and "is_final" in param_dict:
             is_final = param_dict["is_final"]
+
+        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
+            raw_inputs = _load_bytes(data_path_and_name_and_type[0])
+            raw_inputs = torch.tensor(raw_inputs)
+        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
+            raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+            is_final = True
+        if data_path_and_name_and_type is None and raw_inputs is not None:
+            if isinstance(raw_inputs, np.ndarray):
+                raw_inputs = torch.tensor(raw_inputs)
         # 7 .Start for-loop
         # FIXME(kamo): The output format should be discussed about
         asr_result_list = []

--
Gitblit v1.9.1