From d105ce0d6b63bcd14edeb426fbc0acf593296be3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 13:58:11 +0800
Subject: [PATCH] inference
---
funasr/bin/sa_asr_train.py | 3
funasr/bin/punc_infer.py | 4
funasr/bin/sv_inference_launch.py | 19
funasr/bin/lm_train.py | 3
funasr/bin/vad_inference_launch.py | 26
funasr/bin/diar_train.py | 3
funasr/bin/asr_train.py | 3
funasr/bin/diar_inference_launch.py | 48
funasr/bin/lm_inference_launch.py | 297 +++++++++
funasr/bin/punc_inference_launch.py | 24
funasr/bin/sa_asr_inference.py | 5
funasr/bin/vad_infer.py | 5
funasr/bin/diar_infer.py | 1
funasr/bin/asr_infer.py | 345 ++++++++++
funasr/bin/punc_train.py | 4
/dev/null | 406 -------------
funasr/bin/tp_infer.py | 5
funasr/bin/build_trainer.py | 5
funasr/bin/sv_infer.py | 1
funasr/bin/tp_inference_launch.py | 19
funasr/bin/asr_inference_launch.py | 604 +++++++++++++++++-
funasr/bin/asr_train_paraformer.py | 9
22 files changed, 1,308 insertions(+), 531 deletions(-)
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 488be16..bcf5718 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -1,4 +1,8 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import argparse
import logging
import sys
@@ -19,13 +23,15 @@
import numpy as np
import torch
+from packaging.version import parse as V
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import BeamSearch
# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
-
from funasr.modules.beam_search.beam_search import Hypothesis
+from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
+from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
@@ -47,11 +53,10 @@
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.bin.tp_infer import Speech2Timestamp
-from funasr.bin.vad_inference import Speech2VadSegment
+from funasr.bin.vad_infer import Speech2VadSegment
from funasr.bin.punc_infer import Text2Punc
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
-
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
@@ -264,7 +269,6 @@
assert check_return_type(results)
return results
-
class Speech2TextParaformer:
"""Speech2Text class
@@ -839,7 +843,6 @@
# assert check_return_type(results)
return results
-
class Speech2TextUniASR:
"""Speech2Text class
@@ -1072,9 +1075,7 @@
assert check_return_type(results)
return results
-
-
-
+
class Speech2TextMFCCA:
"""Speech2Text class
@@ -1114,6 +1115,7 @@
assert check_argument_types()
# 1. Build ASR model
+ from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
@@ -1270,3 +1272,330 @@
return results
+class Speech2TextTransducer:
+ """Speech2Text class for Transducer models.
+ Args:
+ asr_train_config: ASR model training config path.
+ asr_model_file: ASR model path.
+ beam_search_config: Beam search config path.
+ lm_train_config: Language Model training config path.
+ lm_file: Language Model config path.
+ token_type: Type of token units.
+ bpemodel: BPE model path.
+ device: Device to use for inference.
+ beam_size: Size of beam during search.
+ dtype: Data type.
+ lm_weight: Language model weight.
+ quantize_asr_model: Whether to apply dynamic quantization to ASR model.
+ quantize_modules: List of module names to apply dynamic quantization on.
+ quantize_dtype: Dynamic quantization data type.
+ nbest: Number of final hypothesis.
+ streaming: Whether to perform chunk-by-chunk inference.
+ chunk_size: Number of frames in chunk AFTER subsampling.
+ left_context: Number of frames in left context AFTER subsampling.
+ right_context: Number of frames in right context AFTER subsampling.
+ display_partial_hypotheses: Whether to display partial hypotheses.
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ beam_search_config: Dict[str, Any] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ beam_size: int = 5,
+ dtype: str = "float32",
+ lm_weight: float = 1.0,
+ quantize_asr_model: bool = False,
+ quantize_modules: List[str] = None,
+ quantize_dtype: str = "qint8",
+ nbest: int = 1,
+ streaming: bool = False,
+ simu_streaming: bool = False,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ display_partial_hypotheses: bool = False,
+ ) -> None:
+ """Construct a Speech2Text object."""
+ super().__init__()
+
+ assert check_argument_types()
+ from funasr.tasks.asr import ASRTransducerTask
+ asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+
+ if quantize_asr_model:
+ if quantize_modules is not None:
+ if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
+ raise ValueError(
+ "Only 'Linear' and 'LSTM' modules are currently supported"
+ " by PyTorch and in --quantize_modules"
+ )
+
+ q_config = set([getattr(torch.nn, q) for q in quantize_modules])
+ else:
+ q_config = {torch.nn.Linear}
+
+ if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
+ raise ValueError(
+ "float16 dtype for dynamic quantization is not supported with torch"
+ " version < 1.5.0. Switching to qint8 dtype instead."
+ )
+ q_dtype = getattr(torch, quantize_dtype)
+
+ asr_model = torch.quantization.quantize_dynamic(
+ asr_model, q_config, dtype=q_dtype
+ ).eval()
+ else:
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, device
+ )
+ lm_scorer = lm.lm
+ else:
+ lm_scorer = None
+
+ # 4. Build BeamSearch object
+ if beam_search_config is None:
+ beam_search_config = {}
+
+ beam_search = BeamSearchTransducer(
+ asr_model.decoder,
+ asr_model.joint_network,
+ beam_size,
+ lm=lm_scorer,
+ lm_weight=lm_weight,
+ nbest=nbest,
+ **beam_search_config,
+ )
+
+ token_list = asr_model.token_list
+
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+
+ self.converter = converter
+ self.tokenizer = tokenizer
+
+ self.beam_search = beam_search
+ self.streaming = streaming
+ self.simu_streaming = simu_streaming
+ self.chunk_size = max(chunk_size, 0)
+ self.left_context = left_context
+ self.right_context = max(right_context, 0)
+
+ if not streaming or chunk_size == 0:
+ self.streaming = False
+ self.asr_model.encoder.dynamic_chunk_training = False
+
+ if not simu_streaming or chunk_size == 0:
+ self.simu_streaming = False
+ self.asr_model.encoder.dynamic_chunk_training = False
+
+ self.frontend = frontend
+ self.window_size = self.chunk_size + self.right_context
+
+ if self.streaming:
+ self._ctx = self.asr_model.encoder.get_encoder_input_size(
+ self.window_size
+ )
+
+ self.last_chunk_length = (
+ self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+ )
+ self.reset_inference_cache()
+
+ def reset_inference_cache(self) -> None:
+ """Reset Speech2Text parameters."""
+ self.frontend_cache = None
+
+ self.asr_model.encoder.reset_streaming_cache(
+ self.left_context, device=self.device
+ )
+ self.beam_search.reset_inference_cache()
+
+ self.num_processed_frames = torch.tensor([[0]], device=self.device)
+
+ @torch.no_grad()
+ def streaming_decode(
+ self,
+ speech: Union[torch.Tensor, np.ndarray],
+ is_final: bool = True,
+ ) -> List[HypothesisTransducer]:
+ """Speech2Text streaming call.
+ Args:
+ speech: Chunk of speech data. (S)
+ is_final: Whether speech corresponds to the final chunk of data.
+ Returns:
+ nbest_hypothesis: N-best hypothesis.
+ """
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+ if is_final:
+ if self.streaming and speech.size(0) < self.last_chunk_length:
+ pad = torch.zeros(
+ self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype
+ )
+ speech = torch.cat([speech, pad],
+ dim=0) # feats, feats_length = self.apply_frontend(speech, is_final=is_final)
+
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ if self.asr_model.normalize is not None:
+ feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+ feats = to_device(feats, device=self.device)
+ feats_lengths = to_device(feats_lengths, device=self.device)
+ enc_out = self.asr_model.encoder.chunk_forward(
+ feats,
+ feats_lengths,
+ self.num_processed_frames,
+ chunk_size=self.chunk_size,
+ left_context=self.left_context,
+ right_context=self.right_context,
+ )
+ nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
+
+ self.num_processed_frames += self.chunk_size
+
+ if is_final:
+ self.reset_inference_cache()
+
+ return nbest_hyps
+
+ @torch.no_grad()
+ def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
+ """Speech2Text call.
+ Args:
+ speech: Speech data. (S)
+ Returns:
+ nbest_hypothesis: N-best hypothesis.
+ """
+ assert check_argument_types()
+
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ if self.asr_model.normalize is not None:
+ feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+ feats = to_device(feats, device=self.device)
+ feats_lengths = to_device(feats_lengths, device=self.device)
+ enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context,
+ self.right_context)
+ nbest_hyps = self.beam_search(enc_out[0])
+
+ return nbest_hyps
+
+ @torch.no_grad()
+ def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
+ """Speech2Text call.
+ Args:
+ speech: Speech data. (S)
+ Returns:
+ nbest_hypothesis: N-best hypothesis.
+ """
+ assert check_argument_types()
+
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ feats = to_device(feats, device=self.device)
+ feats_lengths = to_device(feats_lengths, device=self.device)
+
+ enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
+
+ nbest_hyps = self.beam_search(enc_out[0])
+
+ return nbest_hyps
+
+ def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]:
+ """Build partial or final results from the hypotheses.
+ Args:
+ nbest_hyps: N-best hypothesis.
+ Returns:
+ results: Results containing different representation for the hypothesis.
+ """
+ results = []
+
+ for hyp in nbest_hyps:
+ token_int = list(filter(lambda x: x != 0, hyp.yseq))
+
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+ results.append((text, token, token_int, hyp))
+
+ assert check_return_type(results)
+
+ return results
+
+ @staticmethod
+ def from_pretrained(
+ model_tag: Optional[str] = None,
+ **kwargs: Optional[Any],
+ ) -> Speech2Text:
+ """Build Speech2Text instance from the pretrained model.
+ Args:
+ model_tag: Model tag of the pretrained models.
+ Return:
+ : Speech2Text instance.
+ """
+ if model_tag is not None:
+ try:
+ from espnet_model_zoo.downloader import ModelDownloader
+
+ except ImportError:
+ logging.error(
+ "`espnet_model_zoo` is not installed. "
+ "Please install via `pip install -U espnet_model_zoo`."
+ )
+ raise
+ d = ModelDownloader()
+ kwargs.update(**d.download_and_unpack(model_tag))
+
+ return Speech2Text(**kwargs)
+
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 1870032..4a55caa 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
@@ -61,15 +64,180 @@
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.bin.tp_inference import SpeechText2Timestamp
-from funasr.bin.vad_inference import Speech2VadSegment
-from funasr.bin.punctuation_infer import Text2Punc
+
+
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.bin.asr_infer import Speech2Text
from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
from funasr.bin.asr_infer import Speech2TextUniASR
+from funasr.bin.asr_infer import Speech2TextMFCCA
+from funasr.bin.vad_infer import Speech2VadSegment
+from funasr.bin.punc_infer import Text2Punc
+from funasr.bin.tp_infer import Speech2Timestamp
+from funasr.bin.asr_infer import Speech2TextTransducer
+
+def inference_asr(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ mc: bool = False,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ )
+ logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+ speech2text = Speech2Text(**speech2text_kwargs)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ mc=mc,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["sil"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+
+ if text is not None:
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+
+ logging.info("uttid: {}".format(key))
+ logging.info("text predictions: {}\n".format(text))
+ return asr_result_list
+
+ return _forward
def inference_paraformer(
@@ -161,7 +329,7 @@
speech2text = Speech2TextParaformer(**speech2text_kwargs)
if timestamp_model_file is not None:
- speechtext2timestamp = SpeechText2Timestamp(
+ speechtext2timestamp = Speech2Timestamp(
timestamp_cmvn_file=cmvn_file,
timestamp_model_file=timestamp_model_file,
timestamp_infer_config=timestamp_infer_config,
@@ -931,12 +1099,382 @@
return _forward
+def inference_mfcca(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ )
+ logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+ speech2text = Speech2TextMFCCA(**speech2text_kwargs)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ fs=fs,
+ mc=True,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["<space>"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+
+ if text is not None:
+ text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+ return asr_result_list
+
+ return _forward
+
+def inference_transducer(
+ output_dir: str,
+ batch_size: int,
+ dtype: str,
+ beam_size: int,
+ ngpu: int,
+ seed: int,
+ lm_weight: float,
+ nbest: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str],
+ beam_search_config: Optional[dict],
+ lm_train_config: Optional[str],
+ lm_file: Optional[str],
+ model_tag: Optional[str],
+ token_type: Optional[str],
+ bpemodel: Optional[str],
+ key_file: Optional[str],
+ allow_variable_data_keys: bool,
+ quantize_asr_model: Optional[bool],
+ quantize_modules: Optional[List[str]],
+ quantize_dtype: Optional[str],
+ streaming: Optional[bool],
+ simu_streaming: Optional[bool],
+ chunk_size: Optional[int],
+ left_context: Optional[int],
+ right_context: Optional[int],
+ display_partial_hypotheses: bool,
+ **kwargs,
+) -> None:
+ """Transducer model inference.
+ Args:
+ output_dir: Output directory path.
+ batch_size: Batch decoding size.
+ dtype: Data type.
+ beam_size: Beam size.
+ ngpu: Number of GPUs.
+ seed: Random number generator seed.
+ lm_weight: Weight of language model.
+ nbest: Number of final hypothesis.
+ num_workers: Number of workers.
+ log_level: Level of verbose for logs.
+ data_path_and_name_and_type:
+ asr_train_config: ASR model training config path.
+ asr_model_file: ASR model path.
+ beam_search_config: Beam search config path.
+ lm_train_config: Language Model training config path.
+ lm_file: Language Model path.
+ model_tag: Model tag.
+ token_type: Type of token units.
+ bpemodel: BPE model path.
+ key_file: File key.
+ allow_variable_data_keys: Whether to allow variable data keys.
+ quantize_asr_model: Whether to apply dynamic quantization to ASR model.
+ quantize_modules: List of module names to apply dynamic quantization on.
+ quantize_dtype: Dynamic quantization data type.
+ streaming: Whether to perform chunk-by-chunk inference.
+ chunk_size: Number of frames in chunk AFTER subsampling.
+ left_context: Number of frames in left context AFTER subsampling.
+ right_context: Number of frames in right context AFTER subsampling.
+ display_partial_hypotheses: Whether to display partial hypotheses.
+ """
+ assert check_argument_types()
+
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1:
+ device = "cuda"
+ else:
+ device = "cpu"
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ beam_search_config=beam_search_config,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ dtype=dtype,
+ beam_size=beam_size,
+ lm_weight=lm_weight,
+ nbest=nbest,
+ quantize_asr_model=quantize_asr_model,
+ quantize_modules=quantize_modules,
+ quantize_dtype=quantize_dtype,
+ streaming=streaming,
+ simu_streaming=simu_streaming,
+ chunk_size=chunk_size,
+ left_context=left_context,
+ right_context=right_context,
+ )
+ speech2text = Speech2TextTransducer.from_pretrained(
+ model_tag=model_tag,
+ **speech2text_kwargs,
+ )
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(
+ speech2text.asr_train_args, False
+ ),
+ collate_fn=ASRTask.build_collate_fn(
+ speech2text.asr_train_args, False
+ ),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ # 4 .Start for-loop
+ with DatadirWriter(output_dir) as writer:
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+ assert len(batch.keys()) == 1
+
+ try:
+ if speech2text.streaming:
+ speech = batch["speech"]
+
+ _steps = len(speech) // speech2text._ctx
+ _end = 0
+ for i in range(_steps):
+ _end = (i + 1) * speech2text._ctx
+
+ speech2text.streaming_decode(
+ speech[i * speech2text._ctx : _end], is_final=False
+ )
+
+ final_hyps = speech2text.streaming_decode(
+ speech[_end : len(speech)], is_final=True
+ )
+ elif speech2text.simu_streaming:
+ final_hyps = speech2text.simu_streaming_decode(**batch)
+ else:
+ final_hyps = speech2text(**batch)
+
+ results = speech2text.hypotheses_to_results(final_hyps)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
+ results = [[" ", ["<space>"], [2], hyp]] * nbest
+
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ ibest_writer = writer[f"{n}best_recog"]
+
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+
+ if text is not None:
+ ibest_writer["text"][key] = text
+
+
+ return _forward
+
+
+def inference_launch(**kwargs):
+ if 'mode' in kwargs:
+ mode = kwargs['mode']
+ else:
+ logging.info("Unknown decoding mode.")
+ return None
+ if mode == "asr":
+ return inference_asr(**kwargs)
+ elif mode == "uniasr":
+ return inference_uniasr(**kwargs)
+ elif mode == "paraformer":
+ return inference_paraformer(**kwargs)
+ elif mode == "paraformer_streaming":
+ return inference_paraformer_online(**kwargs)
+ elif mode.startswith("paraformer_vad"):
+ return inference_paraformer_vad_punc(**kwargs)
+ elif mode == "mfcca":
+ return inference_mfcca(**kwargs)
+ elif mode == "rnnt":
+ return inference_transducer(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
-
+
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
@@ -946,7 +1484,7 @@
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
-
+
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
@@ -979,7 +1517,7 @@
default=1,
help="The number of workers used for DataLoader",
)
-
+
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
@@ -990,12 +1528,12 @@
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group.add_argument(
- "--mc",
- type=bool,
- default=False,
- help="MultiChannel input",
- )
-
+ "--mc",
+ type=bool,
+ default=False,
+ help="MultiChannel input",
+ )
+
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
@@ -1058,7 +1596,7 @@
default={},
help="The keyword arguments for transducer beam search.",
)
-
+
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
@@ -1104,8 +1642,8 @@
type=bool,
default=False,
help="Whether to display partial hypotheses during chunk-by-chunk inference.",
- )
-
+ )
+
group = parser.add_argument_group("Dynamic quantization related")
group.add_argument(
"--quantize_asr_model",
@@ -1129,8 +1667,8 @@
default="qint8",
choices=["float16", "qint8"],
help="Dtype for dynamic quantization.",
- )
-
+ )
+
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
@@ -1157,36 +1695,6 @@
help="CTC weight in joint decoding",
)
return parser
-
-
-
-def inference_launch(**kwargs):
- if 'mode' in kwargs:
- mode = kwargs['mode']
- else:
- logging.info("Unknown decoding mode.")
- return None
- if mode == "asr":
- from funasr.bin.asr_inference import inference_modelscope
- return inference_modelscope(**kwargs)
- elif mode == "uniasr":
- return inference_uniasr(**kwargs)
- elif mode == "paraformer":
- return inference_paraformer(**kwargs)
- elif mode == "paraformer_streaming":
- return inference_paraformer_online(**kwargs)
- elif mode.startswith("paraformer_vad"):
- return inference_paraformer_vad_punc(**kwargs)
- elif mode == "mfcca":
- from funasr.bin.asr_inference_mfcca import inference_modelscope
- return inference_modelscope(**kwargs)
- elif mode == "rnnt":
- from funasr.bin.asr_inference_rnnt import inference_modelscope
- return inference_modelscope(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
def main(cmd=None):
diff --git a/funasr/bin/asr_train.py b/funasr/bin/asr_train.py
index a43472c..0dec107 100755
--- a/funasr/bin/asr_train.py
+++ b/funasr/bin/asr_train.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import os
diff --git a/funasr/bin/asr_train_paraformer.py b/funasr/bin/asr_train_paraformer.py
index 76943d5..223be14 100755
--- a/funasr/bin/asr_train_paraformer.py
+++ b/funasr/bin/asr_train_paraformer.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import os
@@ -9,6 +12,12 @@
def parse_args():
parser = ASRTask.get_parser()
parser.add_argument(
+ "--mode",
+ type=str,
+ default="asr",
+ help="mode",
+ )
+ parser.add_argument(
"--gpu_id",
type=int,
default=0,
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
index 5c30fdb..df3434f 100644
--- a/funasr/bin/build_trainer.py
+++ b/funasr/bin/build_trainer.py
@@ -1,3 +1,8 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import os
import yaml
diff --git a/funasr/bin/diar_infer.py b/funasr/bin/diar_infer.py
index f698a66..f2dcb1e 100755
--- a/funasr/bin/diar_infer.py
+++ b/funasr/bin/diar_infer.py
@@ -1,3 +1,4 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index 08004e8..69d37d6 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -1,3 +1,4 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -362,6 +363,30 @@
return _forward
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "sond":
+ return inference_sond(mode=mode, **kwargs)
+ elif mode == "sond_demo":
+ param_dict = {
+ "extract_profile": True,
+ "sv_train_config": "sv.yaml",
+ "sv_model_file": "sv.pb",
+ }
+ if "param_dict" in kwargs and kwargs["param_dict"] is not None:
+ for key in param_dict:
+ if key not in kwargs["param_dict"]:
+ kwargs["param_dict"][key] = param_dict[key]
+ else:
+ kwargs["param_dict"] = param_dict
+ return inference_sond(mode=mode, **kwargs)
+ elif mode == "eend-ola":
+ return inference_eend(mode=mode, **kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker Verification",
@@ -469,29 +494,6 @@
)
return parser
-
-
-def inference_launch(mode, **kwargs):
- if mode == "sond":
- return inference_sond(mode=mode, **kwargs)
- elif mode == "sond_demo":
- param_dict = {
- "extract_profile": True,
- "sv_train_config": "sv.yaml",
- "sv_model_file": "sv.pb",
- }
- if "param_dict" in kwargs and kwargs["param_dict"] is not None:
- for key in param_dict:
- if key not in kwargs["param_dict"]:
- kwargs["param_dict"][key] = param_dict[key]
- else:
- kwargs["param_dict"] = param_dict
- return inference_sond(mode=mode, **kwargs)
- elif mode == "eend-ola":
- return inference_eend(mode=mode, **kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
def main(cmd=None):
diff --git a/funasr/bin/diar_train.py b/funasr/bin/diar_train.py
index f76d1b9..16a4bd0 100755
--- a/funasr/bin/diar_train.py
+++ b/funasr/bin/diar_train.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import os
diff --git a/funasr/bin/lm_calc_perplexity.py b/funasr/bin/lm_calc_perplexity.py
deleted file mode 100755
index 198d578..0000000
--- a/funasr/bin/lm_calc_perplexity.py
+++ /dev/null
@@ -1,211 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-from pathlib import Path
-import sys
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-from torch.nn.parallel import data_parallel
-from typeguard import check_argument_types
-
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.tasks.lm import LMTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import float_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-def calc_perplexity(
- output_dir: str,
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- log_base: Optional[float],
- allow_variable_data_keys: bool,
-):
- assert check_argument_types()
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1:
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build LM
- model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device)
- # Wrape model to make model.nll() data-parallel
- wrapped_model = ForwardAdaptor(model, "nll")
- wrapped_model.to(dtype=getattr(torch, dtype)).eval()
- logging.info(f"Model:\n{model}")
-
- # 3. Build data-iterator
- loader = LMTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=LMTask.build_preprocess_fn(train_args, False),
- collate_fn=LMTask.build_collate_fn(train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- # 4. Start for-loop
- with DatadirWriter(output_dir) as writer:
- total_nll = 0.0
- total_ntokens = 0
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
- with torch.no_grad():
- batch = to_device(batch, device)
- if ngpu <= 1:
- # NOTE(kamo): data_parallel also should work with ngpu=1,
- # but for debuggability it's better to keep this block.
- nll, lengths = wrapped_model(**batch)
- else:
- nll, lengths = data_parallel(
- wrapped_model, (), range(ngpu), module_kwargs=batch
- )
-
- assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
- # nll: (B, L) -> (B,)
- nll = nll.detach().cpu().numpy().sum(1)
- # lengths: (B,)
- lengths = lengths.detach().cpu().numpy()
- total_nll += nll.sum()
- total_ntokens += lengths.sum()
-
- for key, _nll, ntoken in zip(keys, nll, lengths):
- if log_base is None:
- utt_ppl = np.exp(_nll / ntoken)
- else:
- utt_ppl = log_base ** (_nll / ntoken / np.log(log_base))
-
- # Write PPL of each utts for debugging or analysis
- writer["utt2nll"][key] = str(-_nll)
- writer["utt2ppl"][key] = str(utt_ppl)
- writer["utt2ntokens"][key] = str(ntoken)
-
- if log_base is None:
- ppl = np.exp(total_nll / total_ntokens)
- else:
- ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
-
- with (Path(output_dir) / "ppl").open("w", encoding="utf-8") as f:
- f.write(f"{ppl}\n")
- with (Path(output_dir) / "base").open("w", encoding="utf-8") as f:
- if log_base is None:
- _log_base = np.e
- else:
- _log_base = log_base
- f.write(f"{_log_base}\n")
- logging.info(f"PPL={ppl}")
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Calc perplexity",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- parser.add_argument(
- "--log_base",
- type=float_or_none,
- default=None,
- help="The base of logarithm for Perplexity. "
- "If None, napier's constant is used.",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=True,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument("--train_config", type=str)
- group.add_argument("--model_file", type=str)
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
- calc_perplexity(**kwargs)
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/lm_inference.py b/funasr/bin/lm_inference.py
deleted file mode 100644
index 76de6df..0000000
--- a/funasr/bin/lm_inference.py
+++ /dev/null
@@ -1,406 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-from pathlib import Path
-import sys
-import os
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-from typing import Any
-from typing import List
-
-import numpy as np
-import torch
-from torch.nn.parallel import data_parallel
-from typeguard import check_argument_types
-
-from funasr.tasks.lm import LMTask
-from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import float_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-def inference(
- output_dir: str,
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- train_config: Optional[str],
- model_file: Optional[str],
- log_base: Optional[float],
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- split_with_space: Optional[bool] = False,
- seg_dict_file: Optional[str] = None,
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: Union[List[Any], bytes, str] = None,
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- output_dir=output_dir,
- raw_inputs=raw_inputs,
- batch_size=batch_size,
- dtype=dtype,
- ngpu=ngpu,
- seed=seed,
- num_workers=num_workers,
- log_level=log_level,
- key_file=key_file,
- train_config=train_config,
- model_file=model_file,
- log_base = log_base,
- allow_variable_data_keys = allow_variable_data_keys,
- split_with_space=split_with_space,
- seg_dict_file=seg_dict_file,
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-
-def inference_modelscope(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- log_base: Optional[float] = 10,
- allow_variable_data_keys: bool = False,
- split_with_space: Optional[bool] = False,
- seg_dict_file: Optional[str] = None,
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
-):
- assert check_argument_types()
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build Model
- model, train_args = LMTask.build_model_from_file(
- train_config, model_file, device)
- wrapped_model = ForwardAdaptor(model, "nll")
- wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
- logging.info(f"Model:\n{model}")
-
- preprocessor = LMPreprocessor(
- train=False,
- token_type=train_args.token_type,
- token_list=train_args.token_list,
- bpemodel=train_args.bpemodel,
- text_cleaner=train_args.cleaner,
- g2p_type=train_args.g2p,
- text_name="text",
- non_linguistic_symbols=train_args.non_linguistic_symbols,
- split_with_space=split_with_space,
- seg_dict_file=seg_dict_file
- )
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: dict = None,
- ):
- results = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- if raw_inputs != None:
- line = raw_inputs.strip()
- key = "lm demo"
- if line=="":
- item = {'key': key, 'value': ""}
- results.append(item)
- return results
- batch = {}
- batch['text'] = line
- if preprocessor != None:
- batch = preprocessor(key, batch)
-
- # Force data-precision
- for name in batch:
- value = batch[name]
- if not isinstance(value, np.ndarray):
- raise RuntimeError(
- f"All values must be converted to np.ndarray object "
- f'by preprocessing, but "{name}" is still {type(value)}.'
- )
- # Cast to desired type
- if value.dtype.kind == "f":
- value = value.astype("float32")
- elif value.dtype.kind == "i":
- value = value.astype("long")
- else:
- raise NotImplementedError(f"Not supported dtype: {value.dtype}")
- batch[name] = value
-
- batch["text_lengths"] = torch.from_numpy(
- np.array([len(batch["text"])], dtype='int32'))
- batch["text"] = np.expand_dims(batch["text"], axis=0)
-
- with torch.no_grad():
- batch = to_device(batch, device)
- if ngpu <= 1:
- nll, lengths = wrapped_model(**batch)
- else:
- nll, lengths = data_parallel(
- wrapped_model, (), range(ngpu), module_kwargs=batch
- )
- ## compute ppl
- ppl_out_batch = ""
- ids2tokens = preprocessor.token_id_converter.ids2tokens
- for sent_ids, sent_nll in zip(batch['text'], nll):
- pre_word = "<s>"
- cur_word = None
- sent_lst = ids2tokens(sent_ids) + ['</s>']
- ppl_out = " ".join(sent_lst) + "\n"
- for word, word_nll in zip(sent_lst, sent_nll):
- cur_word = word
- word_nll = -word_nll.cpu()
- if log_base is None:
- word_prob = np.exp(word_nll)
- else:
- word_prob = log_base ** (word_nll / np.log(log_base))
- ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
- cur=cur_word,
- pre=pre_word,
- prob=round(word_prob.item(), 8),
- word_nll=round(word_nll.item(), 8)
- )
- pre_word = cur_word
-
- sent_nll_mean = sent_nll.mean().cpu().numpy()
- sent_nll_sum = sent_nll.sum().cpu().numpy()
- if log_base is None:
- sent_ppl = np.exp(sent_nll_mean)
- else:
- sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
- ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
- sent_nll=round(-sent_nll_sum.item(), 4),
- sent_ppl=round(sent_ppl.item(), 4)
- )
- ppl_out_batch += ppl_out
- item = {'key': key, 'value': ppl_out}
- if writer is not None:
- writer["ppl"][key+":\n"] = ppl_out
- results.append(item)
-
- return results
-
- # 3. Build data-iterator
- loader = LMTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=preprocessor,
- collate_fn=LMTask.build_collate_fn(train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- # 4. Start for-loop
- total_nll = 0.0
- total_ntokens = 0
- ppl_out_all = ""
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
- ppl_out_batch = ""
- with torch.no_grad():
- batch = to_device(batch, device)
- if ngpu <= 1:
- # NOTE(kamo): data_parallel also should work with ngpu=1,
- # but for debuggability it's better to keep this block.
- nll, lengths = wrapped_model(**batch)
- else:
- nll, lengths = data_parallel(
- wrapped_model, (), range(ngpu), module_kwargs=batch
- )
- ## print ppl
- ids2tokens = preprocessor.token_id_converter.ids2tokens
- for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
- pre_word = "<s>"
- cur_word = None
- sent_lst = ids2tokens(sent_ids) + ['</s>']
- ppl_out = " ".join(sent_lst) + "\n"
- for word, word_nll in zip(sent_lst, sent_nll):
- cur_word = word
- word_nll = -word_nll.cpu()
- if log_base is None:
- word_prob = np.exp(word_nll)
- else:
- word_prob = log_base ** (word_nll / np.log(log_base))
- ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
- cur=cur_word,
- pre=pre_word,
- prob=round(word_prob.item(), 8),
- word_nll=round(word_nll.item(), 8)
- )
- pre_word = cur_word
-
- sent_nll_mean = sent_nll.mean().cpu().numpy()
- sent_nll_sum = sent_nll.sum().cpu().numpy()
- if log_base is None:
- sent_ppl = np.exp(sent_nll_mean)
- else:
- sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
- ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
- sent_nll=round(-sent_nll_sum.item(), 4),
- sent_ppl=round(sent_ppl.item(), 4)
- )
- ppl_out_batch += ppl_out
- utt2nll = round(-sent_nll_sum.item(), 5)
- item = {'key': key, 'value': ppl_out}
- if writer is not None:
- writer["ppl"][key+":\n"] = ppl_out
- writer["utt2nll"][key] = str(utt2nll)
- results.append(item)
-
- ppl_out_all += ppl_out_batch
-
- assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
- # nll: (B, L) -> (B,)
- nll = nll.detach().cpu().numpy().sum(1)
- # lengths: (B,)
- lengths = lengths.detach().cpu().numpy()
- total_nll += nll.sum()
- total_ntokens += lengths.sum()
-
- if log_base is None:
- ppl = np.exp(total_nll / total_ntokens)
- else:
- ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
-
- avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
- total_nll=round(-total_nll.item(), 4),
- total_ppl=round(ppl.item(), 4)
- )
- item = {'key': 'AVG PPL', 'value': avg_ppl}
- ppl_out_all += avg_ppl
- if writer is not None:
- writer["ppl"]["AVG PPL : "] = avg_ppl
- results.append(item)
-
- return results
-
- return _forward
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Calc perplexity",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=False)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- parser.add_argument(
- "--log_base",
- type=float_or_none,
- default=10,
- help="The base of logarithm for Perplexity. "
- "If None, napier's constant is used.",
- required=False
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- action="append",
- required=False
- )
- group.add_argument(
- "--raw_inputs",
- type=str,
- required=False
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group.add_argument("--split_with_space", type=str2bool, default=False)
- group.add_argument("--seg_dict_file", type=str_or_none)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument("--train_config", type=str)
- group.add_argument("--model_file", type=str)
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- inference(**kwargs)
-
-if __name__ == "__main__":
- main()
-
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
index dc6414f..0840e6e 100644
--- a/funasr/bin/lm_inference_launch.py
+++ b/funasr/bin/lm_inference_launch.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
@@ -14,8 +17,294 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.types import float_or_none
+import argparse
+import logging
+from pathlib import Path
+import sys
+import os
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+from typing import Any
+from typing import List
+
+import numpy as np
+import torch
+from torch.nn.parallel import data_parallel
+from typeguard import check_argument_types
+
+from funasr.tasks.lm import LMTask
+from funasr.datasets.preprocessor import LMPreprocessor
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import float_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+def inference_lm(
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ log_base: Optional[float] = 10,
+ allow_variable_data_keys: bool = False,
+ split_with_space: Optional[bool] = False,
+ seg_dict_file: Optional[str] = None,
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build Model
+ model, train_args = LMTask.build_model_from_file(
+ train_config, model_file, device)
+ wrapped_model = ForwardAdaptor(model, "nll")
+ wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+ logging.info(f"Model:\n{model}")
+
+ preprocessor = LMPreprocessor(
+ train=False,
+ token_type=train_args.token_type,
+ token_list=train_args.token_list,
+ bpemodel=train_args.bpemodel,
+ text_cleaner=train_args.cleaner,
+ g2p_type=train_args.g2p,
+ text_name="text",
+ non_linguistic_symbols=train_args.non_linguistic_symbols,
+ split_with_space=split_with_space,
+ seg_dict_file=seg_dict_file
+ )
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: dict = None,
+ ):
+ results = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ if raw_inputs != None:
+ line = raw_inputs.strip()
+ key = "lm demo"
+ if line == "":
+ item = {'key': key, 'value': ""}
+ results.append(item)
+ return results
+ batch = {}
+ batch['text'] = line
+ if preprocessor != None:
+ batch = preprocessor(key, batch)
+
+ # Force data-precision
+ for name in batch:
+ value = batch[name]
+ if not isinstance(value, np.ndarray):
+ raise RuntimeError(
+ f"All values must be converted to np.ndarray object "
+ f'by preprocessing, but "{name}" is still {type(value)}.'
+ )
+ # Cast to desired type
+ if value.dtype.kind == "f":
+ value = value.astype("float32")
+ elif value.dtype.kind == "i":
+ value = value.astype("long")
+ else:
+ raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+ batch[name] = value
+
+ batch["text_lengths"] = torch.from_numpy(
+ np.array([len(batch["text"])], dtype='int32'))
+ batch["text"] = np.expand_dims(batch["text"], axis=0)
+
+ with torch.no_grad():
+ batch = to_device(batch, device)
+ if ngpu <= 1:
+ nll, lengths = wrapped_model(**batch)
+ else:
+ nll, lengths = data_parallel(
+ wrapped_model, (), range(ngpu), module_kwargs=batch
+ )
+ ## compute ppl
+ ppl_out_batch = ""
+ ids2tokens = preprocessor.token_id_converter.ids2tokens
+ for sent_ids, sent_nll in zip(batch['text'], nll):
+ pre_word = "<s>"
+ cur_word = None
+ sent_lst = ids2tokens(sent_ids) + ['</s>']
+ ppl_out = " ".join(sent_lst) + "\n"
+ for word, word_nll in zip(sent_lst, sent_nll):
+ cur_word = word
+ word_nll = -word_nll.cpu()
+ if log_base is None:
+ word_prob = np.exp(word_nll)
+ else:
+ word_prob = log_base ** (word_nll / np.log(log_base))
+ ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+ cur=cur_word,
+ pre=pre_word,
+ prob=round(word_prob.item(), 8),
+ word_nll=round(word_nll.item(), 8)
+ )
+ pre_word = cur_word
+
+ sent_nll_mean = sent_nll.mean().cpu().numpy()
+ sent_nll_sum = sent_nll.sum().cpu().numpy()
+ if log_base is None:
+ sent_ppl = np.exp(sent_nll_mean)
+ else:
+ sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+ ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+ sent_nll=round(-sent_nll_sum.item(), 4),
+ sent_ppl=round(sent_ppl.item(), 4)
+ )
+ ppl_out_batch += ppl_out
+ item = {'key': key, 'value': ppl_out}
+ if writer is not None:
+ writer["ppl"][key + ":\n"] = ppl_out
+ results.append(item)
+
+ return results
+
+ # 3. Build data-iterator
+ loader = LMTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=preprocessor,
+ collate_fn=LMTask.build_collate_fn(train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ # 4. Start for-loop
+ total_nll = 0.0
+ total_ntokens = 0
+ ppl_out_all = ""
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+
+ ppl_out_batch = ""
+ with torch.no_grad():
+ batch = to_device(batch, device)
+ if ngpu <= 1:
+ # NOTE(kamo): data_parallel also should work with ngpu=1,
+ # but for debuggability it's better to keep this block.
+ nll, lengths = wrapped_model(**batch)
+ else:
+ nll, lengths = data_parallel(
+ wrapped_model, (), range(ngpu), module_kwargs=batch
+ )
+ ## print ppl
+ ids2tokens = preprocessor.token_id_converter.ids2tokens
+ for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
+ pre_word = "<s>"
+ cur_word = None
+ sent_lst = ids2tokens(sent_ids) + ['</s>']
+ ppl_out = " ".join(sent_lst) + "\n"
+ for word, word_nll in zip(sent_lst, sent_nll):
+ cur_word = word
+ word_nll = -word_nll.cpu()
+ if log_base is None:
+ word_prob = np.exp(word_nll)
+ else:
+ word_prob = log_base ** (word_nll / np.log(log_base))
+ ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+ cur=cur_word,
+ pre=pre_word,
+ prob=round(word_prob.item(), 8),
+ word_nll=round(word_nll.item(), 8)
+ )
+ pre_word = cur_word
+
+ sent_nll_mean = sent_nll.mean().cpu().numpy()
+ sent_nll_sum = sent_nll.sum().cpu().numpy()
+ if log_base is None:
+ sent_ppl = np.exp(sent_nll_mean)
+ else:
+ sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+ ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+ sent_nll=round(-sent_nll_sum.item(), 4),
+ sent_ppl=round(sent_ppl.item(), 4)
+ )
+ ppl_out_batch += ppl_out
+ utt2nll = round(-sent_nll_sum.item(), 5)
+ item = {'key': key, 'value': ppl_out}
+ if writer is not None:
+ writer["ppl"][key + ":\n"] = ppl_out
+ writer["utt2nll"][key] = str(utt2nll)
+ results.append(item)
+
+ ppl_out_all += ppl_out_batch
+
+ assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
+ # nll: (B, L) -> (B,)
+ nll = nll.detach().cpu().numpy().sum(1)
+ # lengths: (B,)
+ lengths = lengths.detach().cpu().numpy()
+ total_nll += nll.sum()
+ total_ntokens += lengths.sum()
+
+ if log_base is None:
+ ppl = np.exp(total_nll / total_ntokens)
+ else:
+ ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
+
+ avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
+ total_nll=round(-total_nll.item(), 4),
+ total_ppl=round(ppl.item(), 4)
+ )
+ item = {'key': 'AVG PPL', 'value': avg_ppl}
+ ppl_out_all += avg_ppl
+ if writer is not None:
+ writer["ppl"]["AVG PPL : "] = avg_ppl
+ results.append(item)
+
+ return results
+
+ return _forward
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "transformer":
+ return inference_lm(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Calc perplexity",
@@ -89,14 +378,6 @@
group.add_argument("--model_file", type=str)
group.add_argument("--mode", type=str, default="lm")
return parser
-
-def inference_launch(mode, **kwargs):
- if mode == "transformer":
- from funasr.bin.lm_inference import inference_modelscope
- return inference_modelscope(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
def main(cmd=None):
diff --git a/funasr/bin/lm_train.py b/funasr/bin/lm_train.py
index 8641465..22b5f9c 100755
--- a/funasr/bin/lm_train.py
+++ b/funasr/bin/lm_train.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import os
diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
index 41c4da3..4b6cd27 100644
--- a/funasr/bin/punc_infer.py
+++ b/funasr/bin/punc_infer.py
@@ -1,4 +1,8 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import argparse
import logging
from pathlib import Path
diff --git a/funasr/bin/punc_inference_launch.py b/funasr/bin/punc_inference_launch.py
index 594a7be..7f60f81 100755
--- a/funasr/bin/punc_inference_launch.py
+++ b/funasr/bin/punc_inference_launch.py
@@ -1,5 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
@@ -175,6 +177,16 @@
return _forward
+
+def inference_launch(mode, **kwargs):
+ if mode == "punc":
+ return inference_punc(**kwargs)
+ if mode == "punc_VadRealtime":
+ return inference_punc_vad_realtime(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Punctuation inference",
@@ -228,16 +240,6 @@
group.add_argument("--model_file", type=str)
group.add_argument("--mode", type=str, default="punc")
return parser
-
-
-def inference_launch(mode, **kwargs):
- if mode == "punc":
- return inference_punc(**kwargs)
- if mode == "punc_VadRealtime":
- return inference_punc_vad_realtime(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
def main(cmd=None):
diff --git a/funasr/bin/punc_train.py b/funasr/bin/punc_train.py
index 61b63ec..aeded7b 100644
--- a/funasr/bin/punc_train.py
+++ b/funasr/bin/punc_train.py
@@ -1,4 +1,8 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import os
from funasr.tasks.punctuation import PunctuationTask
diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py
index c894f54..7a5ba83 100644
--- a/funasr/bin/sa_asr_inference.py
+++ b/funasr/bin/sa_asr_inference.py
@@ -1,3 +1,8 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import argparse
import logging
import sys
diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py
index 07b9b19..67106cf 100755
--- a/funasr/bin/sa_asr_train.py
+++ b/funasr/bin/sa_asr_train.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import os
diff --git a/funasr/bin/sv_infer.py b/funasr/bin/sv_infer.py
index 8a9c6e9..9761497 100755
--- a/funasr/bin/sv_infer.py
+++ b/funasr/bin/sv_infer.py
@@ -1,3 +1,4 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py
index 24b8638..8e00730 100755
--- a/funasr/bin/sv_inference_launch.py
+++ b/funasr/bin/sv_inference_launch.py
@@ -1,7 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-
import argparse
import logging
@@ -174,6 +174,15 @@
return _forward
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "sv":
+ return inference_sv(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker Verification",
@@ -287,14 +296,6 @@
)
return parser
-
-
-def inference_launch(mode, **kwargs):
- if mode == "sv":
- return inference_sv(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
def main(cmd=None):
diff --git a/funasr/bin/tp_infer.py b/funasr/bin/tp_infer.py
index c83ceea..4ddcba4 100644
--- a/funasr/bin/tp_infer.py
+++ b/funasr/bin/tp_infer.py
@@ -1,3 +1,8 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import argparse
import logging
from optparse import Option
diff --git a/funasr/bin/tp_inference_launch.py b/funasr/bin/tp_inference_launch.py
index 2b2b2ae..a8d67ef 100644
--- a/funasr/bin/tp_inference_launch.py
+++ b/funasr/bin/tp_inference_launch.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import argparse
@@ -179,6 +182,15 @@
return _forward
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "tp_norm":
+ return inference_tp(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Timestamp Prediction Inference",
@@ -264,13 +276,6 @@
)
return parser
-
-def inference_launch(mode, **kwargs):
- if mode == "tp_norm":
- return inference_tp(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
diff --git a/funasr/bin/vad_infer.py b/funasr/bin/vad_infer.py
index 5835e77..245757c 100644
--- a/funasr/bin/vad_infer.py
+++ b/funasr/bin/vad_infer.py
@@ -1,3 +1,8 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import argparse
import logging
import os
diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py
index 2ccc716..1f17c5b 100644
--- a/funasr/bin/vad_inference_launch.py
+++ b/funasr/bin/vad_inference_launch.py
@@ -1,6 +1,8 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import torch
torch.set_num_threads(1)
@@ -267,6 +269,17 @@
return _forward
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "offline":
+ return inference_vad(**kwargs)
+ elif mode == "online":
+ return inference_vad_online(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="VAD Decoding",
@@ -357,15 +370,6 @@
)
return parser
-
-def inference_launch(mode, **kwargs):
- if mode == "offline":
- return inference_vad(**kwargs)
- elif mode == "online":
- return inference_vad_online(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
--
Gitblit v1.9.1