游雁
2023-05-16 d105ce0d6b63bcd14edeb426fbc0acf593296be3
inference
21个文件已修改
2个文件已删除
2050 ■■■■■ 已修改文件
funasr/bin/asr_infer.py 345 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 604 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_train.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_train_paraformer.py 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/build_trainer.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/diar_infer.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/diar_inference_launch.py 48 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/diar_train.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/lm_calc_perplexity.py 211 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/lm_inference.py 406 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/lm_inference_launch.py 297 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/lm_train.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_infer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_inference_launch.py 24 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_train.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/sa_asr_inference.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/sa_asr_train.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/sv_infer.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/sv_inference_launch.py 19 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/tp_infer.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/tp_inference_launch.py 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/vad_infer.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/vad_inference_launch.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_infer.py
@@ -1,4 +1,8 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import sys
@@ -19,13 +23,15 @@
import numpy as np
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import BeamSearch
# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
@@ -47,11 +53,10 @@
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.vad_inference import Speech2VadSegment
from funasr.bin.vad_infer import Speech2VadSegment
from funasr.bin.punc_infer import Text2Punc
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
@@ -264,7 +269,6 @@
        
        assert check_return_type(results)
        return results
class Speech2TextParaformer:
    """Speech2Text class
@@ -839,7 +843,6 @@
        # assert check_return_type(results)
        return results
class Speech2TextUniASR:
    """Speech2Text class
@@ -1072,9 +1075,7 @@
        assert check_return_type(results)
        return results
