From 9b5388fae9c8ece84d5440279c90841514e9d52b Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期三, 18 一月 2023 15:25:43 +0800
Subject: [PATCH] fix ctc beamsearch
---
funasr/bin/asr_inference_paraformer.py | 2 +-
funasr/bin/asr_inference_paraformer_timestamp.py | 2 +-
funasr/bin/asr_inference_paraformer_vad_punc.py | 4 ++--
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index a50e038..01237f7 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -169,7 +169,7 @@
self.converter = converter
self.tokenizer = tokenizer
is_use_lm = lm_weight != 0.0 and lm_file is not None
- if ctc_weight == 0.0 and not is_use_lm:
+ if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
beam_search = None
self.beam_search = beam_search
logging.info(f"Beam_search: {self.beam_search}")
diff --git a/funasr/bin/asr_inference_paraformer_timestamp.py b/funasr/bin/asr_inference_paraformer_timestamp.py
index b646987..9560cb0 100644
--- a/funasr/bin/asr_inference_paraformer_timestamp.py
+++ b/funasr/bin/asr_inference_paraformer_timestamp.py
@@ -172,7 +172,7 @@
self.converter = converter
self.tokenizer = tokenizer
is_use_lm = lm_weight != 0.0 and lm_file is not None
- if ctc_weight == 0.0 and not is_use_lm:
+ if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
beam_search = None
self.beam_search = beam_search
logging.info(f"Beam_search: {self.beam_search}")
diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index 85838aa..cb33b16 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -174,7 +174,7 @@
self.converter = converter
self.tokenizer = tokenizer
is_use_lm = lm_weight != 0.0 and lm_file is not None
- if ctc_weight == 0.0 and not is_use_lm:
+ if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
beam_search = None
self.beam_search = beam_search
logging.info(f"Beam_search: {self.beam_search}")
@@ -666,7 +666,7 @@
time_stamp_postprocessed))
logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
- format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)))
+ format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor+1e-6)))
return asr_result_list
return _forward
--
Gitblit v1.9.1