From ddbc8b5eded1fff6084001d160d46b532020ecb7 Mon Sep 17 00:00:00 2001
From: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Date: 星期一, 15 一月 2024 20:36:20 +0800
Subject: [PATCH] Merge pull request #1247 from alibaba-damo-academy/funasr1.0
---
funasr/bin/inference.py | 20 +++++++-------------
1 files changed, 7 insertions(+), 13 deletions(-)
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index 3aab31a..7368d16 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -175,7 +175,7 @@
# build tokenizer
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
- tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower())
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
kwargs["token_list"] = tokenizer.token_list
@@ -186,13 +186,13 @@
# build frontend
frontend = kwargs.get("frontend", None)
if frontend is not None:
- frontend_class = tables.frontend_classes.get(frontend.lower())
+ frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
# build model
- model_class = tables.model_classes.get(kwargs["model"].lower())
+ model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
model.eval()
model.to(device)
@@ -245,7 +245,7 @@
time1 = time.perf_counter()
with torch.no_grad():
- results, meta_data = model.generate(**batch, **kwargs)
+ results, meta_data = model.inference(**batch, **kwargs)
time2 = time.perf_counter()
asr_result_list.extend(results)
@@ -274,12 +274,9 @@
def generate_with_vad(self, input, input_len=None, **cfg):
# step.1: compute the vad model
- model = self.vad_model
- kwargs = self.vad_kwargs
- kwargs.update(cfg)
+ self.vad_kwargs.update(cfg)
beg_vad = time.time()
- res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg)
- vad_res = res
+ res = self.generate(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
end_vad = time.time()
print(f"time cost vad: {end_vad - beg_vad:0.3f}")
@@ -312,10 +309,7 @@
if not len(sorted_data):
logging.info("decoding, utt: {}, empty speech".format(key))
continue
-
- # if kwargs["device"] == "cpu":
- # batch_size = 0
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
@@ -443,7 +437,7 @@
# build frontend
frontend = kwargs.get("frontend", None)
if frontend is not None:
- frontend_class = tables.frontend_classes.get(frontend.lower())
+ frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])
self.frontend = frontend
--
Gitblit v1.9.1