| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import argparse |
| | | import logging |
| | | import sys |
| | |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from packaging.version import parse as V |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | from funasr.fileio.datadir_writer import DatadirWriter |
| | | from funasr.modules.beam_search.beam_search import BeamSearch |
| | | # from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch |
| | | |
| | | from funasr.modules.beam_search.beam_search import Hypothesis |
| | | from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer |
| | | from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer |
| | | from funasr.modules.scorers.ctc import CTCPrefixScorer |
| | | from funasr.modules.scorers.length_bonus import LengthBonus |
| | | from funasr.modules.subsampling import TooShortUttError |
| | |
| | | from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | from funasr.bin.tp_infer import Speech2Timestamp |
| | | from funasr.bin.vad_inference import Speech2VadSegment |
| | | from funasr.bin.vad_infer import Speech2VadSegment |
| | | from funasr.bin.punc_infer import Text2Punc |
| | | from funasr.utils.vad_utils import slice_padding_fbank |
| | | from funasr.tasks.vad import VADTask |
| | | |
| | | from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard |
| | | |
| | | |
| | |
| | | |
| | | assert check_return_type(results) |
| | | return results |
| | | |
| | | |
| | | class Speech2TextParaformer: |
| | | """Speech2Text class |
| | |
| | | # assert check_return_type(results) |
| | | return results |
| | | |
| | | |
| | | class Speech2TextUniASR: |
| | | """Speech2Text class |
| | | |
| | |
| | | |
| | | assert check_return_type(results) |
| | | return results |
| | | |
| | | |
| | | |
| | | |
| | | |
| | | class Speech2TextMFCCA: |
| | | """Speech2Text class |
| | |
| | | assert check_argument_types() |
| | | |
| | | # 1. Build ASR model |
| | | from funasr.tasks.asr import ASRTaskMFCCA as ASRTask |
| | | scorers = {} |
| | | asr_model, asr_train_args = ASRTask.build_model_from_file( |
| | | asr_train_config, asr_model_file, cmvn_file, device |
| | |
| | | return results |
| | | |
| | | |
| | | class Speech2TextTransducer: |
| | | """Speech2Text class for Transducer models. |
| | | Args: |
| | | asr_train_config: ASR model training config path. |
| | | asr_model_file: ASR model path. |
| | | beam_search_config: Beam search config path. |
| | | lm_train_config: Language Model training config path. |
| | | lm_file: Language Model config path. |
| | | token_type: Type of token units. |
| | | bpemodel: BPE model path. |
| | | device: Device to use for inference. |
| | | beam_size: Size of beam during search. |
| | | dtype: Data type. |
| | | lm_weight: Language model weight. |
| | | quantize_asr_model: Whether to apply dynamic quantization to ASR model. |
| | | quantize_modules: List of module names to apply dynamic quantization on. |
| | | quantize_dtype: Dynamic quantization data type. |
| | | nbest: Number of final hypothesis. |
| | | streaming: Whether to perform chunk-by-chunk inference. |
| | | chunk_size: Number of frames in chunk AFTER subsampling. |
| | | left_context: Number of frames in left context AFTER subsampling. |
| | | right_context: Number of frames in right context AFTER subsampling. |
| | | display_partial_hypotheses: Whether to display partial hypotheses. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | asr_train_config: Union[Path, str] = None, |
| | | asr_model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | beam_search_config: Dict[str, Any] = None, |
| | | lm_train_config: Union[Path, str] = None, |
| | | lm_file: Union[Path, str] = None, |
| | | token_type: str = None, |
| | | bpemodel: str = None, |
| | | device: str = "cpu", |
| | | beam_size: int = 5, |
| | | dtype: str = "float32", |
| | | lm_weight: float = 1.0, |
| | | quantize_asr_model: bool = False, |
| | | quantize_modules: List[str] = None, |
| | | quantize_dtype: str = "qint8", |
| | | nbest: int = 1, |
| | | streaming: bool = False, |
| | | simu_streaming: bool = False, |
| | | chunk_size: int = 16, |
| | | left_context: int = 32, |
| | | right_context: int = 0, |
| | | display_partial_hypotheses: bool = False, |
| | | ) -> None: |
| | | """Construct a Speech2Text object.""" |
| | | super().__init__() |
| | | |
| | | assert check_argument_types() |
| | | from funasr.tasks.asr import ASRTransducerTask |
| | | asr_model, asr_train_args = ASRTransducerTask.build_model_from_file( |
| | | asr_train_config, asr_model_file, cmvn_file, device |
| | | ) |
| | | |
| | | frontend = None |
| | | if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: |
| | | frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) |
| | | |
| | | if quantize_asr_model: |
| | | if quantize_modules is not None: |
| | | if not all([q in ["LSTM", "Linear"] for q in quantize_modules]): |
| | | raise ValueError( |
| | | "Only 'Linear' and 'LSTM' modules are currently supported" |
| | | " by PyTorch and in --quantize_modules" |
| | | ) |
| | | |
| | | q_config = set([getattr(torch.nn, q) for q in quantize_modules]) |
| | | else: |
| | | q_config = {torch.nn.Linear} |
| | | |
| | | if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")): |
| | | raise ValueError( |
| | | "float16 dtype for dynamic quantization is not supported with torch" |
| | | " version < 1.5.0. Switching to qint8 dtype instead." |
| | | ) |
| | | q_dtype = getattr(torch, quantize_dtype) |
| | | |
| | | asr_model = torch.quantization.quantize_dynamic( |
| | | asr_model, q_config, dtype=q_dtype |
| | | ).eval() |
| | | else: |
| | | asr_model.to(dtype=getattr(torch, dtype)).eval() |
| | | |
| | | if lm_train_config is not None: |
| | | lm, lm_train_args = LMTask.build_model_from_file( |
| | | lm_train_config, lm_file, device |
| | | ) |
| | | lm_scorer = lm.lm |
| | | else: |
| | | lm_scorer = None |
| | | |
| | | # 4. Build BeamSearch object |
| | | if beam_search_config is None: |
| | | beam_search_config = {} |
| | | |
| | | beam_search = BeamSearchTransducer( |
| | | asr_model.decoder, |
| | | asr_model.joint_network, |
| | | beam_size, |
| | | lm=lm_scorer, |
| | | lm_weight=lm_weight, |
| | | nbest=nbest, |
| | | **beam_search_config, |
| | | ) |
| | | |
| | | token_list = asr_model.token_list |
| | | |
| | | if token_type is None: |
| | | token_type = asr_train_args.token_type |
| | | if bpemodel is None: |
| | | bpemodel = asr_train_args.bpemodel |
| | | |
| | | if token_type is None: |
| | | tokenizer = None |
| | | elif token_type == "bpe": |
| | | if bpemodel is not None: |
| | | tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) |
| | | else: |
| | | tokenizer = None |
| | | else: |
| | | tokenizer = build_tokenizer(token_type=token_type) |
| | | converter = TokenIDConverter(token_list=token_list) |
| | | logging.info(f"Text tokenizer: {tokenizer}") |
| | | |
| | | self.asr_model = asr_model |
| | | self.asr_train_args = asr_train_args |
| | | self.device = device |
| | | self.dtype = dtype |
| | | self.nbest = nbest |
| | | |
| | | self.converter = converter |
| | | self.tokenizer = tokenizer |
| | | |
| | | self.beam_search = beam_search |
| | | self.streaming = streaming |
| | | self.simu_streaming = simu_streaming |
| | | self.chunk_size = max(chunk_size, 0) |
| | | self.left_context = left_context |
| | | self.right_context = max(right_context, 0) |
| | | |
| | | if not streaming or chunk_size == 0: |
| | | self.streaming = False |
| | | self.asr_model.encoder.dynamic_chunk_training = False |
| | | |
| | | if not simu_streaming or chunk_size == 0: |
| | | self.simu_streaming = False |
| | | self.asr_model.encoder.dynamic_chunk_training = False |
| | | |
| | | self.frontend = frontend |
| | | self.window_size = self.chunk_size + self.right_context |
| | | |
| | | if self.streaming: |
| | | self._ctx = self.asr_model.encoder.get_encoder_input_size( |
| | | self.window_size |
| | | ) |
| | | |
| | | self.last_chunk_length = ( |
| | | self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 |
| | | ) |
| | | self.reset_inference_cache() |
| | | |
| | | def reset_inference_cache(self) -> None: |
| | | """Reset Speech2Text parameters.""" |
| | | self.frontend_cache = None |
| | | |
| | | self.asr_model.encoder.reset_streaming_cache( |
| | | self.left_context, device=self.device |
| | | ) |
| | | self.beam_search.reset_inference_cache() |
| | | |
| | | self.num_processed_frames = torch.tensor([[0]], device=self.device) |
| | | |
| | | @torch.no_grad() |
| | | def streaming_decode( |
| | | self, |
| | | speech: Union[torch.Tensor, np.ndarray], |
| | | is_final: bool = True, |
| | | ) -> List[HypothesisTransducer]: |
| | | """Speech2Text streaming call. |
| | | Args: |
| | | speech: Chunk of speech data. (S) |
| | | is_final: Whether speech corresponds to the final chunk of data. |
| | | Returns: |
| | | nbest_hypothesis: N-best hypothesis. |
| | | """ |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | | if is_final: |
| | | if self.streaming and speech.size(0) < self.last_chunk_length: |
| | | pad = torch.zeros( |
| | | self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype |
| | | ) |
| | | speech = torch.cat([speech, pad], |
| | | dim=0) # feats, feats_length = self.apply_frontend(speech, is_final=is_final) |
| | | |
| | | feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) |
| | | feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) |
| | | |
| | | if self.asr_model.normalize is not None: |
| | | feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) |
| | | |
| | | feats = to_device(feats, device=self.device) |
| | | feats_lengths = to_device(feats_lengths, device=self.device) |
| | | enc_out = self.asr_model.encoder.chunk_forward( |
| | | feats, |
| | | feats_lengths, |
| | | self.num_processed_frames, |
| | | chunk_size=self.chunk_size, |
| | | left_context=self.left_context, |
| | | right_context=self.right_context, |
| | | ) |
| | | nbest_hyps = self.beam_search(enc_out[0], is_final=is_final) |
| | | |
| | | self.num_processed_frames += self.chunk_size |
| | | |
| | | if is_final: |
| | | self.reset_inference_cache() |
| | | |
| | | return nbest_hyps |
| | | |
| | | @torch.no_grad() |
| | | def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]: |
| | | """Speech2Text call. |
| | | Args: |
| | | speech: Speech data. (S) |
| | | Returns: |
| | | nbest_hypothesis: N-best hypothesis. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | | |
| | | feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) |
| | | feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) |
| | | |
| | | if self.asr_model.normalize is not None: |
| | | feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) |
| | | |
| | | feats = to_device(feats, device=self.device) |
| | | feats_lengths = to_device(feats_lengths, device=self.device) |
| | | enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context, |
| | | self.right_context) |
| | | nbest_hyps = self.beam_search(enc_out[0]) |
| | | |
| | | return nbest_hyps |
| | | |
| | | @torch.no_grad() |
| | | def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]: |
| | | """Speech2Text call. |
| | | Args: |
| | | speech: Speech data. (S) |
| | | Returns: |
| | | nbest_hypothesis: N-best hypothesis. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | | |
| | | feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) |
| | | feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) |
| | | |
| | | feats = to_device(feats, device=self.device) |
| | | feats_lengths = to_device(feats_lengths, device=self.device) |
| | | |
| | | enc_out, _ = self.asr_model.encoder(feats, feats_lengths) |
| | | |
| | | nbest_hyps = self.beam_search(enc_out[0]) |
| | | |
| | | return nbest_hyps |
| | | |
| | | def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]: |
| | | """Build partial or final results from the hypotheses. |
| | | Args: |
| | | nbest_hyps: N-best hypothesis. |
| | | Returns: |
| | | results: Results containing different representation for the hypothesis. |
| | | """ |
| | | results = [] |
| | | |
| | | for hyp in nbest_hyps: |
| | | token_int = list(filter(lambda x: x != 0, hyp.yseq)) |
| | | |
| | | token = self.converter.ids2tokens(token_int) |
| | | |
| | | if self.tokenizer is not None: |
| | | text = self.tokenizer.tokens2text(token) |
| | | else: |
| | | text = None |
| | | results.append((text, token, token_int, hyp)) |
| | | |
| | | assert check_return_type(results) |
| | | |
| | | return results |
| | | |
| | | @staticmethod |
| | | def from_pretrained( |
| | | model_tag: Optional[str] = None, |
| | | **kwargs: Optional[Any], |
| | | ) -> Speech2Text: |
| | | """Build Speech2Text instance from the pretrained model. |
| | | Args: |
| | | model_tag: Model tag of the pretrained models. |
| | | Return: |
| | | : Speech2Text instance. |
| | | """ |
| | | if model_tag is not None: |
| | | try: |
| | | from espnet_model_zoo.downloader import ModelDownloader |
| | | |
| | | except ImportError: |
| | | logging.error( |
| | | "`espnet_model_zoo` is not installed. " |
| | | "Please install via `pip install -U espnet_model_zoo`." |
| | | ) |
| | | raise |
| | | d = ModelDownloader() |
| | | kwargs.update(**d.download_and_unpack(model_tag)) |
| | | |
| | | return Speech2Text(**kwargs) |
| | | |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import argparse |
| | | import logging |
| | |
| | | from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer |
| | | from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | from funasr.bin.tp_inference import SpeechText2Timestamp |
| | | from funasr.bin.vad_inference import Speech2VadSegment |
| | | from funasr.bin.punctuation_infer import Text2Punc |
| | | |
| | | |
| | | from funasr.utils.vad_utils import slice_padding_fbank |
| | | from funasr.tasks.vad import VADTask |
| | | from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard |
| | | from funasr.bin.asr_infer import Speech2Text |
| | | from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline |
| | | from funasr.bin.asr_infer import Speech2TextUniASR |
| | | from funasr.bin.asr_infer import Speech2TextMFCCA |
| | | from funasr.bin.vad_infer import Speech2VadSegment |
| | | from funasr.bin.punc_infer import Text2Punc |
| | | from funasr.bin.tp_infer import Speech2Timestamp |
| | | from funasr.bin.asr_infer import Speech2TextTransducer |
| | | |
| | | def inference_asr( |
| | | maxlenratio: float, |
| | | minlenratio: float, |
| | | batch_size: int, |
| | | beam_size: int, |
| | | ngpu: int, |
| | | ctc_weight: float, |
| | | lm_weight: float, |
| | | penalty: float, |
| | | log_level: Union[int, str], |
| | | # data_path_and_name_and_type, |
| | | asr_train_config: Optional[str], |
| | | asr_model_file: Optional[str], |
| | | cmvn_file: Optional[str] = None, |
| | | lm_train_config: Optional[str] = None, |
| | | lm_file: Optional[str] = None, |
| | | token_type: Optional[str] = None, |
| | | key_file: Optional[str] = None, |
| | | word_lm_train_config: Optional[str] = None, |
| | | bpemodel: Optional[str] = None, |
| | | allow_variable_data_keys: bool = False, |
| | | streaming: bool = False, |
| | | output_dir: Optional[str] = None, |
| | | dtype: str = "float32", |
| | | seed: int = 0, |
| | | ngram_weight: float = 0.9, |
| | | nbest: int = 1, |
| | | num_workers: int = 1, |
| | | mc: bool = False, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | ncpu = kwargs.get("ncpu", 1) |
| | | torch.set_num_threads(ncpu) |
| | | if batch_size > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | if word_lm_train_config is not None: |
| | | raise NotImplementedError("Word LM is not implemented") |
| | | if ngpu > 1: |
| | | raise NotImplementedError("only single GPU decoding is supported") |
| | | |
| | | for handler in logging.root.handlers[:]: |
| | | logging.root.removeHandler(handler) |
| | | |
| | | logging.basicConfig( |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | if ngpu >= 1 and torch.cuda.is_available(): |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2. Build speech2text |
| | | speech2text_kwargs = dict( |
| | | asr_train_config=asr_train_config, |
| | | asr_model_file=asr_model_file, |
| | | cmvn_file=cmvn_file, |
| | | lm_train_config=lm_train_config, |
| | | lm_file=lm_file, |
| | | token_type=token_type, |
| | | bpemodel=bpemodel, |
| | | device=device, |
| | | maxlenratio=maxlenratio, |
| | | minlenratio=minlenratio, |
| | | dtype=dtype, |
| | | beam_size=beam_size, |
| | | ctc_weight=ctc_weight, |
| | | lm_weight=lm_weight, |
| | | ngram_weight=ngram_weight, |
| | | penalty=penalty, |
| | | nbest=nbest, |
| | | streaming=streaming, |
| | | ) |
| | | logging.info("speech2text_kwargs: {}".format(speech2text_kwargs)) |
| | | speech2text = Speech2Text(**speech2text_kwargs) |
| | | |
| | | def _forward(data_path_and_name_and_type, |
| | | raw_inputs: Union[np.ndarray, torch.Tensor] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | | fs: dict = None, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | # 3. Build data-iterator |
| | | if data_path_and_name_and_type is None and raw_inputs is not None: |
| | | if isinstance(raw_inputs, torch.Tensor): |
| | | raw_inputs = raw_inputs.numpy() |
| | | data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] |
| | | loader = ASRTask.build_streaming_iterator( |
| | | data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | fs=fs, |
| | | mc=mc, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | | preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), |
| | | collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | ) |
| | | |
| | | finish_count = 0 |
| | | file_count = 1 |
| | | # 7 .Start for-loop |
| | | # FIXME(kamo): The output format should be discussed about |
| | | asr_result_list = [] |
| | | output_path = output_dir_v2 if output_dir_v2 is not None else output_dir |
| | | if output_path is not None: |
| | | writer = DatadirWriter(output_path) |
| | | else: |
| | | writer = None |
| | | |
| | | for keys, batch in loader: |
| | | assert isinstance(batch, dict), type(batch) |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | _bs = len(next(iter(batch.values()))) |
| | | assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
| | | # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} |
| | | |
| | | # N-best list of (text, token, token_int, hyp_object) |
| | | try: |
| | | results = speech2text(**batch) |
| | | except TooShortUttError as e: |
| | | logging.warning(f"Utterance {keys} {e}") |
| | | hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) |
| | | results = [[" ", ["sil"], [2], hyp]] * nbest |
| | | |
| | | # Only supporting batch_size==1 |
| | | key = keys[0] |
| | | for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): |
| | | # Create a directory: outdir/{n}best_recog |
| | | if writer is not None: |
| | | ibest_writer = writer[f"{n}best_recog"] |
| | | |
| | | # Write the result to each file |
| | | ibest_writer["token"][key] = " ".join(token) |
| | | ibest_writer["token_int"][key] = " ".join(map(str, token_int)) |
| | | ibest_writer["score"][key] = str(hyp.score) |
| | | |
| | | if text is not None: |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | | asr_utils.print_progress(finish_count / file_count) |
| | | if writer is not None: |
| | | ibest_writer["text"][key] = text |
| | | |
| | | logging.info("uttid: {}".format(key)) |
| | | logging.info("text predictions: {}\n".format(text)) |
| | | return asr_result_list |
| | | |
| | | return _forward |
| | | |
| | | |
| | | def inference_paraformer( |
| | |
| | | speech2text = Speech2TextParaformer(**speech2text_kwargs) |
| | | |
| | | if timestamp_model_file is not None: |
| | | speechtext2timestamp = SpeechText2Timestamp( |
| | | speechtext2timestamp = Speech2Timestamp( |
| | | timestamp_cmvn_file=cmvn_file, |
| | | timestamp_model_file=timestamp_model_file, |
| | | timestamp_infer_config=timestamp_infer_config, |
| | |
| | | return _forward |
| | | |
| | | |
| | | def inference_mfcca( |
| | | maxlenratio: float, |
| | | minlenratio: float, |
| | | batch_size: int, |
| | | beam_size: int, |
| | | ngpu: int, |
| | | ctc_weight: float, |
| | | lm_weight: float, |
| | | penalty: float, |
| | | log_level: Union[int, str], |
| | | # data_path_and_name_and_type, |
| | | asr_train_config: Optional[str], |
| | | asr_model_file: Optional[str], |
| | | cmvn_file: Optional[str] = None, |
| | | lm_train_config: Optional[str] = None, |
| | | lm_file: Optional[str] = None, |
| | | token_type: Optional[str] = None, |
| | | key_file: Optional[str] = None, |
| | | word_lm_train_config: Optional[str] = None, |
| | | bpemodel: Optional[str] = None, |
| | | allow_variable_data_keys: bool = False, |
| | | streaming: bool = False, |
| | | output_dir: Optional[str] = None, |
| | | dtype: str = "float32", |
| | | seed: int = 0, |
| | | ngram_weight: float = 0.9, |
| | | nbest: int = 1, |
| | | num_workers: int = 1, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | ncpu = kwargs.get("ncpu", 1) |
| | | torch.set_num_threads(ncpu) |
| | | if batch_size > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | if word_lm_train_config is not None: |
| | | raise NotImplementedError("Word LM is not implemented") |
| | | if ngpu > 1: |
| | | raise NotImplementedError("only single GPU decoding is supported") |
| | | |
| | | logging.basicConfig( |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | if ngpu >= 1 and torch.cuda.is_available(): |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2. Build speech2text |
| | | speech2text_kwargs = dict( |
| | | asr_train_config=asr_train_config, |
| | | asr_model_file=asr_model_file, |
| | | cmvn_file=cmvn_file, |
| | | lm_train_config=lm_train_config, |
| | | lm_file=lm_file, |
| | | token_type=token_type, |
| | | bpemodel=bpemodel, |
| | | device=device, |
| | | maxlenratio=maxlenratio, |
| | | minlenratio=minlenratio, |
| | | dtype=dtype, |
| | | beam_size=beam_size, |
| | | ctc_weight=ctc_weight, |
| | | lm_weight=lm_weight, |
| | | ngram_weight=ngram_weight, |
| | | penalty=penalty, |
| | | nbest=nbest, |
| | | streaming=streaming, |
| | | ) |
| | | logging.info("speech2text_kwargs: {}".format(speech2text_kwargs)) |
| | | speech2text = Speech2TextMFCCA(**speech2text_kwargs) |
| | | |
| | | def _forward(data_path_and_name_and_type, |
| | | raw_inputs: Union[np.ndarray, torch.Tensor] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | | fs: dict = None, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | # 3. Build data-iterator |
| | | if data_path_and_name_and_type is None and raw_inputs is not None: |
| | | if isinstance(raw_inputs, torch.Tensor): |
| | | raw_inputs = raw_inputs.numpy() |
| | | data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] |
| | | loader = ASRTask.build_streaming_iterator( |
| | | data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | batch_size=batch_size, |
| | | fs=fs, |
| | | mc=True, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | | preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), |
| | | collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | ) |
| | | |
| | | finish_count = 0 |
| | | file_count = 1 |
| | | # 7 .Start for-loop |
| | | # FIXME(kamo): The output format should be discussed about |
| | | asr_result_list = [] |
| | | output_path = output_dir_v2 if output_dir_v2 is not None else output_dir |
| | | if output_path is not None: |
| | | writer = DatadirWriter(output_path) |
| | | else: |
| | | writer = None |
| | | |
| | | for keys, batch in loader: |
| | | assert isinstance(batch, dict), type(batch) |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | _bs = len(next(iter(batch.values()))) |
| | | assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
| | | # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} |
| | | |
| | | # N-best list of (text, token, token_int, hyp_object) |
| | | try: |
| | | results = speech2text(**batch) |
| | | except TooShortUttError as e: |
| | | logging.warning(f"Utterance {keys} {e}") |
| | | hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) |
| | | results = [[" ", ["<space>"], [2], hyp]] * nbest |
| | | |
| | | # Only supporting batch_size==1 |
| | | key = keys[0] |
| | | for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): |
| | | # Create a directory: outdir/{n}best_recog |
| | | if writer is not None: |
| | | ibest_writer = writer[f"{n}best_recog"] |
| | | |
| | | # Write the result to each file |
| | | ibest_writer["token"][key] = " ".join(token) |
| | | # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) |
| | | ibest_writer["score"][key] = str(hyp.score) |
| | | |
| | | if text is not None: |
| | | text_postprocessed = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | | asr_utils.print_progress(finish_count / file_count) |
| | | if writer is not None: |
| | | ibest_writer["text"][key] = text |
| | | return asr_result_list |
| | | |
| | | return _forward |
| | | |
| | | def inference_transducer( |
| | | output_dir: str, |
| | | batch_size: int, |
| | | dtype: str, |
| | | beam_size: int, |
| | | ngpu: int, |
| | | seed: int, |
| | | lm_weight: float, |
| | | nbest: int, |
| | | num_workers: int, |
| | | log_level: Union[int, str], |
| | | data_path_and_name_and_type: Sequence[Tuple[str, str, str]], |
| | | asr_train_config: Optional[str], |
| | | asr_model_file: Optional[str], |
| | | cmvn_file: Optional[str], |
| | | beam_search_config: Optional[dict], |
| | | lm_train_config: Optional[str], |
| | | lm_file: Optional[str], |
| | | model_tag: Optional[str], |
| | | token_type: Optional[str], |
| | | bpemodel: Optional[str], |
| | | key_file: Optional[str], |
| | | allow_variable_data_keys: bool, |
| | | quantize_asr_model: Optional[bool], |
| | | quantize_modules: Optional[List[str]], |
| | | quantize_dtype: Optional[str], |
| | | streaming: Optional[bool], |
| | | simu_streaming: Optional[bool], |
| | | chunk_size: Optional[int], |
| | | left_context: Optional[int], |
| | | right_context: Optional[int], |
| | | display_partial_hypotheses: bool, |
| | | **kwargs, |
| | | ) -> None: |
| | | """Transducer model inference. |
| | | Args: |
| | | output_dir: Output directory path. |
| | | batch_size: Batch decoding size. |
| | | dtype: Data type. |
| | | beam_size: Beam size. |
| | | ngpu: Number of GPUs. |
| | | seed: Random number generator seed. |
| | | lm_weight: Weight of language model. |
| | | nbest: Number of final hypothesis. |
| | | num_workers: Number of workers. |
| | | log_level: Level of verbose for logs. |
| | | data_path_and_name_and_type: |
| | | asr_train_config: ASR model training config path. |
| | | asr_model_file: ASR model path. |
| | | beam_search_config: Beam search config path. |
| | | lm_train_config: Language Model training config path. |
| | | lm_file: Language Model path. |
| | | model_tag: Model tag. |
| | | token_type: Type of token units. |
| | | bpemodel: BPE model path. |
| | | key_file: File key. |
| | | allow_variable_data_keys: Whether to allow variable data keys. |
| | | quantize_asr_model: Whether to apply dynamic quantization to ASR model. |
| | | quantize_modules: List of module names to apply dynamic quantization on. |
| | | quantize_dtype: Dynamic quantization data type. |
| | | streaming: Whether to perform chunk-by-chunk inference. |
| | | chunk_size: Number of frames in chunk AFTER subsampling. |
| | | left_context: Number of frames in left context AFTER subsampling. |
| | | right_context: Number of frames in right context AFTER subsampling. |
| | | display_partial_hypotheses: Whether to display partial hypotheses. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | if batch_size > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | if ngpu > 1: |
| | | raise NotImplementedError("only single GPU decoding is supported") |
| | | |
| | | logging.basicConfig( |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | if ngpu >= 1: |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2. Build speech2text |
| | | speech2text_kwargs = dict( |
| | | asr_train_config=asr_train_config, |
| | | asr_model_file=asr_model_file, |
| | | cmvn_file=cmvn_file, |
| | | beam_search_config=beam_search_config, |
| | | lm_train_config=lm_train_config, |
| | | lm_file=lm_file, |
| | | token_type=token_type, |
| | | bpemodel=bpemodel, |
| | | device=device, |
| | | dtype=dtype, |
| | | beam_size=beam_size, |
| | | lm_weight=lm_weight, |
| | | nbest=nbest, |
| | | quantize_asr_model=quantize_asr_model, |
| | | quantize_modules=quantize_modules, |
| | | quantize_dtype=quantize_dtype, |
| | | streaming=streaming, |
| | | simu_streaming=simu_streaming, |
| | | chunk_size=chunk_size, |
| | | left_context=left_context, |
| | | right_context=right_context, |
| | | ) |
| | | speech2text = Speech2TextTransducer.from_pretrained( |
| | | model_tag=model_tag, |
| | | **speech2text_kwargs, |
| | | ) |
| | | |
| | | def _forward(data_path_and_name_and_type, |
| | | raw_inputs: Union[np.ndarray, torch.Tensor] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | | fs: dict = None, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | # 3. Build data-iterator |
| | | loader = ASRTask.build_streaming_iterator( |
| | | data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | | preprocess_fn=ASRTask.build_preprocess_fn( |
| | | speech2text.asr_train_args, False |
| | | ), |
| | | collate_fn=ASRTask.build_collate_fn( |
| | | speech2text.asr_train_args, False |
| | | ), |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | ) |
| | | |
| | | # 4 .Start for-loop |
| | | with DatadirWriter(output_dir) as writer: |
| | | for keys, batch in loader: |
| | | assert isinstance(batch, dict), type(batch) |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | |
| | | _bs = len(next(iter(batch.values()))) |
| | | assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
| | | batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} |
| | | assert len(batch.keys()) == 1 |
| | | |
| | | try: |
| | | if speech2text.streaming: |
| | | speech = batch["speech"] |
| | | |
| | | _steps = len(speech) // speech2text._ctx |
| | | _end = 0 |
| | | for i in range(_steps): |
| | | _end = (i + 1) * speech2text._ctx |
| | | |
| | | speech2text.streaming_decode( |
| | | speech[i * speech2text._ctx : _end], is_final=False |
| | | ) |
| | | |
| | | final_hyps = speech2text.streaming_decode( |
| | | speech[_end : len(speech)], is_final=True |
| | | ) |
| | | elif speech2text.simu_streaming: |
| | | final_hyps = speech2text.simu_streaming_decode(**batch) |
| | | else: |
| | | final_hyps = speech2text(**batch) |
| | | |
| | | results = speech2text.hypotheses_to_results(final_hyps) |
| | | except TooShortUttError as e: |
| | | logging.warning(f"Utterance {keys} {e}") |
| | | hyp = Hypothesis(score=0.0, yseq=[], dec_state=None) |
| | | results = [[" ", ["<space>"], [2], hyp]] * nbest |
| | | |
| | | key = keys[0] |
| | | for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): |
| | | ibest_writer = writer[f"{n}best_recog"] |
| | | |
| | | ibest_writer["token"][key] = " ".join(token) |
| | | ibest_writer["token_int"][key] = " ".join(map(str, token_int)) |
| | | ibest_writer["score"][key] = str(hyp.score) |
| | | |
| | | if text is not None: |
| | | ibest_writer["text"][key] = text |
| | | |
| | | |
| | | return _forward |
| | | |
| | | |
| | | def inference_launch(**kwargs): |
| | | if 'mode' in kwargs: |
| | | mode = kwargs['mode'] |
| | | else: |
| | | logging.info("Unknown decoding mode.") |
| | | return None |
| | | if mode == "asr": |
| | | return inference_asr(**kwargs) |
| | | elif mode == "uniasr": |
| | | return inference_uniasr(**kwargs) |
| | | elif mode == "paraformer": |
| | | return inference_paraformer(**kwargs) |
| | | elif mode == "paraformer_streaming": |
| | | return inference_paraformer_online(**kwargs) |
| | | elif mode.startswith("paraformer_vad"): |
| | | return inference_paraformer_vad_punc(**kwargs) |
| | | elif mode == "mfcca": |
| | | return inference_mfcca(**kwargs) |
| | | elif mode == "rnnt": |
| | | return inference_transducer(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="ASR Decoding", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | |
| | | |
| | | # Note(kamo): Use '_' instead of '-' as separator. |
| | | # '-' is confusing if written in yaml. |
| | | parser.add_argument( |
| | |
| | | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| | | help="The verbose level of logging", |
| | | ) |
| | | |
| | | |
| | | parser.add_argument("--output_dir", type=str, required=True) |
| | | parser.add_argument( |
| | | "--ngpu", |
| | |
| | | default=1, |
| | | help="The number of workers used for DataLoader", |
| | | ) |
| | | |
| | | |
| | | group = parser.add_argument_group("Input data related") |
| | | group.add_argument( |
| | | "--data_path_and_name_and_type", |
| | |
| | | group.add_argument("--key_file", type=str_or_none) |
| | | group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) |
| | | group.add_argument( |
| | | "--mc", |
| | | type=bool, |
| | | default=False, |
| | | help="MultiChannel input", |
| | | ) |
| | | |
| | | "--mc", |
| | | type=bool, |
| | | default=False, |
| | | help="MultiChannel input", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("The model configuration related") |
| | | group.add_argument( |
| | | "--vad_infer_config", |
| | |
| | | default={}, |
| | | help="The keyword arguments for transducer beam search.", |
| | | ) |
| | | |
| | | |
| | | group = parser.add_argument_group("Beam-search related") |
| | | group.add_argument( |
| | | "--batch_size", |
| | |
| | | type=bool, |
| | | default=False, |
| | | help="Whether to display partial hypotheses during chunk-by-chunk inference.", |
| | | ) |
| | | |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Dynamic quantization related") |
| | | group.add_argument( |
| | | "--quantize_asr_model", |
| | |
| | | default="qint8", |
| | | choices=["float16", "qint8"], |
| | | help="Dtype for dynamic quantization.", |
| | | ) |
| | | |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Text converter related") |
| | | group.add_argument( |
| | | "--token_type", |
| | |
| | | help="CTC weight in joint decoding", |
| | | ) |
| | | return parser |
| | | |
| | | |
| | | |
| | | def inference_launch(**kwargs): |
| | | if 'mode' in kwargs: |
| | | mode = kwargs['mode'] |
| | | else: |
| | | logging.info("Unknown decoding mode.") |
| | | return None |
| | | if mode == "asr": |
| | | from funasr.bin.asr_inference import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | elif mode == "uniasr": |
| | | return inference_uniasr(**kwargs) |
| | | elif mode == "paraformer": |
| | | return inference_paraformer(**kwargs) |
| | | elif mode == "paraformer_streaming": |
| | | return inference_paraformer_online(**kwargs) |
| | | elif mode.startswith("paraformer_vad"): |
| | | return inference_paraformer_vad_punc(**kwargs) |
| | | elif mode == "mfcca": |
| | | from funasr.bin.asr_inference_mfcca import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | elif mode == "rnnt": |
| | | from funasr.bin.asr_inference_rnnt import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | |
| | | |
| | | def main(cmd=None): |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import os |
| | | |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import os |
| | | |
| | |
| | | def parse_args(): |
| | | parser = ASRTask.get_parser() |
| | | parser.add_argument( |
| | | "--mode", |
| | | type=str, |
| | | default="asr", |
| | | help="mode", |
| | | ) |
| | | parser.add_argument( |
| | | "--gpu_id", |
| | | type=int, |
| | | default=0, |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import os |
| | | |
| | | import yaml |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | |
| | | return _forward |
| | | |
| | | |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "sond": |
| | | return inference_sond(mode=mode, **kwargs) |
| | | elif mode == "sond_demo": |
| | | param_dict = { |
| | | "extract_profile": True, |
| | | "sv_train_config": "sv.yaml", |
| | | "sv_model_file": "sv.pb", |
| | | } |
| | | if "param_dict" in kwargs and kwargs["param_dict"] is not None: |
| | | for key in param_dict: |
| | | if key not in kwargs["param_dict"]: |
| | | kwargs["param_dict"][key] = param_dict[key] |
| | | else: |
| | | kwargs["param_dict"] = param_dict |
| | | return inference_sond(mode=mode, **kwargs) |
| | | elif mode == "eend-ola": |
| | | return inference_eend(mode=mode, **kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="Speaker Verification", |
| | |
| | | ) |
| | | |
| | | return parser |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "sond": |
| | | return inference_sond(mode=mode, **kwargs) |
| | | elif mode == "sond_demo": |
| | | param_dict = { |
| | | "extract_profile": True, |
| | | "sv_train_config": "sv.yaml", |
| | | "sv_model_file": "sv.pb", |
| | | } |
| | | if "param_dict" in kwargs and kwargs["param_dict"] is not None: |
| | | for key in param_dict: |
| | | if key not in kwargs["param_dict"]: |
| | | kwargs["param_dict"][key] = param_dict[key] |
| | | else: |
| | | kwargs["param_dict"] = param_dict |
| | | return inference_sond(mode=mode, **kwargs) |
| | | elif mode == "eend-ola": |
| | | return inference_eend(mode=mode, **kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | |
| | | def main(cmd=None): |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import os |
| | | |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | |
| | | |
| | |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.types import float_or_none |
| | | import argparse |
| | | import logging |
| | | from pathlib import Path |
| | | import sys |
| | | import os |
| | | from typing import Optional |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import Dict |
| | | from typing import Any |
| | | from typing import List |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from torch.nn.parallel import data_parallel |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.tasks.lm import LMTask |
| | | from funasr.datasets.preprocessor import LMPreprocessor |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.fileio.datadir_writer import DatadirWriter |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.forward_adaptor import ForwardAdaptor |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import float_or_none |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | |
| | | |
| | | def inference_lm( |
| | | batch_size: int, |
| | | dtype: str, |
| | | ngpu: int, |
| | | seed: int, |
| | | num_workers: int, |
| | | log_level: Union[int, str], |
| | | key_file: Optional[str], |
| | | train_config: Optional[str], |
| | | model_file: Optional[str], |
| | | log_base: Optional[float] = 10, |
| | | allow_variable_data_keys: bool = False, |
| | | split_with_space: Optional[bool] = False, |
| | | seg_dict_file: Optional[str] = None, |
| | | output_dir: Optional[str] = None, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | ncpu = kwargs.get("ncpu", 1) |
| | | torch.set_num_threads(ncpu) |
| | | |
| | | if ngpu >= 1 and torch.cuda.is_available(): |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2. Build Model |
| | | model, train_args = LMTask.build_model_from_file( |
| | | train_config, model_file, device) |
| | | wrapped_model = ForwardAdaptor(model, "nll") |
| | | wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval() |
| | | logging.info(f"Model:\n{model}") |
| | | |
| | | preprocessor = LMPreprocessor( |
| | | train=False, |
| | | token_type=train_args.token_type, |
| | | token_list=train_args.token_list, |
| | | bpemodel=train_args.bpemodel, |
| | | text_cleaner=train_args.cleaner, |
| | | g2p_type=train_args.g2p, |
| | | text_name="text", |
| | | non_linguistic_symbols=train_args.non_linguistic_symbols, |
| | | split_with_space=split_with_space, |
| | | seg_dict_file=seg_dict_file |
| | | ) |
| | | |
| | | def _forward( |
| | | data_path_and_name_and_type, |
| | | raw_inputs: Union[List[Any], bytes, str] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | | param_dict: dict = None, |
| | | ): |
| | | results = [] |
| | | output_path = output_dir_v2 if output_dir_v2 is not None else output_dir |
| | | if output_path is not None: |
| | | writer = DatadirWriter(output_path) |
| | | else: |
| | | writer = None |
| | | |
| | | if raw_inputs != None: |
| | | line = raw_inputs.strip() |
| | | key = "lm demo" |
| | | if line == "": |
| | | item = {'key': key, 'value': ""} |
| | | results.append(item) |
| | | return results |
| | | batch = {} |
| | | batch['text'] = line |
| | | if preprocessor != None: |
| | | batch = preprocessor(key, batch) |
| | | |
| | | # Force data-precision |
| | | for name in batch: |
| | | value = batch[name] |
| | | if not isinstance(value, np.ndarray): |
| | | raise RuntimeError( |
| | | f"All values must be converted to np.ndarray object " |
| | | f'by preprocessing, but "{name}" is still {type(value)}.' |
| | | ) |
| | | # Cast to desired type |
| | | if value.dtype.kind == "f": |
| | | value = value.astype("float32") |
| | | elif value.dtype.kind == "i": |
| | | value = value.astype("long") |
| | | else: |
| | | raise NotImplementedError(f"Not supported dtype: {value.dtype}") |
| | | batch[name] = value |
| | | |
| | | batch["text_lengths"] = torch.from_numpy( |
| | | np.array([len(batch["text"])], dtype='int32')) |
| | | batch["text"] = np.expand_dims(batch["text"], axis=0) |
| | | |
| | | with torch.no_grad(): |
| | | batch = to_device(batch, device) |
| | | if ngpu <= 1: |
| | | nll, lengths = wrapped_model(**batch) |
| | | else: |
| | | nll, lengths = data_parallel( |
| | | wrapped_model, (), range(ngpu), module_kwargs=batch |
| | | ) |
| | | ## compute ppl |
| | | ppl_out_batch = "" |
| | | ids2tokens = preprocessor.token_id_converter.ids2tokens |
| | | for sent_ids, sent_nll in zip(batch['text'], nll): |
| | | pre_word = "<s>" |
| | | cur_word = None |
| | | sent_lst = ids2tokens(sent_ids) + ['</s>'] |
| | | ppl_out = " ".join(sent_lst) + "\n" |
| | | for word, word_nll in zip(sent_lst, sent_nll): |
| | | cur_word = word |
| | | word_nll = -word_nll.cpu() |
| | | if log_base is None: |
| | | word_prob = np.exp(word_nll) |
| | | else: |
| | | word_prob = log_base ** (word_nll / np.log(log_base)) |
| | | ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format( |
| | | cur=cur_word, |
| | | pre=pre_word, |
| | | prob=round(word_prob.item(), 8), |
| | | word_nll=round(word_nll.item(), 8) |
| | | ) |
| | | pre_word = cur_word |
| | | |
| | | sent_nll_mean = sent_nll.mean().cpu().numpy() |
| | | sent_nll_sum = sent_nll.sum().cpu().numpy() |
| | | if log_base is None: |
| | | sent_ppl = np.exp(sent_nll_mean) |
| | | else: |
| | | sent_ppl = log_base ** (sent_nll_mean / np.log(log_base)) |
| | | ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format( |
| | | sent_nll=round(-sent_nll_sum.item(), 4), |
| | | sent_ppl=round(sent_ppl.item(), 4) |
| | | ) |
| | | ppl_out_batch += ppl_out |
| | | item = {'key': key, 'value': ppl_out} |
| | | if writer is not None: |
| | | writer["ppl"][key + ":\n"] = ppl_out |
| | | results.append(item) |
| | | |
| | | return results |
| | | |
| | | # 3. Build data-iterator |
| | | loader = LMTask.build_streaming_iterator( |
| | | data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | | preprocess_fn=preprocessor, |
| | | collate_fn=LMTask.build_collate_fn(train_args, False), |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | ) |
| | | |
| | | # 4. Start for-loop |
| | | total_nll = 0.0 |
| | | total_ntokens = 0 |
| | | ppl_out_all = "" |
| | | for keys, batch in loader: |
| | | assert isinstance(batch, dict), type(batch) |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | _bs = len(next(iter(batch.values()))) |
| | | assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
| | | |
| | | ppl_out_batch = "" |
| | | with torch.no_grad(): |
| | | batch = to_device(batch, device) |
| | | if ngpu <= 1: |
| | | # NOTE(kamo): data_parallel also should work with ngpu=1, |
| | | # but for debuggability it's better to keep this block. |
| | | nll, lengths = wrapped_model(**batch) |
| | | else: |
| | | nll, lengths = data_parallel( |
| | | wrapped_model, (), range(ngpu), module_kwargs=batch |
| | | ) |
| | | ## print ppl |
| | | ids2tokens = preprocessor.token_id_converter.ids2tokens |
| | | for key, sent_ids, sent_nll in zip(keys, batch['text'], nll): |
| | | pre_word = "<s>" |
| | | cur_word = None |
| | | sent_lst = ids2tokens(sent_ids) + ['</s>'] |
| | | ppl_out = " ".join(sent_lst) + "\n" |
| | | for word, word_nll in zip(sent_lst, sent_nll): |
| | | cur_word = word |
| | | word_nll = -word_nll.cpu() |
| | | if log_base is None: |
| | | word_prob = np.exp(word_nll) |
| | | else: |
| | | word_prob = log_base ** (word_nll / np.log(log_base)) |
| | | ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format( |
| | | cur=cur_word, |
| | | pre=pre_word, |
| | | prob=round(word_prob.item(), 8), |
| | | word_nll=round(word_nll.item(), 8) |
| | | ) |
| | | pre_word = cur_word |
| | | |
| | | sent_nll_mean = sent_nll.mean().cpu().numpy() |
| | | sent_nll_sum = sent_nll.sum().cpu().numpy() |
| | | if log_base is None: |
| | | sent_ppl = np.exp(sent_nll_mean) |
| | | else: |
| | | sent_ppl = log_base ** (sent_nll_mean / np.log(log_base)) |
| | | ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format( |
| | | sent_nll=round(-sent_nll_sum.item(), 4), |
| | | sent_ppl=round(sent_ppl.item(), 4) |
| | | ) |
| | | ppl_out_batch += ppl_out |
| | | utt2nll = round(-sent_nll_sum.item(), 5) |
| | | item = {'key': key, 'value': ppl_out} |
| | | if writer is not None: |
| | | writer["ppl"][key + ":\n"] = ppl_out |
| | | writer["utt2nll"][key] = str(utt2nll) |
| | | results.append(item) |
| | | |
| | | ppl_out_all += ppl_out_batch |
| | | |
| | | assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths)) |
| | | # nll: (B, L) -> (B,) |
| | | nll = nll.detach().cpu().numpy().sum(1) |
| | | # lengths: (B,) |
| | | lengths = lengths.detach().cpu().numpy() |
| | | total_nll += nll.sum() |
| | | total_ntokens += lengths.sum() |
| | | |
| | | if log_base is None: |
| | | ppl = np.exp(total_nll / total_ntokens) |
| | | else: |
| | | ppl = log_base ** (total_nll / total_ntokens / np.log(log_base)) |
| | | |
| | | avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format( |
| | | total_nll=round(-total_nll.item(), 4), |
| | | total_ppl=round(ppl.item(), 4) |
| | | ) |
| | | item = {'key': 'AVG PPL', 'value': avg_ppl} |
| | | ppl_out_all += avg_ppl |
| | | if writer is not None: |
| | | writer["ppl"]["AVG PPL : "] = avg_ppl |
| | | results.append(item) |
| | | |
| | | return results |
| | | |
| | | return _forward |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "transformer": |
| | | return inference_lm(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="Calc perplexity", |
| | |
| | | group.add_argument("--model_file", type=str) |
| | | group.add_argument("--mode", type=str, default="lm") |
| | | return parser |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "transformer": |
| | | from funasr.bin.lm_inference import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | |
| | | def main(cmd=None): |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import os |
| | | |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import argparse |
| | | import logging |
| | | from pathlib import Path |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import argparse |
| | | import logging |
| | |
| | | return _forward |
| | | |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "punc": |
| | | return inference_punc(**kwargs) |
| | | if mode == "punc_VadRealtime": |
| | | return inference_punc_vad_realtime(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="Punctuation inference", |
| | |
| | | group.add_argument("--model_file", type=str) |
| | | group.add_argument("--mode", type=str, default="punc") |
| | | return parser |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "punc": |
| | | return inference_punc(**kwargs) |
| | | if mode == "punc_VadRealtime": |
| | | return inference_punc_vad_realtime(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | |
| | | def main(cmd=None): |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import os |
| | | from funasr.tasks.punctuation import PunctuationTask |
| | | |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import argparse |
| | | import logging |
| | | import sys |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import os |
| | | |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | |
| | | import argparse |
| | | import logging |
| | |
| | | return _forward |
| | | |
| | | |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "sv": |
| | | return inference_sv(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="Speaker Verification", |
| | |
| | | ) |
| | | |
| | | return parser |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "sv": |
| | | return inference_sv(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | |
| | | def main(cmd=None): |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import argparse |
| | | import logging |
| | | from optparse import Option |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | |
| | | import argparse |
| | |
| | | return _forward |
| | | |
| | | |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "tp_norm": |
| | | return inference_tp(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="Timestamp Prediction Inference", |
| | |
| | | ) |
| | | return parser |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "tp_norm": |
| | | return inference_tp(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | |
| | | import torch |
| | | torch.set_num_threads(1) |
| | |
| | | return _forward |
| | | |
| | | |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "offline": |
| | | return inference_vad(**kwargs) |
| | | elif mode == "online": |
| | | return inference_vad_online(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="VAD Decoding", |
| | |
| | | ) |
| | | return parser |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "offline": |
| | | return inference_vad(**kwargs) |
| | | elif mode == "online": |
| | | return inference_vad_online(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |