From 2778f60346b85df2b32fdfa04133fa78c4841dc4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 21 二月 2023 17:45:59 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/bin/asr_inference_paraformer.py | 118 +++++++++++++----------
funasr/bin/asr_inference_mfcca.py | 1
funasr/bin/asr_inference_uniasr.py | 1
funasr/bin/asr_inference_uniasr_vad.py | 1
funasr/bin/asr_inference.py | 1
funasr/bin/asr_inference_paraformer_vad.py | 12 ++
funasr/bin/asr_inference_paraformer_vad_punc.py | 116 +++++++++++++---------
7 files changed, 151 insertions(+), 99 deletions(-)
diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index ca8f2bc..318d3d7 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -395,6 +395,7 @@
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
+ **kwargs,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
diff --git a/funasr/bin/asr_inference_mfcca.py b/funasr/bin/asr_inference_mfcca.py
index ba03449..e25b2a9 100644
--- a/funasr/bin/asr_inference_mfcca.py
+++ b/funasr/bin/asr_inference_mfcca.py
@@ -523,6 +523,7 @@
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
+ **kwargs,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 18d788e..487f750 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -169,56 +169,8 @@
self.tokenizer = tokenizer
# 6. [Optional] Build hotword list from str, local file or url
- # for None
- if hotword_list_or_file is None:
- self.hotword_list = None
- # for text str input
- elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'):
- logging.info("Attempting to parse hotwords as str...")
- self.hotword_list = []
- hotword_str_list = []
- for hw in hotword_list_or_file.strip().split():
- hotword_str_list.append(hw)
- self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- self.hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Hotword list: {}.".format(hotword_str_list))
- # for local txt inputs
- elif os.path.exists(hotword_list_or_file):
- logging.info("Attempting to parse hotwords from local txt...")
- self.hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hotword_str_list.append(hw)
- self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- self.hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for url, download and generate txt
- else:
- logging.info("Attempting to parse hotwords from url...")
- work_dir = tempfile.TemporaryDirectory().name
- if not os.path.exists(work_dir):
- os.makedirs(work_dir)
- text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
- local_file = requests.get(hotword_list_or_file)
- open(text_file_path, "wb").write(local_file.content)
- hotword_list_or_file = text_file_path
- self.hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hotword_str_list.append(hw)
- self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- self.hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
-
+ self.hotword_list = None
+ self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
is_use_lm = lm_weight != 0.0 and lm_file is not None
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
@@ -336,6 +288,60 @@
# assert check_return_type(results)
return results
+
+ def generate_hotwords_list(self, hotword_list_or_file):
+ # for None
+ if hotword_list_or_file is None:
+ hotword_list = None
+ # for local txt inputs
+ elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
+ logging.info("Attempting to parse hotwords from local txt...")
+ hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hotword_str_list.append(hw)
+ hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ # for url, download and generate txt
+ elif hotword_list_or_file.startswith('http'):
+ logging.info("Attempting to parse hotwords from url...")
+ work_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(work_dir):
+ os.makedirs(work_dir)
+ text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+ local_file = requests.get(hotword_list_or_file)
+ open(text_file_path, "wb").write(local_file.content)
+ hotword_list_or_file = text_file_path
+ hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hotword_str_list.append(hw)
+ hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ # for text str input
+ elif not hotword_list_or_file.endswith('.txt'):
+ logging.info("Attempting to parse hotwords as str...")
+ hotword_list = []
+ hotword_str_list = []
+ for hw in hotword_list_or_file.strip().split():
+ hotword_str_list.append(hw)
+ hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Hotword list: {}.".format(hotword_str_list))
+ else:
+ hotword_list = None
+ return hotword_list
class Speech2TextExport:
"""Speech2TextExport class
@@ -648,7 +654,19 @@
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
+ **kwargs,
):
+
+ hotword_list_or_file = None
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+
+ if 'hotword' in kwargs:
+ hotword_list_or_file = kwargs['hotword']
+
+ if speech2text.hotword_list is None:
+ speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py
index 78fc5f3..a0dc0aa 100644
--- a/funasr/bin/asr_inference_paraformer_vad.py
+++ b/funasr/bin/asr_inference_paraformer_vad.py
@@ -228,7 +228,19 @@
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
+ **kwargs,
):
+
+ hotword_list_or_file = None
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+
+ if 'hotword' in kwargs:
+ hotword_list_or_file = kwargs['hotword']
+
+ if speech2text.hotword_list is None:
+ speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index 408b5b9..a6e1394 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -176,55 +176,8 @@
self.tokenizer = tokenizer
# 6. [Optional] Build hotword list from str, local file or url
- # for None
- if hotword_list_or_file is None:
- self.hotword_list = None
- # for text str input
- elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'):
- logging.info("Attempting to parse hotwords as str...")
- self.hotword_list = []
- hotword_str_list = []
- for hw in hotword_list_or_file.strip().split():
- hotword_str_list.append(hw)
- self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- self.hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Hotword list: {}.".format(hotword_str_list))
- # for local txt inputs
- elif os.path.exists(hotword_list_or_file):
- logging.info("Attempting to parse hotwords from local txt...")
- self.hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hotword_str_list.append(hw)
- self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- self.hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for url, download and generate txt
- else:
- logging.info("Attempting to parse hotwords from url...")
- work_dir = tempfile.TemporaryDirectory().name
- if not os.path.exists(work_dir):
- os.makedirs(work_dir)
- text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
- local_file = requests.get(hotword_list_or_file)
- open(text_file_path, "wb").write(local_file.content)
- hotword_list_or_file = text_file_path
- self.hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hotword_str_list.append(hw)
- self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- self.hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
+ self.hotword_list = None
+ self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
is_use_lm = lm_weight != 0.0 and lm_file is not None
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
@@ -355,6 +308,59 @@
# assert check_return_type(results)
return results
+ def generate_hotwords_list(self, hotword_list_or_file):
+ # for None
+ if hotword_list_or_file is None:
+ hotword_list = None
+ # for local txt inputs
+ elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
+ logging.info("Attempting to parse hotwords from local txt...")
+ hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hotword_str_list.append(hw)
+ hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ # for url, download and generate txt
+ elif hotword_list_or_file.startswith('http'):
+ logging.info("Attempting to parse hotwords from url...")
+ work_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(work_dir):
+ os.makedirs(work_dir)
+ text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+ local_file = requests.get(hotword_list_or_file)
+ open(text_file_path, "wb").write(local_file.content)
+ hotword_list_or_file = text_file_path
+ hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hotword_str_list.append(hw)
+ hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ # for text str input
+ elif not hotword_list_or_file.endswith('.txt'):
+ logging.info("Attempting to parse hotwords as str...")
+ hotword_list = []
+ hotword_str_list = []
+ for hw in hotword_list_or_file.strip().split():
+ hotword_str_list.append(hw)
+ hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Hotword list: {}.".format(hotword_str_list))
+ else:
+ hotword_list = None
+ return hotword_list
class Speech2VadSegment:
"""Speech2VadSegment class
@@ -637,7 +643,19 @@
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
+ **kwargs,
):
+
+ hotword_list_or_file = None
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+
+ if 'hotword' in kwargs:
+ hotword_list_or_file = kwargs['hotword']
+
+ if speech2text.hotword_list is None:
+ speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
index db09d31..c50bf17 100644
--- a/funasr/bin/asr_inference_uniasr.py
+++ b/funasr/bin/asr_inference_uniasr.py
@@ -433,6 +433,7 @@
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
+ **kwargs,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
diff --git a/funasr/bin/asr_inference_uniasr_vad.py b/funasr/bin/asr_inference_uniasr_vad.py
index de32dcf..ac3b4b6 100644
--- a/funasr/bin/asr_inference_uniasr_vad.py
+++ b/funasr/bin/asr_inference_uniasr_vad.py
@@ -433,6 +433,7 @@
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
+ **kwargs,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
--
Gitblit v1.9.1