From f6b611de44c3a535befa96da552d07b0ed1b073c Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 十二月 2023 15:52:16 +0800
Subject: [PATCH] funasr1.0
---
funasr/bin/inference.py | 21 ++++++++++++++++++++-
1 files changed, 20 insertions(+), 1 deletions(-)
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index fda7abe..16ad0e2 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -18,6 +18,7 @@
from funasr.register import tables
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio
from funasr.utils.vad_utils import slice_padding_audio_samples
+from funasr.utils.timestamp_tools import time_stamp_sentence
def build_iter_for_infer(data_in, input_len=None, data_type="sound"):
"""
@@ -46,7 +47,7 @@
data = lines["source"]
key = data["key"] if "key" in data else key
else: # filelist, wav.scp, text.txt: id \t data or data
- lines = line.strip().split()
+ lines = line.strip().split(maxsplit=1)
data = lines[1] if len(lines)>1 else lines[0]
key = lines[0] if len(lines)>1 else key
@@ -227,6 +228,7 @@
# step.1: compute the vad model
model = self.vad_model
kwargs = self.vad_kwargs
+ kwargs.update(cfg)
beg_vad = time.time()
res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg)
end_vad = time.time()
@@ -322,6 +324,23 @@
result["key"] = key
results_ret_list.append(result)
pbar_total.update(1)
+
+ # step.3 compute punc model
+ model = self.punc_model
+ kwargs = self.punc_kwargs
+ kwargs.update(cfg)
+
+ for i, result in enumerate(results_ret_list):
+ beg_punc = time.time()
+ res = self.generate(result["text"], model=model, kwargs=kwargs, **cfg)
+ end_punc = time.time()
+ print(f"time punc: {end_punc - beg_punc:0.3f}")
+
+ # sentences = time_stamp_sentence(model.punc_list, model.sentence_end_id, results_ret_list[i]["timestamp"], res[i]["text"])
+ # results_ret_list[i]["time_stamp"] = res[0]["text_postprocessed_punc"]
+ # results_ret_list[i]["sentences"] = sentences
+ # results_ret_list[i]["text_with_punc"] = res[i]["text"]
+
pbar_total.update(1)
end_total = time.time()
time_escape_total_all_samples = end_total - beg_total
--
Gitblit v1.9.1