From 79bd015ab0ded4e5aed1b1ecf32fcbc84eefde68 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 03 二月 2023 17:42:47 +0800
Subject: [PATCH] Merge pull request #58 from alibaba-damo-academy/dev
---
funasr/bin/asr_inference_paraformer.py | 2 +
funasr/bin/asr_inference_uniasr.py | 2 +
funasr/bin/asr_inference.py | 2 +
funasr/bin/sv_inference.py | 32 ++++++++-------
funasr/bin/asr_inference_paraformer_vad_punc.py | 8 +---
funasr/bin/asr_inference_launch.py | 27 +++++++++++++
6 files changed, 51 insertions(+), 22 deletions(-)
diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index 4c6d2a4..f7ad9ae 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -483,6 +483,7 @@
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
+ param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
@@ -533,6 +534,7 @@
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
+ param_dict: dict = None,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 67a85d2..d72fd4b 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -223,6 +223,31 @@
logging.info("Unknown decoding mode: {}".format(mode))
return None
+def inference_launch_funasr(**kwargs):
+ if 'mode' in kwargs:
+ mode = kwargs['mode']
+ else:
+ logging.info("Unknown decoding mode.")
+ return None
+ if mode == "asr":
+ from funasr.bin.asr_inference import inference
+ return inference(**kwargs)
+ elif mode == "uniasr":
+ from funasr.bin.asr_inference_uniasr import inference
+ return inference(**kwargs)
+ elif mode == "paraformer":
+ from funasr.bin.asr_inference_paraformer import inference
+ return inference(**kwargs)
+ elif mode == "paraformer_vad_punc":
+ from funasr.bin.asr_inference_paraformer_vad_punc import inference
+ return inference(**kwargs)
+ elif mode == "vad":
+ from funasr.bin.vad_inference import inference
+ return inference(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
@@ -251,7 +276,7 @@
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
- inference_launch(**kwargs)
+ inference_launch_funasr(**kwargs)
if __name__ == "__main__":
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index cb140eb..0929436 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -529,6 +529,7 @@
nbest: int = 1,
num_workers: int = 1,
output_dir: Optional[str] = None,
+ param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
@@ -578,6 +579,7 @@
data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
+ param_dict: dict = None,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index 7a289aa..10c1da6 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -659,12 +659,8 @@
punc_id_list = ""
text_postprocessed_punc = ""
- item = {'key': key, 'value': text_postprocessed_punc_time_stamp, 'text': text_postprocessed,
- 'time_stamp': time_stamp_postprocessed, 'punc': punc_id_list, 'token': token}
- if outputs_dict:
- item = {'text_punc': text_postprocessed_punc, 'text': text_postprocessed,
- 'punc_id': punc_id_list, 'token': token, 'time_stamp': time_stamp_postprocessed}
- item = {'key': key, 'value': item}
+ item = {'key': key, 'value': text_postprocessed_punc, 'text_postprocessed': text_postprocessed,
+ 'time_stamp': time_stamp_postprocessed, 'token': token}
asr_result_list.append(item)
finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
index 4ecb1cc..a4bdcc1 100644
--- a/funasr/bin/asr_inference_uniasr.py
+++ b/funasr/bin/asr_inference_uniasr.py
@@ -521,6 +521,7 @@
token_num_relax: int = 1,
decoding_ind: int = 0,
decoding_mode: str = "model1",
+ param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
@@ -574,6 +575,7 @@
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
+ param_dict: dict = None,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
diff --git a/funasr/bin/sv_inference.py b/funasr/bin/sv_inference.py
index b0fae38..6da696a 100755
--- a/funasr/bin/sv_inference.py
+++ b/funasr/bin/sv_inference.py
@@ -171,6 +171,7 @@
streaming: bool = False,
embedding_node: str = "resnet1_dense",
sv_threshold: float = 0.9465,
+ param_dict: Optional[dict] = None,
**kwargs,
):
assert check_argument_types()
@@ -183,6 +184,7 @@
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
+ logging.info("param_dict: {}".format(param_dict))
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
@@ -212,7 +214,9 @@
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
+ param_dict: Optional[dict] = None,
):
+ logging.info("param_dict: {}".format(param_dict))
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
@@ -233,11 +237,10 @@
# 7 .Start for-loop
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- embd_fd, ref_emb_fd, score_fd = None, None, None
+ embd_writer, ref_embd_writer, score_writer = None, None, None
if output_path is not None:
os.makedirs(output_path, exist_ok=True)
- embd_writer = WriteHelper("ark:{}/xvector.ark".format(output_path))
- # embd_fd = open(os.path.join(output_path, "xvector.ark"), "wb")
+ embd_writer = WriteHelper("ark,scp:{}/xvector.ark,{}/xvector.scp".format(output_path, output_path))
sv_result_list = []
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
@@ -249,6 +252,7 @@
embedding, ref_embedding, score = speech2xvector(**batch)
# Only supporting batch_size==1
key = keys[0]
+ normalized_score = 0.0
if score is not None:
score = score.item()
normalized_score = max(score - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
@@ -257,23 +261,21 @@
item = {"key": key, "value": embedding.squeeze(0).cpu().numpy()}
sv_result_list.append(item)
if output_path is not None:
- # kaldiio.save_mat(embd_fd, embedding[0].cpu().numpy(), key)
embd_writer(key, embedding[0].cpu().numpy())
if ref_embedding is not None:
- if ref_emb_fd is None:
- # ref_emb_fd = open(os.path.join(output_path, "ref_xvector.ark"), "wb")
- ref_embd_writer = WriteHelper("ark:{}/ref_xvector.ark".format(output_path))
- score_fd = open(os.path.join(output_path, "score.txt"), "w")
- # kaldiio.save_mat(ref_emb_fd, ref_embedding[0].cpu().numpy(), key)
+ if ref_embd_writer is None:
+ ref_embd_writer = WriteHelper(
+ "ark,scp:{}/ref_xvector.ark,{}/ref_xvector.scp".format(output_path, output_path)
+ )
+ score_writer = open(os.path.join(output_path, "score.txt"), "w")
ref_embd_writer(key, ref_embedding[0].cpu().numpy())
- score_fd.write("{:.6f}\n".format(score.item()))
+ score_writer.write("{} {:.6f}\n".format(key, normalized_score))
+
if output_path is not None:
- # embd_fd.close()
embd_writer.close()
- if ref_emb_fd is not None:
- # ref_emb_fd.close()
- ref_emb_fd.close()
- score_fd.close()
+ if ref_embd_writer is not None:
+ ref_embd_writer.close()
+ score_writer.close()
return sv_result_list
--
Gitblit v1.9.1