From fae856e23d45fd27d5fd55fd036e8e3fc7b24915 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 02 六月 2023 23:00:08 +0800
Subject: [PATCH] update funasr-onnx-offline
---
funasr/bin/asr_inference_launch.py | 806 +++++++++++++++++++++++++++++++++++++++++++++++++++++---
1 files changed, 753 insertions(+), 53 deletions(-)
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 1870032..f84212d 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
@@ -61,15 +64,181 @@
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.bin.tp_inference import SpeechText2Timestamp
-from funasr.bin.vad_inference import Speech2VadSegment
-from funasr.bin.punctuation_infer import Text2Punc
+
+
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.bin.asr_infer import Speech2Text
from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
from funasr.bin.asr_infer import Speech2TextUniASR
+from funasr.bin.asr_infer import Speech2TextMFCCA
+from funasr.bin.vad_infer import Speech2VadSegment
+from funasr.bin.punc_infer import Text2Punc
+from funasr.bin.tp_infer import Speech2Timestamp
+from funasr.bin.asr_infer import Speech2TextTransducer
+from funasr.bin.asr_infer import Speech2TextSAASR
+
+def inference_asr(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ mc: bool = False,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ )
+ logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+ speech2text = Speech2Text(**speech2text_kwargs)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ mc=mc,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["sil"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+
+ if text is not None:
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+
+ logging.info("uttid: {}".format(key))
+ logging.info("text predictions: {}\n".format(text))
+ return asr_result_list
+
+ return _forward
def inference_paraformer(
@@ -161,7 +330,7 @@
speech2text = Speech2TextParaformer(**speech2text_kwargs)
if timestamp_model_file is not None:
- speechtext2timestamp = SpeechText2Timestamp(
+ speechtext2timestamp = Speech2Timestamp(
timestamp_cmvn_file=cmvn_file,
timestamp_model_file=timestamp_model_file,
timestamp_infer_config=timestamp_infer_config,
@@ -431,6 +600,9 @@
if 'hotword' in kwargs:
hotword_list_or_file = kwargs['hotword']
+ batch_size_token = kwargs.get("batch_size_token", 6000)
+ print("batch_size_token: ", batch_size_token)
+
if speech2text.hotword_list is None:
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
@@ -473,8 +645,10 @@
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
+ beg_vad = time.time()
vad_results = speech2vadsegment(**batch)
+ end_vad = time.time()
+ print("time cost vad: ", end_vad-beg_vad)
_, vadsegments = vad_results[0], vad_results[1][0]
speech, speech_lengths = batch["speech"], batch["speech_lengths"]
@@ -483,17 +657,29 @@
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
- for j, beg_idx in enumerate(range(0, n, batch_size)):
- end_idx = min(n, beg_idx + batch_size)
+ batch_size_token_ms = batch_size_token*60
+ batch_size_token_ms_cum = 0
+ beg_idx = 0
+ for j, _ in enumerate(range(0, n)):
+ batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
+ if j < n-1 and (batch_size_token_ms_cum + sorted_data[j+1][0][1] - sorted_data[j+1][0][0])<batch_size_token_ms:
+ continue
+ batch_size_token_ms_cum = 0
+ end_idx = j + 1
speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
-
+ beg_idx = end_idx
batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
batch = to_device(batch, device=device)
+ print("batch: ", speech_j.shape[0])
+ beg_asr = time.time()
results = speech2text(**batch)
+ end_asr = time.time()
+ print("time cost asr: ", end_asr - beg_asr)
if len(results) < 1:
results = [["", [], [], [], [], [], []]]
results_sorted.extend(results)
+
restored_data = [0] * n
for j in range(n):
index = sorted_data[j][1]
@@ -532,7 +718,10 @@
text_postprocessed_punc = text_postprocessed
punc_id_list = []
if len(word_lists) > 0 and text2punc is not None:
+ beg_punc = time.time()
text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
+ end_punc = time.time()
+ print("time cost punc: ", end_punc-beg_punc)
item = {'key': key, 'value': text_postprocessed_punc}
if text_postprocessed != "":
@@ -931,12 +1120,547 @@
return _forward
+def inference_mfcca(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ )
+ logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+ speech2text = Speech2TextMFCCA(**speech2text_kwargs)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ fs=fs,
+ mc=True,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["<space>"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+
+ if text is not None:
+ text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+ return asr_result_list
+
+ return _forward
+
+def inference_transducer(
+ output_dir: str,
+ batch_size: int,
+ dtype: str,
+ beam_size: int,
+ ngpu: int,
+ seed: int,
+ lm_weight: float,
+ nbest: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str],
+ beam_search_config: Optional[dict],
+ lm_train_config: Optional[str],
+ lm_file: Optional[str],
+ model_tag: Optional[str],
+ token_type: Optional[str],
+ bpemodel: Optional[str],
+ key_file: Optional[str],
+ allow_variable_data_keys: bool,
+ quantize_asr_model: Optional[bool],
+ quantize_modules: Optional[List[str]],
+ quantize_dtype: Optional[str],
+ streaming: Optional[bool],
+ simu_streaming: Optional[bool],
+ chunk_size: Optional[int],
+ left_context: Optional[int],
+ right_context: Optional[int],
+ display_partial_hypotheses: bool,
+ **kwargs,
+) -> None:
+ """Transducer model inference.
+ Args:
+ output_dir: Output directory path.
+ batch_size: Batch decoding size.
+ dtype: Data type.
+ beam_size: Beam size.
+ ngpu: Number of GPUs.
+ seed: Random number generator seed.
+ lm_weight: Weight of language model.
+ nbest: Number of final hypothesis.
+ num_workers: Number of workers.
+ log_level: Level of verbose for logs.
+ data_path_and_name_and_type:
+ asr_train_config: ASR model training config path.
+ asr_model_file: ASR model path.
+ beam_search_config: Beam search config path.
+ lm_train_config: Language Model training config path.
+ lm_file: Language Model path.
+ model_tag: Model tag.
+ token_type: Type of token units.
+ bpemodel: BPE model path.
+ key_file: File key.
+ allow_variable_data_keys: Whether to allow variable data keys.
+ quantize_asr_model: Whether to apply dynamic quantization to ASR model.
+ quantize_modules: List of module names to apply dynamic quantization on.
+ quantize_dtype: Dynamic quantization data type.
+ streaming: Whether to perform chunk-by-chunk inference.
+ chunk_size: Number of frames in chunk AFTER subsampling.
+ left_context: Number of frames in left context AFTER subsampling.
+ right_context: Number of frames in right context AFTER subsampling.
+ display_partial_hypotheses: Whether to display partial hypotheses.
+ """
+ assert check_argument_types()
+
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1:
+ device = "cuda"
+ else:
+ device = "cpu"
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ beam_search_config=beam_search_config,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ dtype=dtype,
+ beam_size=beam_size,
+ lm_weight=lm_weight,
+ nbest=nbest,
+ quantize_asr_model=quantize_asr_model,
+ quantize_modules=quantize_modules,
+ quantize_dtype=quantize_dtype,
+ streaming=streaming,
+ simu_streaming=simu_streaming,
+ chunk_size=chunk_size,
+ left_context=left_context,
+ right_context=right_context,
+ )
+ speech2text = Speech2TextTransducer.from_pretrained(
+ model_tag=model_tag,
+ **speech2text_kwargs,
+ )
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(
+ speech2text.asr_train_args, False
+ ),
+ collate_fn=ASRTask.build_collate_fn(
+ speech2text.asr_train_args, False
+ ),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ # 4 .Start for-loop
+ with DatadirWriter(output_dir) as writer:
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+ assert len(batch.keys()) == 1
+
+ try:
+ if speech2text.streaming:
+ speech = batch["speech"]
+
+ _steps = len(speech) // speech2text._ctx
+ _end = 0
+ for i in range(_steps):
+ _end = (i + 1) * speech2text._ctx
+
+ speech2text.streaming_decode(
+ speech[i * speech2text._ctx : _end], is_final=False
+ )
+
+ final_hyps = speech2text.streaming_decode(
+ speech[_end : len(speech)], is_final=True
+ )
+ elif speech2text.simu_streaming:
+ final_hyps = speech2text.simu_streaming_decode(**batch)
+ else:
+ final_hyps = speech2text(**batch)
+
+ results = speech2text.hypotheses_to_results(final_hyps)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
+ results = [[" ", ["<space>"], [2], hyp]] * nbest
+
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ ibest_writer = writer[f"{n}best_recog"]
+
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+
+ if text is not None:
+ ibest_writer["text"][key] = text
+
+
+ return _forward
+
+
+def inference_sa_asr(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ mc: bool = False,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ )
+ logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+ speech2text = Speech2TextSAASR(**speech2text_kwargs)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ mc=mc,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["sil"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+ ibest_writer["text_id"][key] = text_id
+
+ if text is not None:
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+
+ logging.info("uttid: {}".format(key))
+ logging.info("text predictions: {}".format(text))
+ logging.info("text_id predictions: {}\n".format(text_id))
+ return asr_result_list
+
+ return _forward
+
+
+def inference_launch(**kwargs):
+ if 'mode' in kwargs:
+ mode = kwargs['mode']
+ else:
+ logging.info("Unknown decoding mode.")
+ return None
+ if mode == "asr":
+ return inference_asr(**kwargs)
+ elif mode == "uniasr":
+ return inference_uniasr(**kwargs)
+ elif mode == "paraformer":
+ return inference_paraformer(**kwargs)
+ elif mode == "paraformer_fake_streaming":
+ return inference_paraformer(**kwargs)
+ elif mode == "paraformer_streaming":
+ return inference_paraformer_online(**kwargs)
+ elif mode.startswith("paraformer_vad"):
+ return inference_paraformer_vad_punc(**kwargs)
+ elif mode == "mfcca":
+ return inference_mfcca(**kwargs)
+ elif mode == "rnnt":
+ return inference_transducer(**kwargs)
+ elif mode == "sa_asr":
+ return inference_sa_asr(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
-
+
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
@@ -946,7 +1670,7 @@
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
-
+
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
@@ -979,7 +1703,7 @@
default=1,
help="The number of workers used for DataLoader",
)
-
+
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
@@ -988,14 +1712,20 @@
action="append",
)
group.add_argument("--key_file", type=str_or_none)
+ parser.add_argument(
+ "--hotword",
+ type=str_or_none,
+ default=None,
+ help="hotword file path or hotwords seperated by space"
+ )
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group.add_argument(
- "--mc",
- type=bool,
- default=False,
- help="MultiChannel input",
- )
-
+ "--mc",
+ type=bool,
+ default=False,
+ help="MultiChannel input",
+ )
+
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
@@ -1058,7 +1788,7 @@
default={},
help="The keyword arguments for transducer beam search.",
)
-
+
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
@@ -1104,8 +1834,8 @@
type=bool,
default=False,
help="Whether to display partial hypotheses during chunk-by-chunk inference.",
- )
-
+ )
+
group = parser.add_argument_group("Dynamic quantization related")
group.add_argument(
"--quantize_asr_model",
@@ -1129,8 +1859,8 @@
default="qint8",
choices=["float16", "qint8"],
help="Dtype for dynamic quantization.",
- )
-
+ )
+
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
@@ -1157,36 +1887,6 @@
help="CTC weight in joint decoding",
)
return parser
-
-
-
-def inference_launch(**kwargs):
- if 'mode' in kwargs:
- mode = kwargs['mode']
- else:
- logging.info("Unknown decoding mode.")
- return None
- if mode == "asr":
- from funasr.bin.asr_inference import inference_modelscope
- return inference_modelscope(**kwargs)
- elif mode == "uniasr":
- return inference_uniasr(**kwargs)
- elif mode == "paraformer":
- return inference_paraformer(**kwargs)
- elif mode == "paraformer_streaming":
- return inference_paraformer_online(**kwargs)
- elif mode.startswith("paraformer_vad"):
- return inference_paraformer_vad_punc(**kwargs)
- elif mode == "mfcca":
- from funasr.bin.asr_inference_mfcca import inference_modelscope
- return inference_modelscope(**kwargs)
- elif mode == "rnnt":
- from funasr.bin.asr_inference_rnnt import inference_modelscope
- return inference_modelscope(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
def main(cmd=None):
@@ -1222,4 +1922,4 @@
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
--
Gitblit v1.9.1