From 23ec48ebeb8d90fa4ee836f8dbbdfc3cf32e9f3b Mon Sep 17 00:00:00 2001
From: huangmingming <dyyzhmm@163.com>
Date: 星期二, 31 一月 2023 14:03:53 +0800
Subject: [PATCH] update server env

---
 funasr/bin/asr_inference_paraformer.py |   13 ++++++++-----
 1 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 1a73457..01237f7 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -95,10 +95,13 @@
         logging.info("asr_train_args: {}".format(asr_train_args))
         asr_model.to(dtype=getattr(torch, dtype)).eval()
 
-        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+        if asr_model.ctc != None:
+            ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+            scorers.update(
+                ctc=ctc
+            )
         token_list = asr_model.token_list
         scorers.update(
-            ctc=ctc,
             length_bonus=LengthBonus(len(token_list)),
         )
 
@@ -166,7 +169,7 @@
         self.converter = converter
         self.tokenizer = tokenizer
         is_use_lm = lm_weight != 0.0 and lm_file is not None
-        if ctc_weight == 0.0 and not is_use_lm:
+        if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
             beam_search = None
         self.beam_search = beam_search
         logging.info(f"Beam_search: {self.beam_search}")
@@ -259,7 +262,7 @@
                     token_int = hyp.yseq[1:last_pos].tolist()
 
                 # remove blank symbol id, which is assumed to be 0
-                token_int = list(filter(lambda x: x != 0, token_int))
+                token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
 
                 # Change integer-ids to tokens
                 token = self.converter.ids2tokens(token_int)
@@ -650,7 +653,7 @@
                         finish_count += 1
                         # asr_utils.print_progress(finish_count / file_count)
                         if writer is not None:
-                            ibest_writer["text"][key] = text
+                            ibest_writer["text"][key] = text_postprocessed
 
                     logging.info("decoding, utt: {}, predictions: {}".format(key, text))
         rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))

--
Gitblit v1.9.1