From d1a3a7ad90ce9deb7fb4940970bca0abb9409181 Mon Sep 17 00:00:00 2001
From: Lizerui9926 <110582652+Lizerui9926@users.noreply.github.com>
Date: 星期四, 09 二月 2023 20:44:30 +0800
Subject: [PATCH] Merge pull request #89 from alibaba-damo-academy/dev_lzr
---
funasr/bin/asr_inference_paraformer.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++++---
1 files changed, 52 insertions(+), 3 deletions(-)
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 3769b6c..709c5bf 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -3,6 +3,9 @@
import logging
import sys
import time
+import copy
+import os
+import codecs
from pathlib import Path
from typing import Optional
from typing import Sequence
@@ -35,6 +38,8 @@
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
+
header_colors = '\033[95m'
end_colors = '\033[0m'
@@ -78,6 +83,7 @@
penalty: float = 0.0,
nbest: int = 1,
frontend_conf: dict = None,
+ hotword_list_or_file: str = None,
**kwargs,
):
assert check_argument_types()
@@ -168,6 +174,34 @@
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
+
+ # 6. [Optional] Build hotword list from file or str
+ if hotword_list_or_file is None:
+ self.hotword_list = None
+ elif os.path.exists(hotword_list_or_file):
+ self.hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hotword_str_list.append(hw)
+ self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ self.hotword_list.append([1])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ else:
+ logging.info("Attempting to parse hotwords as str...")
+ self.hotword_list = []
+ hotword_str_list = []
+ for hw in hotword_list_or_file.strip().split():
+ hotword_str_list.append(hw)
+ self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ self.hotword_list.append([1])
+ hotword_str_list.append('<s>')
+ logging.info("Hotword list: {}.".format(hotword_str_list))
+
+
is_use_lm = lm_weight != 0.0 and lm_file is not None
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
beam_search = None
@@ -229,8 +263,14 @@
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ if not isinstance(self.asr_model, ContextualParaformer):
+ if self.hotword_list:
+ logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ else:
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
results = []
b, n, d = decoder_out.size()
@@ -388,6 +428,7 @@
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
+ hotword_list_or_file = param_dict['hotword']
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
@@ -416,6 +457,7 @@
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
+ hotword_list_or_file=hotword_list_or_file,
)
speech2text = Speech2Text(**speech2text_kwargs)
@@ -551,7 +593,12 @@
default=1,
help="The number of workers used for DataLoader",
)
-
+ parser.add_argument(
+ "--hotword",
+ type=str_or_none,
+ default=None,
+ help="hotword file path or hotwords seperated by space"
+ )
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
@@ -679,8 +726,10 @@
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
+ param_dict = {'hotword': args.hotword}
kwargs = vars(args)
kwargs.pop("config", None)
+ kwargs['param_dict'] = param_dict
inference(**kwargs)
--
Gitblit v1.9.1