From 91027ddab49e5791fc42569b4db9dafca55735e6 Mon Sep 17 00:00:00 2001
From: 凌匀 <ailsa.zly@alibaba-inc.com>
Date: 星期四, 16 二月 2023 22:11:18 +0800
Subject: [PATCH] fix vad results bug
---
funasr/bin/asr_inference_paraformer_timestamp.py | 23 ++++++++---------------
1 files changed, 8 insertions(+), 15 deletions(-)
diff --git a/funasr/bin/asr_inference_paraformer_timestamp.py b/funasr/bin/asr_inference_paraformer_timestamp.py
index 3fb8764..d0baa74 100644
--- a/funasr/bin/asr_inference_paraformer_timestamp.py
+++ b/funasr/bin/asr_inference_paraformer_timestamp.py
@@ -98,10 +98,13 @@
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ if asr_model.ctc != None:
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ scorers.update(
+ ctc=ctc
+ )
token_list = asr_model.token_list
scorers.update(
- ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
@@ -169,7 +172,7 @@
self.converter = converter
self.tokenizer = tokenizer
is_use_lm = lm_weight != 0.0 and lm_file is not None
- if ctc_weight == 0.0 and not is_use_lm:
+ if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
beam_search = None
self.beam_search = beam_search
logging.info(f"Beam_search: {self.beam_search}")
@@ -407,7 +410,7 @@
results = speech2text(**batch)
if len(results) < 1:
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["<space>"], [2], hyp, 10, 6]] * nbest
+ results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
time_end = time.time()
forward_time = time_end - time_beg
lfr_factor = results[0][-1]
@@ -433,7 +436,7 @@
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
@@ -450,16 +453,6 @@
logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
format(length_total, forward_time_total, 100 * forward_time_total / (length_total*lfr_factor)))
return asr_result_list
-
-
-def set_parameters(language: str = None,
- sample_rate: Union[int, Dict[Any, int]] = None):
- if language is not None:
- global global_asr_language
- global_asr_language = language
- if sample_rate is not None:
- global global_sample_rate
- global_sample_rate = sample_rate
def get_parser():
--
Gitblit v1.9.1