From 6013d3c4a9afa81f8f7bdca4e3d9fe84639eb39b Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期五, 28 四月 2023 16:38:03 +0800
Subject: [PATCH] Merge pull request #441 from alibaba-damo-academy/dev_lhn

---
 funasr/bin/asr_inference_paraformer_streaming.py |   35 ++++++++++++++++++++++++++++++-----
 1 files changed, 30 insertions(+), 5 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
index ff8bb8c..bf5590c 100644
--- a/funasr/bin/asr_inference_paraformer_streaming.py
+++ b/funasr/bin/asr_inference_paraformer_streaming.py
@@ -20,6 +20,7 @@
 
 import numpy as np
 import torch
+import torchaudio
 from typeguard import check_argument_types
 
 from funasr.fileio.datadir_writer import DatadirWriter
@@ -204,9 +205,12 @@
         results = []
         cache_en = cache["encoder"]
         if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
+            if cache_en["start_idx"] == 0:
+                return []
             cache_en["tail_chunk"] = True
             feats = cache_en["feats"]
             feats_len = torch.tensor([feats.shape[1]])
+            self.asr_model.frontend = None
             results = self.infer(feats, feats_len, cache)
             return results
         else:
@@ -515,6 +519,8 @@
         if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
             raw_inputs = _load_bytes(data_path_and_name_and_type[0])
             raw_inputs = torch.tensor(raw_inputs)
+        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
+            raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, np.ndarray):
                 raw_inputs = torch.tensor(raw_inputs)
@@ -531,13 +537,32 @@
         # 7 .Start for-loop
         # FIXME(kamo): The output format should be discussed about
         raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
-        input_lens = torch.tensor([raw_inputs.shape[1]])
         asr_result_list = []
-
         cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
-        cache["encoder"]["is_final"] = is_final
-        asr_result = speech2text(cache, raw_inputs, input_lens)
-        item = {'key': "utt", 'value': asr_result}
+        item = {}
+        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
+            sample_offset = 0
+            speech_length = raw_inputs.shape[1]
+            stride_size =  chunk_size[1] * 960
+            cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
+            final_result = ""
+            for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
+                if sample_offset + stride_size >= speech_length - 1:
+                    stride_size = speech_length - sample_offset
+                    cache["encoder"]["is_final"] = True
+                else:
+                    cache["encoder"]["is_final"] = False
+                input_lens = torch.tensor([stride_size])
+                asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
+                if len(asr_result) != 0: 
+                    final_result += asr_result[0]
+            item = {'key': "utt", 'value': [final_result]}
+        else:
+            input_lens = torch.tensor([raw_inputs.shape[1]])
+            cache["encoder"]["is_final"] = is_final
+            asr_result = speech2text(cache, raw_inputs, input_lens)
+            item = {'key': "utt", 'value': asr_result}
+
         asr_result_list.append(item)
         if is_final:
             cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1)

--
Gitblit v1.9.1