From 91027ddab49e5791fc42569b4db9dafca55735e6 Mon Sep 17 00:00:00 2001
From: 凌匀 <ailsa.zly@alibaba-inc.com>
Date: 星期四, 16 二月 2023 22:11:18 +0800
Subject: [PATCH] fix vad results bug

---
 funasr/bin/asr_inference_paraformer_timestamp.py |   23 ++++++++---------------
 1 files changed, 8 insertions(+), 15 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer_timestamp.py b/funasr/bin/asr_inference_paraformer_timestamp.py
index 3fb8764..d0baa74 100644
--- a/funasr/bin/asr_inference_paraformer_timestamp.py
+++ b/funasr/bin/asr_inference_paraformer_timestamp.py
@@ -98,10 +98,13 @@
         logging.info("asr_train_args: {}".format(asr_train_args))
         asr_model.to(dtype=getattr(torch, dtype)).eval()
 
-        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+        if asr_model.ctc != None:
+            ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+            scorers.update(
+                ctc=ctc
+            )
         token_list = asr_model.token_list
         scorers.update(
-            ctc=ctc,
             length_bonus=LengthBonus(len(token_list)),
         )
 
@@ -169,7 +172,7 @@
         self.converter = converter
         self.tokenizer = tokenizer
         is_use_lm = lm_weight != 0.0 and lm_file is not None
-        if ctc_weight == 0.0 and not is_use_lm:
+        if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
             beam_search = None
         self.beam_search = beam_search
         logging.info(f"Beam_search: {self.beam_search}")
@@ -407,7 +410,7 @@
         results = speech2text(**batch)
         if len(results) < 1:
             hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
-            results = [[" ", ["<space>"], [2], hyp, 10, 6]] * nbest
+            results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
         time_end = time.time()
         forward_time = time_end - time_beg
         lfr_factor = results[0][-1]
@@ -433,7 +436,7 @@
                     ibest_writer["score"][key] = str(hyp.score)
     
                 if text is not None:
-                    text_postprocessed = postprocess_utils.sentence_postprocess(token)
+                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                     item = {'key': key, 'value': text_postprocessed}
                     asr_result_list.append(item)
                     finish_count += 1
@@ -450,16 +453,6 @@
     logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
                  format(length_total, forward_time_total, 100 * forward_time_total / (length_total*lfr_factor)))
     return asr_result_list
-
-
-def set_parameters(language: str = None,
-                   sample_rate: Union[int, Dict[Any, int]] = None):
-    if language is not None:
-        global global_asr_language
-        global_asr_language = language
-    if sample_rate is not None:
-        global global_sample_rate
-        global_sample_rate = sample_rate
 
 
 def get_parser():

--
Gitblit v1.9.1