From 0271fbe4fd11114c259ec970b1b04938010a16e3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 18:53:33 +0800
Subject: [PATCH] inference
---
funasr/bin/asr_infer.py | 251 +++++++++++++++++++++++++++++++++++
funasr/bin/asr_inference_launch.py | 164 +++++++++++++++++++++++
2 files changed, 414 insertions(+), 1 deletions(-)
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index bcf5718..f8736e9 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -32,6 +32,7 @@
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
+from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
@@ -58,7 +59,7 @@
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
-
+from funasr.tasks.asr import frontend_choices
class Speech2Text:
"""Speech2Text class
@@ -1599,3 +1600,251 @@
return Speech2Text(**kwargs)
+
+class Speech2TextSAASR:
+ """Speech2Text class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2text(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ frontend_conf: dict = None,
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build ASR model
+ from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
+ scorers = {}
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ if asr_train_args.frontend == 'wav_frontend':
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+ else:
+ frontend_class = frontend_choices.get_class(asr_train_args.frontend)
+ frontend = frontend_class(**asr_train_args.frontend_conf).eval()
+
+ logging.info("asr_model: {}".format(asr_model))
+ logging.info("asr_train_args: {}".format(asr_train_args))
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ decoder = asr_model.decoder
+
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ token_list = asr_model.token_list
+ scorers.update(
+ decoder=decoder,
+ ctc=ctc,
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 2. Build Language model
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, None, device
+ )
+ scorers["lm"] = lm.lm
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ # 4. Build BeamSearch object
+ # transducer is not supported now
+ beam_search_transducer = None
+
+ weights = dict(
+ decoder=1.0 - ctc_weight,
+ ctc=ctc_weight,
+ lm=lm_weight,
+ ngram=ngram_weight,
+ length_bonus=penalty,
+ )
+ beam_search = BeamSearch(
+ beam_size=beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=asr_model.sos,
+ eos=asr_model.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+ )
+
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+ self.beam_search = beam_search
+ self.beam_search_transducer = beam_search_transducer
+ self.maxlenratio = maxlenratio
+ self.minlenratio = minlenratio
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+ self.frontend = frontend
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
+ profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
+ ) -> List[
+ Tuple[
+ Optional[str],
+ Optional[str],
+ List[str],
+ List[int],
+ Union[HypothesisSAASR],
+ ]
+ ]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, text_id, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ if isinstance(profile, np.ndarray):
+ profile = torch.tensor(profile)
+
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ asr_enc, _, spk_enc = self.asr_model.encode(**batch)
+ if isinstance(asr_enc, tuple):
+ asr_enc = asr_enc[0]
+ if isinstance(spk_enc, tuple):
+ spk_enc = spk_enc[0]
+ assert len(asr_enc) == 1, len(asr_enc)
+ assert len(spk_enc) == 1, len(spk_enc)
+
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+
+ results = []
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (HypothesisSAASR)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1: last_pos]
+ else:
+ token_int = hyp.yseq[1: last_pos].tolist()
+
+ spk_weigths = torch.stack(hyp.spk_weigths, dim=0)
+
+ token_ori = self.converter.ids2tokens(token_int)
+ text_ori = self.tokenizer.tokens2text(token_ori)
+
+ text_ori_spklist = text_ori.split('$')
+ cur_index = 0
+ spk_choose = []
+ for i in range(len(text_ori_spklist)):
+ text_ori_split = text_ori_spklist[i]
+ n = len(text_ori_split)
+ spk_weights_local = spk_weigths[cur_index: cur_index + n]
+ cur_index = cur_index + n + 1
+ spk_weights_local = spk_weights_local.mean(dim=0)
+ spk_choose_local = spk_weights_local.argmax(-1)
+ spk_choose.append(spk_choose_local.item() + 1)
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+
+ text_spklist = text.split('$')
+ assert len(spk_choose) == len(text_spklist)
+
+ spk_list = []
+ for i in range(len(text_spklist)):
+ text_split = text_spklist[i]
+ n = len(text_split)
+ spk_list.append(str(spk_choose[i]) * n)
+
+ text_id = '$'.join(spk_list)
+
+ assert len(text) == len(text_id)
+
+ results.append((text, text_id, token, token_int, hyp))
+
+ assert check_return_type(results)
+ return results
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index f9b8571..dbbb3ed 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -77,6 +77,7 @@
from funasr.bin.punc_infer import Text2Punc
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.asr_infer import Speech2TextTransducer
+from funasr.bin.asr_infer import Speech2TextSAASR
def inference_asr(
maxlenratio: float,
@@ -1444,6 +1445,167 @@
return _forward
+def inference_sa_asr(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ mc: bool = False,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ )
+ logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+ speech2text = Speech2TextSAASR(**speech2text_kwargs)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ mc=mc,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["sil"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+ ibest_writer["text_id"][key] = text_id
+
+ if text is not None:
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+
+ logging.info("uttid: {}".format(key))
+ logging.info("text predictions: {}".format(text))
+ logging.info("text_id predictions: {}\n".format(text_id))
+ return asr_result_list
+
+ return _forward
+
+
def inference_launch(**kwargs):
if 'mode' in kwargs:
mode = kwargs['mode']
@@ -1464,6 +1626,8 @@
return inference_mfcca(**kwargs)
elif mode == "rnnt":
return inference_transducer(**kwargs)
+ elif mode == "sa_asr":
+ return inference_sa_asr(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
--
Gitblit v1.9.1