From 0856ea2ebdcb976db6e786de5cd79fae3d35cd4c Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 20 二月 2023 18:18:35 +0800
Subject: [PATCH] Merge pull request #136 from alibaba-damo-academy/dev_cmz
---
funasr/bin/asr_inference_paraformer_vad_punc.py | 85 ++++++++++++++++++++++++++++++++++++++----
1 files changed, 77 insertions(+), 8 deletions(-)
diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index ee36135..408b5b9 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -5,6 +5,10 @@
import logging
import sys
import time
+import os
+import codecs
+import tempfile
+import requests
from pathlib import Path
from typing import Optional
from typing import Sequence
@@ -39,9 +43,9 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_lfr6, time_stamp_lfr6_pl
+from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
from funasr.bin.punctuation_infer import Text2Punc
-from funasr.models.e2e_asr_paraformer import BiCifParaformer
+from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
header_colors = '\033[95m'
end_colors = '\033[0m'
@@ -79,6 +83,7 @@
penalty: float = 0.0,
nbest: int = 1,
frontend_conf: dict = None,
+ hotword_list_or_file: str = None,
**kwargs,
):
assert check_argument_types()
@@ -169,6 +174,58 @@
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
+
+ # 6. [Optional] Build hotword list from str, local file or url
+ # for None
+ if hotword_list_or_file is None:
+ self.hotword_list = None
+ # for text str input
+ elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'):
+ logging.info("Attempting to parse hotwords as str...")
+ self.hotword_list = []
+ hotword_str_list = []
+ for hw in hotword_list_or_file.strip().split():
+ hotword_str_list.append(hw)
+ self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ self.hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Hotword list: {}.".format(hotword_str_list))
+ # for local txt inputs
+ elif os.path.exists(hotword_list_or_file):
+ logging.info("Attempting to parse hotwords from local txt...")
+ self.hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hotword_str_list.append(hw)
+ self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ self.hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ # for url, download and generate txt
+ else:
+ logging.info("Attempting to parse hotwords from url...")
+ work_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(work_dir):
+ os.makedirs(work_dir)
+ text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+ local_file = requests.get(hotword_list_or_file)
+ open(text_file_path, "wb").write(local_file.content)
+ hotword_list_or_file = text_file_path
+ self.hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hotword_str_list.append(hw)
+ self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ self.hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+
is_use_lm = lm_weight != 0.0 and lm_file is not None
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
beam_search = None
@@ -233,8 +290,15 @@
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+ if not isinstance(self.asr_model, ContextualParaformer):
+ if self.hotword_list:
+ logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ else:
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
if isinstance(self.asr_model, BiCifParaformer):
_, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
@@ -286,9 +350,7 @@
timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
else:
- time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time,
- end_time)
- results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
+ results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
# assert check_return_type(results)
return results
@@ -515,6 +577,11 @@
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+ else:
+ hotword_list_or_file = None
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
@@ -553,6 +620,7 @@
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
+ hotword_list_or_file=hotword_list_or_file,
)
speech2text = Speech2Text(**speech2text_kwargs)
text2punc = None
@@ -636,7 +704,8 @@
text, token, token_int = result[0], result[1], result[2]
time_stamp = None if len(result) < 4 else result[3]
- if use_timestamp and time_stamp is not None:
+
+ if use_timestamp and time_stamp is not None:
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
else:
postprocessed_result = postprocess_utils.sentence_postprocess(token)
--
Gitblit v1.9.1