class Speech2TextMFCCA:
    """Speech2Text class
@@ -1114,6 +1115,7 @@
        assert check_argument_types()
        
        # 1. Build ASR model
        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
@@ -1270,3 +1272,330 @@
        return results
class Speech2TextTransducer:
    """Speech2Text class for Transducer models.
    Args:
        asr_train_config: ASR model training config path.
        asr_model_file: ASR model path.
        beam_search_config: Beam search config path.
        lm_train_config: Language Model training config path.
        lm_file: Language Model config path.
        token_type: Type of token units.
        bpemodel: BPE model path.
        device: Device to use for inference.
        beam_size: Size of beam during search.
        dtype: Data type.
        lm_weight: Language model weight.
        quantize_asr_model: Whether to apply dynamic quantization to ASR model.
        quantize_modules: List of module names to apply dynamic quantization on.
        quantize_dtype: Dynamic quantization data type.
        nbest: Number of final hypothesis.
        streaming: Whether to perform chunk-by-chunk inference.
        chunk_size: Number of frames in chunk AFTER subsampling.
        left_context: Number of frames in left context AFTER subsampling.
        right_context: Number of frames in right context AFTER subsampling.
        display_partial_hypotheses: Whether to display partial hypotheses.
    """
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        beam_search_config: Dict[str, Any] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        beam_size: int = 5,
        dtype: str = "float32",
        lm_weight: float = 1.0,
        quantize_asr_model: bool = False,
        quantize_modules: List[str] = None,
        quantize_dtype: str = "qint8",
        nbest: int = 1,
        streaming: bool = False,
        simu_streaming: bool = False,
        chunk_size: int = 16,
        left_context: int = 32,
        right_context: int = 0,
        display_partial_hypotheses: bool = False,
    ) -> None:
        """Construct a Speech2Text object."""
        super().__init__()
        assert check_argument_types()
        from funasr.tasks.asr import ASRTransducerTask
        asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
        if quantize_asr_model:
            if quantize_modules is not None:
                if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
                    raise ValueError(
                        "Only 'Linear' and 'LSTM' modules are currently supported"
                        " by PyTorch and in --quantize_modules"
                    )
                q_config = set([getattr(torch.nn, q) for q in quantize_modules])
            else:
                q_config = {torch.nn.Linear}
            if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
                raise ValueError(
                    "float16 dtype for dynamic quantization is not supported with torch"
                    " version < 1.5.0. Switching to qint8 dtype instead."
                )
            q_dtype = getattr(torch, quantize_dtype)
            asr_model = torch.quantization.quantize_dynamic(
                asr_model, q_config, dtype=q_dtype
            ).eval()
        else:
            asr_model.to(dtype=getattr(torch, dtype)).eval()
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            )
            lm_scorer = lm.lm
        else:
            lm_scorer = None
        # 4. Build BeamSearch object
        if beam_search_config is None:
            beam_search_config = {}
        beam_search = BeamSearchTransducer(
            asr_model.decoder,
            asr_model.joint_network,
            beam_size,
            lm=lm_scorer,
            lm_weight=lm_weight,
            nbest=nbest,
            **beam_search_config,
        )
        token_list = asr_model.token_list
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
            if bpemodel is not None:
                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
            else:
                tokenizer = None
        else:
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.device = device
        self.dtype = dtype
        self.nbest = nbest
        self.converter = converter
        self.tokenizer = tokenizer
        self.beam_search = beam_search
        self.streaming = streaming
        self.simu_streaming = simu_streaming
        self.chunk_size = max(chunk_size, 0)
        self.left_context = left_context
        self.right_context = max(right_context, 0)
        if not streaming or chunk_size == 0:
            self.streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False
        if not simu_streaming or chunk_size == 0:
            self.simu_streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False
        self.frontend = frontend
        self.window_size = self.chunk_size + self.right_context
        if self.streaming:
            self._ctx = self.asr_model.encoder.get_encoder_input_size(
                self.window_size
            )
            self.last_chunk_length = (
                self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
            )
            self.reset_inference_cache()
    def reset_inference_cache(self) -> None:
        """Reset Speech2Text parameters."""
        self.frontend_cache = None
        self.asr_model.encoder.reset_streaming_cache(
            self.left_context, device=self.device
        )
        self.beam_search.reset_inference_cache()
        self.num_processed_frames = torch.tensor([[0]], device=self.device)
    @torch.no_grad()
    def streaming_decode(
        self,
        speech: Union[torch.Tensor, np.ndarray],
        is_final: bool = True,
    ) -> List[HypothesisTransducer]:
        """Speech2Text streaming call.
        Args:
            speech: Chunk of speech data. (S)
            is_final: Whether speech corresponds to the final chunk of data.
        Returns:
            nbest_hypothesis: N-best hypothesis.
        """
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if is_final:
            if self.streaming and speech.size(0) < self.last_chunk_length:
                pad = torch.zeros(
                    self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype
                )
                speech = torch.cat([speech, pad],
                                   dim=0)  # feats, feats_length = self.apply_frontend(speech, is_final=is_final)
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out = self.asr_model.encoder.chunk_forward(
            feats,
            feats_lengths,
            self.num_processed_frames,
            chunk_size=self.chunk_size,
            left_context=self.left_context,
            right_context=self.right_context,
        )
        nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
        self.num_processed_frames += self.chunk_size
        if is_final:
            self.reset_inference_cache()
        return nbest_hyps
    @torch.no_grad()
    def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
        """Speech2Text call.
        Args:
            speech: Speech data. (S)
        Returns:
            nbest_hypothesis: N-best hypothesis.
        """
        assert check_argument_types()
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context,
                                                            self.right_context)
        nbest_hyps = self.beam_search(enc_out[0])
        return nbest_hyps
    @torch.no_grad()
    def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
        """Speech2Text call.
        Args:
            speech: Speech data. (S)
        Returns:
            nbest_hypothesis: N-best hypothesis.
        """
        assert check_argument_types()
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
        nbest_hyps = self.beam_search(enc_out[0])
        return nbest_hyps
    def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]:
        """Build partial or final results from the hypotheses.
        Args:
            nbest_hyps: N-best hypothesis.
        Returns:
            results: Results containing different representation for the hypothesis.
        """
        results = []
        for hyp in nbest_hyps:
            token_int = list(filter(lambda x: x != 0, hyp.yseq))
            token = self.converter.ids2tokens(token_int)
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, hyp))
            assert check_return_type(results)
        return results
    @staticmethod
    def from_pretrained(
        model_tag: Optional[str] = None,
        **kwargs: Optional[Any],
    ) -> Speech2Text:
        """Build Speech2Text instance from the pretrained model.
        Args:
            model_tag: Model tag of the pretrained models.
        Return:
            : Speech2Text instance.
        """
        if model_tag is not None:
            try:
                from espnet_model_zoo.downloader import ModelDownloader
            except ImportError:
                logging.error(
                    "`espnet_model_zoo` is not installed. "
                    "Please install via `pip install -U espnet_model_zoo`."
                )
                raise
            d = ModelDownloader()
            kwargs.update(**d.download_and_unpack(model_tag))
        return Speech2Text(**kwargs)
funasr/bin/asr_inference_launch.py
@@ -1,4 +1,7 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
@@ -61,15 +64,180 @@
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.bin.tp_inference import SpeechText2Timestamp
from funasr.bin.vad_inference import Speech2VadSegment
from funasr.bin.punctuation_infer import Text2Punc
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.bin.asr_infer import Speech2Text
from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
from funasr.bin.asr_infer import Speech2TextUniASR
from funasr.bin.asr_infer import Speech2TextMFCCA
from funasr.bin.vad_infer import Speech2VadSegment
from funasr.bin.punc_infer import Text2Punc
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.asr_infer import Speech2TextTransducer
def inference_asr(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    streaming: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    mc: bool = False,
    param_dict: dict = None,
    **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if word_lm_train_config is not None:
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
        asr_model_file=asr_model_file,
        cmvn_file=cmvn_file,
        lm_train_config=lm_train_config,
        lm_file=lm_file,
        token_type=token_type,
        bpemodel=bpemodel,
        device=device,
        maxlenratio=maxlenratio,
        minlenratio=minlenratio,
        dtype=dtype,
        beam_size=beam_size,
        ctc_weight=ctc_weight,
        lm_weight=lm_weight,
        ngram_weight=ngram_weight,
        penalty=penalty,
        nbest=nbest,
        streaming=streaming,
    )
    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
    speech2text = Speech2Text(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
                 fs: dict = None,
                 param_dict: dict = None,
                 **kwargs,
                 ):
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            mc=mc,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
        # FIXME(kamo): The output format should be discussed about
        asr_result_list = []
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if output_path is not None:
            writer = DatadirWriter(output_path)
        else:
            writer = None
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            # N-best list of (text, token, token_int, hyp_object)
            try:
                results = speech2text(**batch)
            except TooShortUttError as e:
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["sil"], [2], hyp]] * nbest
            # Only supporting batch_size==1
            key = keys[0]
            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
                # Create a directory: outdir/{n}best_recog
                if writer is not None:
                    ibest_writer = writer[f"{n}best_recog"]
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                if text is not None:
                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
                    asr_result_list.append(item)
                    finish_count += 1
                    asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text
                logging.info("uttid: {}".format(key))
                logging.info("text predictions: {}\n".format(text))
        return asr_result_list
    return _forward
def inference_paraformer(
@@ -161,7 +329,7 @@
    speech2text = Speech2TextParaformer(**speech2text_kwargs)
    
    if timestamp_model_file is not None:
        speechtext2timestamp = SpeechText2Timestamp(
        speechtext2timestamp = Speech2Timestamp(
            timestamp_cmvn_file=cmvn_file,
            timestamp_model_file=timestamp_model_file,
            timestamp_infer_config=timestamp_infer_config,
@@ -931,12 +1099,382 @@
    return _forward
def inference_mfcca(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    streaming: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    param_dict: dict = None,
    **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if word_lm_train_config is not None:
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
        asr_model_file=asr_model_file,
        cmvn_file=cmvn_file,
        lm_train_config=lm_train_config,
        lm_file=lm_file,
        token_type=token_type,
        bpemodel=bpemodel,
        device=device,
        maxlenratio=maxlenratio,
        minlenratio=minlenratio,
        dtype=dtype,
        beam_size=beam_size,
        ctc_weight=ctc_weight,
        lm_weight=lm_weight,
        ngram_weight=ngram_weight,
        penalty=penalty,
        nbest=nbest,
        streaming=streaming,
    )
    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
    speech2text = Speech2TextMFCCA(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
                 fs: dict = None,
                 param_dict: dict = None,
                 **kwargs,
                 ):
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            fs=fs,
            mc=True,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
        # FIXME(kamo): The output format should be discussed about
        asr_result_list = []
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if output_path is not None:
            writer = DatadirWriter(output_path)
        else:
            writer = None
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            # N-best list of (text, token, token_int, hyp_object)
            try:
                results = speech2text(**batch)
            except TooShortUttError as e:
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["<space>"], [2], hyp]] * nbest
            # Only supporting batch_size==1
            key = keys[0]
            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
                # Create a directory: outdir/{n}best_recog
                if writer is not None:
                    ibest_writer = writer[f"{n}best_recog"]
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                if text is not None:
                    text_postprocessed = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
                    asr_result_list.append(item)
                    finish_count += 1
                    asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text
        return asr_result_list
    return _forward
def inference_transducer(
    output_dir: str,
    batch_size: int,
    dtype: str,
    beam_size: int,
    ngpu: int,
    seed: int,
    lm_weight: float,
    nbest: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str],
    beam_search_config: Optional[dict],
    lm_train_config: Optional[str],
    lm_file: Optional[str],
    model_tag: Optional[str],
    token_type: Optional[str],
    bpemodel: Optional[str],
    key_file: Optional[str],
    allow_variable_data_keys: bool,
    quantize_asr_model: Optional[bool],
    quantize_modules: Optional[List[str]],
    quantize_dtype: Optional[str],
    streaming: Optional[bool],
    simu_streaming: Optional[bool],
    chunk_size: Optional[int],
    left_context: Optional[int],
    right_context: Optional[int],
    display_partial_hypotheses: bool,
    **kwargs,
) -> None:
    """Transducer model inference.
    Args:
        output_dir: Output directory path.
        batch_size: Batch decoding size.
        dtype: Data type.
        beam_size: Beam size.
        ngpu: Number of GPUs.
        seed: Random number generator seed.
        lm_weight: Weight of language model.
        nbest: Number of final hypothesis.
        num_workers: Number of workers.
        log_level: Level of verbose for logs.
        data_path_and_name_and_type:
        asr_train_config: ASR model training config path.
        asr_model_file: ASR model path.
        beam_search_config: Beam search config path.
        lm_train_config: Language Model training config path.
        lm_file: Language Model path.
        model_tag: Model tag.
        token_type: Type of token units.
        bpemodel: BPE model path.
        key_file: File key.
        allow_variable_data_keys: Whether to allow variable data keys.
        quantize_asr_model: Whether to apply dynamic quantization to ASR model.
        quantize_modules: List of module names to apply dynamic quantization on.
        quantize_dtype: Dynamic quantization data type.
        streaming: Whether to perform chunk-by-chunk inference.
        chunk_size: Number of frames in chunk AFTER subsampling.
        left_context: Number of frames in left context AFTER subsampling.
        right_context: Number of frames in right context AFTER subsampling.
        display_partial_hypotheses: Whether to display partial hypotheses.
    """
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
        asr_model_file=asr_model_file,
        cmvn_file=cmvn_file,
        beam_search_config=beam_search_config,
        lm_train_config=lm_train_config,
        lm_file=lm_file,
        token_type=token_type,
        bpemodel=bpemodel,
        device=device,
        dtype=dtype,
        beam_size=beam_size,
        lm_weight=lm_weight,
        nbest=nbest,
        quantize_asr_model=quantize_asr_model,
        quantize_modules=quantize_modules,
        quantize_dtype=quantize_dtype,
        streaming=streaming,
        simu_streaming=simu_streaming,
        chunk_size=chunk_size,
        left_context=left_context,
        right_context=right_context,
    )
    speech2text = Speech2TextTransducer.from_pretrained(
        model_tag=model_tag,
        **speech2text_kwargs,
    )
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
                 fs: dict = None,
                 param_dict: dict = None,
                 **kwargs,
                 ):
        # 3. Build data-iterator
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(
                speech2text.asr_train_args, False
            ),
            collate_fn=ASRTask.build_collate_fn(
                speech2text.asr_train_args, False
            ),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        # 4 .Start for-loop
        with DatadirWriter(output_dir) as writer:
            for keys, batch in loader:
                assert isinstance(batch, dict), type(batch)
                assert all(isinstance(s, str) for s in keys), keys
                _bs = len(next(iter(batch.values())))
                assert len(keys) == _bs, f"{len(keys)} != {_bs}"
                batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
                assert len(batch.keys()) == 1
                try:
                    if speech2text.streaming:
                        speech = batch["speech"]
                        _steps = len(speech) // speech2text._ctx
                        _end = 0
                        for i in range(_steps):
                            _end = (i + 1) * speech2text._ctx
                            speech2text.streaming_decode(
                                speech[i * speech2text._ctx : _end], is_final=False
                            )
                        final_hyps = speech2text.streaming_decode(
                            speech[_end : len(speech)], is_final=True
                        )
                    elif speech2text.simu_streaming:
                        final_hyps = speech2text.simu_streaming_decode(**batch)
                    else:
                        final_hyps = speech2text(**batch)
                    results = speech2text.hypotheses_to_results(final_hyps)
                except TooShortUttError as e:
                    logging.warning(f"Utterance {keys} {e}")
                    hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
                    results = [[" ", ["<space>"], [2], hyp]] * nbest
                key = keys[0]
                for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
                    ibest_writer = writer[f"{n}best_recog"]
                    ibest_writer["token"][key] = " ".join(token)
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                    if text is not None:
                        ibest_writer["text"][key] = text
    return _forward
def inference_launch(**kwargs):
    if 'mode' in kwargs:
        mode = kwargs['mode']
    else:
        logging.info("Unknown decoding mode.")
        return None
    if mode == "asr":
        return inference_asr(**kwargs)
    elif mode == "uniasr":
        return inference_uniasr(**kwargs)
    elif mode == "paraformer":
        return inference_paraformer(**kwargs)
    elif mode == "paraformer_streaming":
        return inference_paraformer_online(**kwargs)
    elif mode.startswith("paraformer_vad"):
        return inference_paraformer_vad_punc(**kwargs)
    elif mode == "mfcca":
        return inference_mfcca(**kwargs)
    elif mode == "rnnt":
        return inference_transducer(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="ASR Decoding",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # Note(kamo): Use '_' instead of '-' as separator.
    # '-' is confusing if written in yaml.
    parser.add_argument(
@@ -946,7 +1484,7 @@
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument(
        "--ngpu",
@@ -979,7 +1517,7 @@
        default=1,
        help="The number of workers used for DataLoader",
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument(
        "--data_path_and_name_and_type",
@@ -990,12 +1528,12 @@
    group.add_argument("--key_file", type=str_or_none)
    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
    group.add_argument(
            "--mc",
            type=bool,
            default=False,
            help="MultiChannel input",
        )
        "--mc",
        type=bool,
        default=False,
        help="MultiChannel input",
    )
    group = parser.add_argument_group("The model configuration related")
    group.add_argument(
        "--vad_infer_config",
@@ -1058,7 +1596,7 @@
        default={},
        help="The keyword arguments for transducer beam search.",
    )
    group = parser.add_argument_group("Beam-search related")
    group.add_argument(
        "--batch_size",
@@ -1104,8 +1642,8 @@
        type=bool,
        default=False,
        help="Whether to display partial hypotheses during chunk-by-chunk inference.",
    )
    )
    group = parser.add_argument_group("Dynamic quantization related")
    group.add_argument(
        "--quantize_asr_model",
@@ -1129,8 +1667,8 @@
        default="qint8",
        choices=["float16", "qint8"],
        help="Dtype for dynamic quantization.",
    )
    )
    group = parser.add_argument_group("Text converter related")
    group.add_argument(
        "--token_type",
@@ -1157,36 +1695,6 @@
        help="CTC weight in joint decoding",
    )
    return parser
def inference_launch(**kwargs):
    if 'mode' in kwargs:
        mode = kwargs['mode']
    else:
        logging.info("Unknown decoding mode.")
        return None
    if mode == "asr":
        from funasr.bin.asr_inference import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "uniasr":
        return inference_uniasr(**kwargs)
    elif mode == "paraformer":
        return inference_paraformer(**kwargs)
    elif mode == "paraformer_streaming":
        return inference_paraformer_online(**kwargs)
    elif mode.startswith("paraformer_vad"):
        return inference_paraformer_vad_punc(**kwargs)
    elif mode == "mfcca":
        from funasr.bin.asr_inference_mfcca import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "rnnt":
        from funasr.bin.asr_inference_rnnt import inference_modelscope
        return inference_modelscope(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def main(cmd=None):
funasr/bin/asr_train.py
@@ -1,4 +1,7 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import os
funasr/bin/asr_train_paraformer.py
@@ -1,4 +1,7 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import os
@@ -9,6 +12,12 @@
def parse_args():
    parser = ASRTask.get_parser()
    parser.add_argument(
        "--mode",
        type=str,
        default="asr",
        help="mode",
    )
    parser.add_argument(
        "--gpu_id",
        type=int,
        default=0,
funasr/bin/build_trainer.py
@@ -1,3 +1,8 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import os
import yaml
funasr/bin/diar_infer.py
@@ -1,3 +1,4 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
funasr/bin/diar_inference_launch.py
@@ -1,3 +1,4 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -362,6 +363,30 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "sond":
        return inference_sond(mode=mode, **kwargs)
    elif mode == "sond_demo":
        param_dict = {
            "extract_profile": True,
            "sv_train_config": "sv.yaml",
            "sv_model_file": "sv.pb",
        }
        if "param_dict" in kwargs and kwargs["param_dict"] is not None:
            for key in param_dict:
                if key not in kwargs["param_dict"]:
                    kwargs["param_dict"][key] = param_dict[key]
        else:
            kwargs["param_dict"] = param_dict
        return inference_sond(mode=mode, **kwargs)
    elif mode == "eend-ola":
        return inference_eend(mode=mode, **kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Speaker Verification",
@@ -469,29 +494,6 @@
    )
    return parser
def inference_launch(mode, **kwargs):
    if mode == "sond":
        return inference_sond(mode=mode, **kwargs)
    elif mode == "sond_demo":
        param_dict = {
            "extract_profile": True,
            "sv_train_config": "sv.yaml",
            "sv_model_file": "sv.pb",
        }
        if "param_dict" in kwargs and kwargs["param_dict"] is not None:
            for key in param_dict:
                if key not in kwargs["param_dict"]:
                    kwargs["param_dict"][key] = param_dict[key]
        else:
            kwargs["param_dict"] = param_dict
        return inference_sond(mode=mode, **kwargs)
    elif mode == "eend-ola":
        return inference_eend(mode=mode, **kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def main(cmd=None):
funasr/bin/diar_train.py
@@ -1,4 +1,7 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import os
funasr/bin/lm_calc_perplexity.py
File was deleted
funasr/bin/lm_inference.py
File was deleted
funasr/bin/lm_inference_launch.py
@@ -1,4 +1,7 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -14,8 +17,294 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.types import float_or_none
import argparse
import logging
from pathlib import Path
import sys
import os
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import numpy as np
import torch
from torch.nn.parallel import data_parallel
from typeguard import check_argument_types
from funasr.tasks.lm import LMTask
from funasr.datasets.preprocessor import LMPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import float_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
def inference_lm(
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    log_base: Optional[float] = 10,
    allow_variable_data_keys: bool = False,
    split_with_space: Optional[bool] = False,
    seg_dict_file: Optional[str] = None,
    output_dir: Optional[str] = None,
    param_dict: dict = None,
    **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build Model
    model, train_args = LMTask.build_model_from_file(
        train_config, model_file, device)
    wrapped_model = ForwardAdaptor(model, "nll")
    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
    logging.info(f"Model:\n{model}")
    preprocessor = LMPreprocessor(
        train=False,
        token_type=train_args.token_type,
        token_list=train_args.token_list,
        bpemodel=train_args.bpemodel,
        text_cleaner=train_args.cleaner,
        g2p_type=train_args.g2p,
        text_name="text",
        non_linguistic_symbols=train_args.non_linguistic_symbols,
        split_with_space=split_with_space,
        seg_dict_file=seg_dict_file
    )
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[List[Any], bytes, str] = None,
        output_dir_v2: Optional[str] = None,
        param_dict: dict = None,
    ):
        results = []
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if output_path is not None:
            writer = DatadirWriter(output_path)
        else:
            writer = None
        if raw_inputs != None:
            line = raw_inputs.strip()
            key = "lm demo"
            if line == "":
                item = {'key': key, 'value': ""}
                results.append(item)
                return results
            batch = {}
            batch['text'] = line
            if preprocessor != None:
                batch = preprocessor(key, batch)
            #  Force data-precision
            for name in batch:
                value = batch[name]
                if not isinstance(value, np.ndarray):
                    raise RuntimeError(
                        f"All values must be converted to np.ndarray object "
                        f'by preprocessing, but "{name}" is still {type(value)}.'
                    )
                # Cast to desired type
                if value.dtype.kind == "f":
                    value = value.astype("float32")
                elif value.dtype.kind == "i":
                    value = value.astype("long")
                else:
                    raise NotImplementedError(f"Not supported dtype: {value.dtype}")
                batch[name] = value
            batch["text_lengths"] = torch.from_numpy(
                np.array([len(batch["text"])], dtype='int32'))
            batch["text"] = np.expand_dims(batch["text"], axis=0)
            with torch.no_grad():
                batch = to_device(batch, device)
                if ngpu <= 1:
                    nll, lengths = wrapped_model(**batch)
                else:
                    nll, lengths = data_parallel(
                        wrapped_model, (), range(ngpu), module_kwargs=batch
                    )
                ## compute ppl
                ppl_out_batch = ""
                ids2tokens = preprocessor.token_id_converter.ids2tokens
                for sent_ids, sent_nll in zip(batch['text'], nll):
                    pre_word = "<s>"
                    cur_word = None
                    sent_lst = ids2tokens(sent_ids) + ['</s>']
                    ppl_out = " ".join(sent_lst) + "\n"
                    for word, word_nll in zip(sent_lst, sent_nll):
                        cur_word = word
                        word_nll = -word_nll.cpu()
                        if log_base is None:
                            word_prob = np.exp(word_nll)
                        else:
                            word_prob = log_base ** (word_nll / np.log(log_base))
                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
                            cur=cur_word,
                            pre=pre_word,
                            prob=round(word_prob.item(), 8),
                            word_nll=round(word_nll.item(), 8)
                        )
                        pre_word = cur_word
                    sent_nll_mean = sent_nll.mean().cpu().numpy()
                    sent_nll_sum = sent_nll.sum().cpu().numpy()
                    if log_base is None:
                        sent_ppl = np.exp(sent_nll_mean)
                    else:
                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
                        sent_nll=round(-sent_nll_sum.item(), 4),
                        sent_ppl=round(sent_ppl.item(), 4)
                    )
                    ppl_out_batch += ppl_out
                    item = {'key': key, 'value': ppl_out}
                    if writer is not None:
                        writer["ppl"][key + ":\n"] = ppl_out
                    results.append(item)
            return results
        # 3. Build data-iterator
        loader = LMTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=preprocessor,
            collate_fn=LMTask.build_collate_fn(train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        # 4. Start for-loop
        total_nll = 0.0
        total_ntokens = 0
        ppl_out_all = ""
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            ppl_out_batch = ""
            with torch.no_grad():
                batch = to_device(batch, device)
                if ngpu <= 1:
                    # NOTE(kamo): data_parallel also should work with ngpu=1,
                    # but for debuggability it's better to keep this block.
                    nll, lengths = wrapped_model(**batch)
                else:
                    nll, lengths = data_parallel(
                        wrapped_model, (), range(ngpu), module_kwargs=batch
                    )
                ## print ppl
                ids2tokens = preprocessor.token_id_converter.ids2tokens
                for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
                    pre_word = "<s>"
                    cur_word = None
                    sent_lst = ids2tokens(sent_ids) + ['</s>']
                    ppl_out = " ".join(sent_lst) + "\n"
                    for word, word_nll in zip(sent_lst, sent_nll):
                        cur_word = word
                        word_nll = -word_nll.cpu()
                        if log_base is None:
                            word_prob = np.exp(word_nll)
                        else:
                            word_prob = log_base ** (word_nll / np.log(log_base))
                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
                            cur=cur_word,
                            pre=pre_word,
                            prob=round(word_prob.item(), 8),
                            word_nll=round(word_nll.item(), 8)
                        )
                        pre_word = cur_word
                    sent_nll_mean = sent_nll.mean().cpu().numpy()
                    sent_nll_sum = sent_nll.sum().cpu().numpy()
                    if log_base is None:
                        sent_ppl = np.exp(sent_nll_mean)
                    else:
                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
                        sent_nll=round(-sent_nll_sum.item(), 4),
                        sent_ppl=round(sent_ppl.item(), 4)
                    )
                    ppl_out_batch += ppl_out
                    utt2nll = round(-sent_nll_sum.item(), 5)
                    item = {'key': key, 'value': ppl_out}
                    if writer is not None:
                        writer["ppl"][key + ":\n"] = ppl_out
                        writer["utt2nll"][key] = str(utt2nll)
                    results.append(item)
            ppl_out_all += ppl_out_batch
            assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
            # nll: (B, L) -> (B,)
            nll = nll.detach().cpu().numpy().sum(1)
            # lengths: (B,)
            lengths = lengths.detach().cpu().numpy()
            total_nll += nll.sum()
            total_ntokens += lengths.sum()
        if log_base is None:
            ppl = np.exp(total_nll / total_ntokens)
        else:
            ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
        avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
            total_nll=round(-total_nll.item(), 4),
            total_ppl=round(ppl.item(), 4)
        )
        item = {'key': 'AVG PPL', 'value': avg_ppl}
        ppl_out_all += avg_ppl
        if writer is not None:
            writer["ppl"]["AVG PPL : "] = avg_ppl
        results.append(item)
        return results
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "transformer":
        return inference_lm(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Calc perplexity",
@@ -89,14 +378,6 @@
    group.add_argument("--model_file", type=str)
    group.add_argument("--mode", type=str, default="lm")
    return parser
def inference_launch(mode, **kwargs):
    if mode == "transformer":
        from funasr.bin.lm_inference import inference_modelscope
        return inference_modelscope(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def main(cmd=None):
funasr/bin/lm_train.py
@@ -1,4 +1,7 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import os
funasr/bin/punc_infer.py
@@ -1,4 +1,8 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
from pathlib import Path
funasr/bin/punc_inference_launch.py
@@ -1,5 +1,7 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
@@ -175,6 +177,16 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "punc":
        return inference_punc(**kwargs)
    if mode == "punc_VadRealtime":
        return inference_punc_vad_realtime(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Punctuation inference",
@@ -228,16 +240,6 @@
    group.add_argument("--model_file", type=str)
    group.add_argument("--mode", type=str, default="punc")
    return parser
def inference_launch(mode, **kwargs):
    if mode == "punc":
        return inference_punc(**kwargs)
    if mode == "punc_VadRealtime":
        return inference_punc_vad_realtime(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def main(cmd=None):
funasr/bin/punc_train.py
@@ -1,4 +1,8 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import os
from funasr.tasks.punctuation import PunctuationTask
funasr/bin/sa_asr_inference.py
@@ -1,3 +1,8 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import sys
funasr/bin/sa_asr_train.py
@@ -1,4 +1,7 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import os
funasr/bin/sv_infer.py
@@ -1,3 +1,4 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
funasr/bin/sv_inference_launch.py
@@ -1,7 +1,7 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
@@ -174,6 +174,15 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "sv":
        return inference_sv(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Speaker Verification",
@@ -287,14 +296,6 @@
    )
    return parser
def inference_launch(mode, **kwargs):
    if mode == "sv":
        return inference_sv(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def main(cmd=None):
funasr/bin/tp_infer.py
@@ -1,3 +1,8 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
from optparse import Option
funasr/bin/tp_inference_launch.py
@@ -1,4 +1,7 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
@@ -179,6 +182,15 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "tp_norm":
        return inference_tp(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Timestamp Prediction Inference",
@@ -264,13 +276,6 @@
    )
    return parser
def inference_launch(mode, **kwargs):
    if mode == "tp_norm":
        return inference_tp(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
funasr/bin/vad_infer.py
@@ -1,3 +1,8 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
funasr/bin/vad_inference_launch.py
@@ -1,6 +1,8 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
torch.set_num_threads(1)
@@ -267,6 +269,17 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "offline":
        return inference_vad(**kwargs)
    elif mode == "online":
        return inference_vad_online(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="VAD Decoding",
@@ -357,15 +370,6 @@
    )
    return parser
def inference_launch(mode, **kwargs):
    if mode == "offline":
        return inference_vad(**kwargs)
    elif mode == "online":
        return inference_vad_online(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)