From d105ce0d6b63bcd14edeb426fbc0acf593296be3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 13:58:11 +0800
Subject: [PATCH] inference

---
 funasr/bin/sa_asr_train.py          |    3 
 funasr/bin/punc_infer.py            |    4 
 funasr/bin/sv_inference_launch.py   |   19 
 funasr/bin/lm_train.py              |    3 
 funasr/bin/vad_inference_launch.py  |   26 
 funasr/bin/diar_train.py            |    3 
 funasr/bin/asr_train.py             |    3 
 funasr/bin/diar_inference_launch.py |   48 
 funasr/bin/lm_inference_launch.py   |  297 +++++++++
 funasr/bin/punc_inference_launch.py |   24 
 funasr/bin/sa_asr_inference.py      |    5 
 funasr/bin/vad_infer.py             |    5 
 funasr/bin/diar_infer.py            |    1 
 funasr/bin/asr_infer.py             |  345 ++++++++++
 funasr/bin/punc_train.py            |    4 
 /dev/null                           |  406 -------------
 funasr/bin/tp_infer.py              |    5 
 funasr/bin/build_trainer.py         |    5 
 funasr/bin/sv_infer.py              |    1 
 funasr/bin/tp_inference_launch.py   |   19 
 funasr/bin/asr_inference_launch.py  |  604 +++++++++++++++++-
 funasr/bin/asr_train_paraformer.py  |    9 
 22 files changed, 1,308 insertions(+), 531 deletions(-)

diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 488be16..bcf5718 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -1,4 +1,8 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import argparse
 import logging
 import sys
@@ -19,13 +23,15 @@
 
 import numpy as np
 import torch
+from packaging.version import parse as V
 from typeguard import check_argument_types
 from typeguard import check_return_type
 from funasr.fileio.datadir_writer import DatadirWriter
 from funasr.modules.beam_search.beam_search import BeamSearch
 # from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
-
 from funasr.modules.beam_search.beam_search import Hypothesis
+from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
+from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
 from funasr.modules.scorers.ctc import CTCPrefixScorer
 from funasr.modules.scorers.length_bonus import LengthBonus
 from funasr.modules.subsampling import TooShortUttError
@@ -47,11 +53,10 @@
 from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
 from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
 from funasr.bin.tp_infer import Speech2Timestamp
-from funasr.bin.vad_inference import Speech2VadSegment
+from funasr.bin.vad_infer import Speech2VadSegment
 from funasr.bin.punc_infer import Text2Punc
 from funasr.utils.vad_utils import slice_padding_fbank
 from funasr.tasks.vad import VADTask
-
 from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
 
 
@@ -264,7 +269,6 @@
         
         assert check_return_type(results)
         return results
-
 
 class Speech2TextParaformer:
     """Speech2Text class
@@ -839,7 +843,6 @@
         # assert check_return_type(results)
         return results
 
-
 class Speech2TextUniASR:
     """Speech2Text class
 
@@ -1072,9 +1075,7 @@
 
         assert check_return_type(results)
         return results
-
-
-    
+   
 
 class Speech2TextMFCCA:
     """Speech2Text class
@@ -1114,6 +1115,7 @@
         assert check_argument_types()
         
         # 1. Build ASR model
+        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
         scorers = {}
         asr_model, asr_train_args = ASRTask.build_model_from_file(
             asr_train_config, asr_model_file, cmvn_file, device
@@ -1270,3 +1272,330 @@
         return results
 
 
+class Speech2TextTransducer:
+    """Speech2Text class for Transducer models.
+    Args:
+        asr_train_config: ASR model training config path.
+        asr_model_file: ASR model path.
+        beam_search_config: Beam search config path.
+        lm_train_config: Language Model training config path.
+        lm_file: Language Model config path.
+        token_type: Type of token units.
+        bpemodel: BPE model path.
+        device: Device to use for inference.
+        beam_size: Size of beam during search.
+        dtype: Data type.
+        lm_weight: Language model weight.
+        quantize_asr_model: Whether to apply dynamic quantization to ASR model.
+        quantize_modules: List of module names to apply dynamic quantization on.
+        quantize_dtype: Dynamic quantization data type.
+        nbest: Number of final hypothesis.
+        streaming: Whether to perform chunk-by-chunk inference.
+        chunk_size: Number of frames in chunk AFTER subsampling.
+        left_context: Number of frames in left context AFTER subsampling.
+        right_context: Number of frames in right context AFTER subsampling.
+        display_partial_hypotheses: Whether to display partial hypotheses.
+    """
+    
+    def __init__(
+        self,
+        asr_train_config: Union[Path, str] = None,
+        asr_model_file: Union[Path, str] = None,
+        cmvn_file: Union[Path, str] = None,
+        beam_search_config: Dict[str, Any] = None,
+        lm_train_config: Union[Path, str] = None,
+        lm_file: Union[Path, str] = None,
+        token_type: str = None,
+        bpemodel: str = None,
+        device: str = "cpu",
+        beam_size: int = 5,
+        dtype: str = "float32",
+        lm_weight: float = 1.0,
+        quantize_asr_model: bool = False,
+        quantize_modules: List[str] = None,
+        quantize_dtype: str = "qint8",
+        nbest: int = 1,
+        streaming: bool = False,
+        simu_streaming: bool = False,
+        chunk_size: int = 16,
+        left_context: int = 32,
+        right_context: int = 0,
+        display_partial_hypotheses: bool = False,
+    ) -> None:
+        """Construct a Speech2Text object."""
+        super().__init__()
+        
+        assert check_argument_types()
+        from funasr.tasks.asr import ASRTransducerTask
+        asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device
+        )
+        
+        frontend = None
+        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+        
+        if quantize_asr_model:
+            if quantize_modules is not None:
+                if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
+                    raise ValueError(
+                        "Only 'Linear' and 'LSTM' modules are currently supported"
+                        " by PyTorch and in --quantize_modules"
+                    )
+                
+                q_config = set([getattr(torch.nn, q) for q in quantize_modules])
+            else:
+                q_config = {torch.nn.Linear}
+            
+            if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
+                raise ValueError(
+                    "float16 dtype for dynamic quantization is not supported with torch"
+                    " version < 1.5.0. Switching to qint8 dtype instead."
+                )
+            q_dtype = getattr(torch, quantize_dtype)
+            
+            asr_model = torch.quantization.quantize_dynamic(
+                asr_model, q_config, dtype=q_dtype
+            ).eval()
+        else:
+            asr_model.to(dtype=getattr(torch, dtype)).eval()
+        
+        if lm_train_config is not None:
+            lm, lm_train_args = LMTask.build_model_from_file(
+                lm_train_config, lm_file, device
+            )
+            lm_scorer = lm.lm
+        else:
+            lm_scorer = None
+        
+        # 4. Build BeamSearch object
+        if beam_search_config is None:
+            beam_search_config = {}
+        
+        beam_search = BeamSearchTransducer(
+            asr_model.decoder,
+            asr_model.joint_network,
+            beam_size,
+            lm=lm_scorer,
+            lm_weight=lm_weight,
+            nbest=nbest,
+            **beam_search_config,
+        )
+        
+        token_list = asr_model.token_list
+        
+        if token_type is None:
+            token_type = asr_train_args.token_type
+        if bpemodel is None:
+            bpemodel = asr_train_args.bpemodel
+        
+        if token_type is None:
+            tokenizer = None
+        elif token_type == "bpe":
+            if bpemodel is not None:
+                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+            else:
+                tokenizer = None
+        else:
+            tokenizer = build_tokenizer(token_type=token_type)
+        converter = TokenIDConverter(token_list=token_list)
+        logging.info(f"Text tokenizer: {tokenizer}")
+        
+        self.asr_model = asr_model
+        self.asr_train_args = asr_train_args
+        self.device = device
+        self.dtype = dtype
+        self.nbest = nbest
+        
+        self.converter = converter
+        self.tokenizer = tokenizer
+        
+        self.beam_search = beam_search
+        self.streaming = streaming
+        self.simu_streaming = simu_streaming
+        self.chunk_size = max(chunk_size, 0)
+        self.left_context = left_context
+        self.right_context = max(right_context, 0)
+        
+        if not streaming or chunk_size == 0:
+            self.streaming = False
+            self.asr_model.encoder.dynamic_chunk_training = False
+        
+        if not simu_streaming or chunk_size == 0:
+            self.simu_streaming = False
+            self.asr_model.encoder.dynamic_chunk_training = False
+        
+        self.frontend = frontend
+        self.window_size = self.chunk_size + self.right_context
+        
+        if self.streaming:
+            self._ctx = self.asr_model.encoder.get_encoder_input_size(
+                self.window_size
+            )
+            
+            self.last_chunk_length = (
+                self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+            )
+            self.reset_inference_cache()
+    
+    def reset_inference_cache(self) -> None:
+        """Reset Speech2Text parameters."""
+        self.frontend_cache = None
+        
+        self.asr_model.encoder.reset_streaming_cache(
+            self.left_context, device=self.device
+        )
+        self.beam_search.reset_inference_cache()
+        
+        self.num_processed_frames = torch.tensor([[0]], device=self.device)
+    
+    @torch.no_grad()
+    def streaming_decode(
+        self,
+        speech: Union[torch.Tensor, np.ndarray],
+        is_final: bool = True,
+    ) -> List[HypothesisTransducer]:
+        """Speech2Text streaming call.
+        Args:
+            speech: Chunk of speech data. (S)
+            is_final: Whether speech corresponds to the final chunk of data.
+        Returns:
+            nbest_hypothesis: N-best hypothesis.
+        """
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+        if is_final:
+            if self.streaming and speech.size(0) < self.last_chunk_length:
+                pad = torch.zeros(
+                    self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype
+                )
+                speech = torch.cat([speech, pad],
+                                   dim=0)  # feats, feats_length = self.apply_frontend(speech, is_final=is_final)
+        
+        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+        
+        if self.asr_model.normalize is not None:
+            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+        
+        feats = to_device(feats, device=self.device)
+        feats_lengths = to_device(feats_lengths, device=self.device)
+        enc_out = self.asr_model.encoder.chunk_forward(
+            feats,
+            feats_lengths,
+            self.num_processed_frames,
+            chunk_size=self.chunk_size,
+            left_context=self.left_context,
+            right_context=self.right_context,
+        )
+        nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
+        
+        self.num_processed_frames += self.chunk_size
+        
+        if is_final:
+            self.reset_inference_cache()
+        
+        return nbest_hyps
+    
+    @torch.no_grad()
+    def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
+        """Speech2Text call.
+        Args:
+            speech: Speech data. (S)
+        Returns:
+            nbest_hypothesis: N-best hypothesis.
+        """
+        assert check_argument_types()
+        
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+        
+        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+        
+        if self.asr_model.normalize is not None:
+            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+        
+        feats = to_device(feats, device=self.device)
+        feats_lengths = to_device(feats_lengths, device=self.device)
+        enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context,
+                                                            self.right_context)
+        nbest_hyps = self.beam_search(enc_out[0])
+        
+        return nbest_hyps
+    
+    @torch.no_grad()
+    def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
+        """Speech2Text call.
+        Args:
+            speech: Speech data. (S)
+        Returns:
+            nbest_hypothesis: N-best hypothesis.
+        """
+        assert check_argument_types()
+        
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+        
+        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+        
+        feats = to_device(feats, device=self.device)
+        feats_lengths = to_device(feats_lengths, device=self.device)
+        
+        enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
+        
+        nbest_hyps = self.beam_search(enc_out[0])
+        
+        return nbest_hyps
+    
+    def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]:
+        """Build partial or final results from the hypotheses.
+        Args:
+            nbest_hyps: N-best hypothesis.
+        Returns:
+            results: Results containing different representation for the hypothesis.
+        """
+        results = []
+        
+        for hyp in nbest_hyps:
+            token_int = list(filter(lambda x: x != 0, hyp.yseq))
+            
+            token = self.converter.ids2tokens(token_int)
+            
+            if self.tokenizer is not None:
+                text = self.tokenizer.tokens2text(token)
+            else:
+                text = None
+            results.append((text, token, token_int, hyp))
+            
+            assert check_return_type(results)
+        
+        return results
+    
+    @staticmethod
+    def from_pretrained(
+        model_tag: Optional[str] = None,
+        **kwargs: Optional[Any],
+    ) -> Speech2Text:
+        """Build Speech2Text instance from the pretrained model.
+        Args:
+            model_tag: Model tag of the pretrained models.
+        Return:
+            : Speech2Text instance.
+        """
+        if model_tag is not None:
+            try:
+                from espnet_model_zoo.downloader import ModelDownloader
+            
+            except ImportError:
+                logging.error(
+                    "`espnet_model_zoo` is not installed. "
+                    "Please install via `pip install -U espnet_model_zoo`."
+                )
+                raise
+            d = ModelDownloader()
+            kwargs.update(**d.download_and_unpack(model_tag))
+        
+        return Speech2Text(**kwargs)
+
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 1870032..4a55caa 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 import argparse
 import logging
@@ -61,15 +64,180 @@
 from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
 from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
 from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.bin.tp_inference import SpeechText2Timestamp
-from funasr.bin.vad_inference import Speech2VadSegment
-from funasr.bin.punctuation_infer import Text2Punc
+
+
 from funasr.utils.vad_utils import slice_padding_fbank
 from funasr.tasks.vad import VADTask
 from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
 from funasr.bin.asr_infer import Speech2Text
 from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
 from funasr.bin.asr_infer import Speech2TextUniASR
+from funasr.bin.asr_infer import Speech2TextMFCCA
+from funasr.bin.vad_infer import Speech2VadSegment
+from funasr.bin.punc_infer import Text2Punc
+from funasr.bin.tp_infer import Speech2Timestamp
+from funasr.bin.asr_infer import Speech2TextTransducer
+
+def inference_asr(
+    maxlenratio: float,
+    minlenratio: float,
+    batch_size: int,
+    beam_size: int,
+    ngpu: int,
+    ctc_weight: float,
+    lm_weight: float,
+    penalty: float,
+    log_level: Union[int, str],
+    # data_path_and_name_and_type,
+    asr_train_config: Optional[str],
+    asr_model_file: Optional[str],
+    cmvn_file: Optional[str] = None,
+    lm_train_config: Optional[str] = None,
+    lm_file: Optional[str] = None,
+    token_type: Optional[str] = None,
+    key_file: Optional[str] = None,
+    word_lm_train_config: Optional[str] = None,
+    bpemodel: Optional[str] = None,
+    allow_variable_data_keys: bool = False,
+    streaming: bool = False,
+    output_dir: Optional[str] = None,
+    dtype: str = "float32",
+    seed: int = 0,
+    ngram_weight: float = 0.9,
+    nbest: int = 1,
+    num_workers: int = 1,
+    mc: bool = False,
+    param_dict: dict = None,
+    **kwargs,
+):
+    assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
+    if batch_size > 1:
+        raise NotImplementedError("batch decoding is not implemented")
+    if word_lm_train_config is not None:
+        raise NotImplementedError("Word LM is not implemented")
+    if ngpu > 1:
+        raise NotImplementedError("only single GPU decoding is supported")
+    
+    for handler in logging.root.handlers[:]:
+        logging.root.removeHandler(handler)
+    
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+    
+    if ngpu >= 1 and torch.cuda.is_available():
+        device = "cuda"
+    else:
+        device = "cpu"
+    
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+    
+    # 2. Build speech2text
+    speech2text_kwargs = dict(
+        asr_train_config=asr_train_config,
+        asr_model_file=asr_model_file,
+        cmvn_file=cmvn_file,
+        lm_train_config=lm_train_config,
+        lm_file=lm_file,
+        token_type=token_type,
+        bpemodel=bpemodel,
+        device=device,
+        maxlenratio=maxlenratio,
+        minlenratio=minlenratio,
+        dtype=dtype,
+        beam_size=beam_size,
+        ctc_weight=ctc_weight,
+        lm_weight=lm_weight,
+        ngram_weight=ngram_weight,
+        penalty=penalty,
+        nbest=nbest,
+        streaming=streaming,
+    )
+    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+    speech2text = Speech2Text(**speech2text_kwargs)
+    
+    def _forward(data_path_and_name_and_type,
+                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+                 output_dir_v2: Optional[str] = None,
+                 fs: dict = None,
+                 param_dict: dict = None,
+                 **kwargs,
+                 ):
+        # 3. Build data-iterator
+        if data_path_and_name_and_type is None and raw_inputs is not None:
+            if isinstance(raw_inputs, torch.Tensor):
+                raw_inputs = raw_inputs.numpy()
+            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+        loader = ASRTask.build_streaming_iterator(
+            data_path_and_name_and_type,
+            dtype=dtype,
+            fs=fs,
+            mc=mc,
+            batch_size=batch_size,
+            key_file=key_file,
+            num_workers=num_workers,
+            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+            allow_variable_data_keys=allow_variable_data_keys,
+            inference=True,
+        )
+        
+        finish_count = 0
+        file_count = 1
+        # 7 .Start for-loop
+        # FIXME(kamo): The output format should be discussed about
+        asr_result_list = []
+        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+        if output_path is not None:
+            writer = DatadirWriter(output_path)
+        else:
+            writer = None
+        
+        for keys, batch in loader:
+            assert isinstance(batch, dict), type(batch)
+            assert all(isinstance(s, str) for s in keys), keys
+            _bs = len(next(iter(batch.values())))
+            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+            
+            # N-best list of (text, token, token_int, hyp_object)
+            try:
+                results = speech2text(**batch)
+            except TooShortUttError as e:
+                logging.warning(f"Utterance {keys} {e}")
+                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+                results = [[" ", ["sil"], [2], hyp]] * nbest
+            
+            # Only supporting batch_size==1
+            key = keys[0]
+            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+                # Create a directory: outdir/{n}best_recog
+                if writer is not None:
+                    ibest_writer = writer[f"{n}best_recog"]
+                    
+                    # Write the result to each file
+                    ibest_writer["token"][key] = " ".join(token)
+                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                    ibest_writer["score"][key] = str(hyp.score)
+                
+                if text is not None:
+                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+                    item = {'key': key, 'value': text_postprocessed}
+                    asr_result_list.append(item)
+                    finish_count += 1
+                    asr_utils.print_progress(finish_count / file_count)
+                    if writer is not None:
+                        ibest_writer["text"][key] = text
+                
+                logging.info("uttid: {}".format(key))
+                logging.info("text predictions: {}\n".format(text))
+        return asr_result_list
+    
+    return _forward
 
 
 def inference_paraformer(
@@ -161,7 +329,7 @@
     speech2text = Speech2TextParaformer(**speech2text_kwargs)
     
     if timestamp_model_file is not None:
-        speechtext2timestamp = SpeechText2Timestamp(
+        speechtext2timestamp = Speech2Timestamp(
             timestamp_cmvn_file=cmvn_file,
             timestamp_model_file=timestamp_model_file,
             timestamp_infer_config=timestamp_infer_config,
@@ -931,12 +1099,382 @@
     return _forward
 
 
+def inference_mfcca(
+    maxlenratio: float,
+    minlenratio: float,
+    batch_size: int,
+    beam_size: int,
+    ngpu: int,
+    ctc_weight: float,
+    lm_weight: float,
+    penalty: float,
+    log_level: Union[int, str],
+    # data_path_and_name_and_type,
+    asr_train_config: Optional[str],
+    asr_model_file: Optional[str],
+    cmvn_file: Optional[str] = None,
+    lm_train_config: Optional[str] = None,
+    lm_file: Optional[str] = None,
+    token_type: Optional[str] = None,
+    key_file: Optional[str] = None,
+    word_lm_train_config: Optional[str] = None,
+    bpemodel: Optional[str] = None,
+    allow_variable_data_keys: bool = False,
+    streaming: bool = False,
+    output_dir: Optional[str] = None,
+    dtype: str = "float32",
+    seed: int = 0,
+    ngram_weight: float = 0.9,
+    nbest: int = 1,
+    num_workers: int = 1,
+    param_dict: dict = None,
+    **kwargs,
+):
+    assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
+    if batch_size > 1:
+        raise NotImplementedError("batch decoding is not implemented")
+    if word_lm_train_config is not None:
+        raise NotImplementedError("Word LM is not implemented")
+    if ngpu > 1:
+        raise NotImplementedError("only single GPU decoding is supported")
+    
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+    
+    if ngpu >= 1 and torch.cuda.is_available():
+        device = "cuda"
+    else:
+        device = "cpu"
+    
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+    
+    # 2. Build speech2text
+    speech2text_kwargs = dict(
+        asr_train_config=asr_train_config,
+        asr_model_file=asr_model_file,
+        cmvn_file=cmvn_file,
+        lm_train_config=lm_train_config,
+        lm_file=lm_file,
+        token_type=token_type,
+        bpemodel=bpemodel,
+        device=device,
+        maxlenratio=maxlenratio,
+        minlenratio=minlenratio,
+        dtype=dtype,
+        beam_size=beam_size,
+        ctc_weight=ctc_weight,
+        lm_weight=lm_weight,
+        ngram_weight=ngram_weight,
+        penalty=penalty,
+        nbest=nbest,
+        streaming=streaming,
+    )
+    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+    speech2text = Speech2TextMFCCA(**speech2text_kwargs)
+    
+    def _forward(data_path_and_name_and_type,
+                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+                 output_dir_v2: Optional[str] = None,
+                 fs: dict = None,
+                 param_dict: dict = None,
+                 **kwargs,
+                 ):
+        # 3. Build data-iterator
+        if data_path_and_name_and_type is None and raw_inputs is not None:
+            if isinstance(raw_inputs, torch.Tensor):
+                raw_inputs = raw_inputs.numpy()
+            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+        loader = ASRTask.build_streaming_iterator(
+            data_path_and_name_and_type,
+            dtype=dtype,
+            batch_size=batch_size,
+            fs=fs,
+            mc=True,
+            key_file=key_file,
+            num_workers=num_workers,
+            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+            allow_variable_data_keys=allow_variable_data_keys,
+            inference=True,
+        )
+        
+        finish_count = 0
+        file_count = 1
+        # 7 .Start for-loop
+        # FIXME(kamo): The output format should be discussed about
+        asr_result_list = []
+        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+        if output_path is not None:
+            writer = DatadirWriter(output_path)
+        else:
+            writer = None
+        
+        for keys, batch in loader:
+            assert isinstance(batch, dict), type(batch)
+            assert all(isinstance(s, str) for s in keys), keys
+            _bs = len(next(iter(batch.values())))
+            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+            
+            # N-best list of (text, token, token_int, hyp_object)
+            try:
+                results = speech2text(**batch)
+            except TooShortUttError as e:
+                logging.warning(f"Utterance {keys} {e}")
+                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+                results = [[" ", ["<space>"], [2], hyp]] * nbest
+            
+            # Only supporting batch_size==1
+            key = keys[0]
+            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+                # Create a directory: outdir/{n}best_recog
+                if writer is not None:
+                    ibest_writer = writer[f"{n}best_recog"]
+                    
+                    # Write the result to each file
+                    ibest_writer["token"][key] = " ".join(token)
+                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                    ibest_writer["score"][key] = str(hyp.score)
+                
+                if text is not None:
+                    text_postprocessed = postprocess_utils.sentence_postprocess(token)
+                    item = {'key': key, 'value': text_postprocessed}
+                    asr_result_list.append(item)
+                    finish_count += 1
+                    asr_utils.print_progress(finish_count / file_count)
+                    if writer is not None:
+                        ibest_writer["text"][key] = text
+        return asr_result_list
+    
+    return _forward
+
+def inference_transducer(
+    output_dir: str,
+    batch_size: int,
+    dtype: str,
+    beam_size: int,
+    ngpu: int,
+    seed: int,
+    lm_weight: float,
+    nbest: int,
+    num_workers: int,
+    log_level: Union[int, str],
+    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
+    asr_train_config: Optional[str],
+    asr_model_file: Optional[str],
+    cmvn_file: Optional[str],
+    beam_search_config: Optional[dict],
+    lm_train_config: Optional[str],
+    lm_file: Optional[str],
+    model_tag: Optional[str],
+    token_type: Optional[str],
+    bpemodel: Optional[str],
+    key_file: Optional[str],
+    allow_variable_data_keys: bool,
+    quantize_asr_model: Optional[bool],
+    quantize_modules: Optional[List[str]],
+    quantize_dtype: Optional[str],
+    streaming: Optional[bool],
+    simu_streaming: Optional[bool],
+    chunk_size: Optional[int],
+    left_context: Optional[int],
+    right_context: Optional[int],
+    display_partial_hypotheses: bool,
+    **kwargs,
+) -> None:
+    """Transducer model inference.
+    Args:
+        output_dir: Output directory path.
+        batch_size: Batch decoding size.
+        dtype: Data type.
+        beam_size: Beam size.
+        ngpu: Number of GPUs.
+        seed: Random number generator seed.
+        lm_weight: Weight of language model.
+        nbest: Number of final hypothesis.
+        num_workers: Number of workers.
+        log_level: Level of verbose for logs.
+        data_path_and_name_and_type:
+        asr_train_config: ASR model training config path.
+        asr_model_file: ASR model path.
+        beam_search_config: Beam search config path.
+        lm_train_config: Language Model training config path.
+        lm_file: Language Model path.
+        model_tag: Model tag.
+        token_type: Type of token units.
+        bpemodel: BPE model path.
+        key_file: File key.
+        allow_variable_data_keys: Whether to allow variable data keys.
+        quantize_asr_model: Whether to apply dynamic quantization to ASR model.
+        quantize_modules: List of module names to apply dynamic quantization on.
+        quantize_dtype: Dynamic quantization data type.
+        streaming: Whether to perform chunk-by-chunk inference.
+        chunk_size: Number of frames in chunk AFTER subsampling.
+        left_context: Number of frames in left context AFTER subsampling.
+        right_context: Number of frames in right context AFTER subsampling.
+        display_partial_hypotheses: Whether to display partial hypotheses.
+    """
+    assert check_argument_types()
+
+    if batch_size > 1:
+        raise NotImplementedError("batch decoding is not implemented")
+    if ngpu > 1:
+        raise NotImplementedError("only single GPU decoding is supported")
+
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+
+    if ngpu >= 1:
+        device = "cuda"
+    else:
+        device = "cpu"
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+
+    # 2. Build speech2text
+    speech2text_kwargs = dict(
+        asr_train_config=asr_train_config,
+        asr_model_file=asr_model_file,
+        cmvn_file=cmvn_file,
+        beam_search_config=beam_search_config,
+        lm_train_config=lm_train_config,
+        lm_file=lm_file,
+        token_type=token_type,
+        bpemodel=bpemodel,
+        device=device,
+        dtype=dtype,
+        beam_size=beam_size,
+        lm_weight=lm_weight,
+        nbest=nbest,
+        quantize_asr_model=quantize_asr_model,
+        quantize_modules=quantize_modules,
+        quantize_dtype=quantize_dtype,
+        streaming=streaming,
+        simu_streaming=simu_streaming,
+        chunk_size=chunk_size,
+        left_context=left_context,
+        right_context=right_context,
+    )
+    speech2text = Speech2TextTransducer.from_pretrained(
+        model_tag=model_tag,
+        **speech2text_kwargs,
+    )
+    
+    def _forward(data_path_and_name_and_type,
+                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+                 output_dir_v2: Optional[str] = None,
+                 fs: dict = None,
+                 param_dict: dict = None,
+                 **kwargs,
+                 ):
+        # 3. Build data-iterator
+        loader = ASRTask.build_streaming_iterator(
+            data_path_and_name_and_type,
+            dtype=dtype,
+            batch_size=batch_size,
+            key_file=key_file,
+            num_workers=num_workers,
+            preprocess_fn=ASRTask.build_preprocess_fn(
+                speech2text.asr_train_args, False
+            ),
+            collate_fn=ASRTask.build_collate_fn(
+                speech2text.asr_train_args, False
+            ),
+            allow_variable_data_keys=allow_variable_data_keys,
+            inference=True,
+        )
+    
+        # 4 .Start for-loop
+        with DatadirWriter(output_dir) as writer:
+            for keys, batch in loader:
+                assert isinstance(batch, dict), type(batch)
+                assert all(isinstance(s, str) for s in keys), keys
+    
+                _bs = len(next(iter(batch.values())))
+                assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+                batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+                assert len(batch.keys()) == 1
+    
+                try:
+                    if speech2text.streaming:
+                        speech = batch["speech"]
+    
+                        _steps = len(speech) // speech2text._ctx
+                        _end = 0
+                        for i in range(_steps):
+                            _end = (i + 1) * speech2text._ctx
+    
+                            speech2text.streaming_decode(
+                                speech[i * speech2text._ctx : _end], is_final=False
+                            )
+    
+                        final_hyps = speech2text.streaming_decode(
+                            speech[_end : len(speech)], is_final=True
+                        )
+                    elif speech2text.simu_streaming:
+                        final_hyps = speech2text.simu_streaming_decode(**batch)
+                    else:
+                        final_hyps = speech2text(**batch)
+    
+                    results = speech2text.hypotheses_to_results(final_hyps)
+                except TooShortUttError as e:
+                    logging.warning(f"Utterance {keys} {e}")
+                    hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
+                    results = [[" ", ["<space>"], [2], hyp]] * nbest
+    
+                key = keys[0]
+                for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+                    ibest_writer = writer[f"{n}best_recog"]
+    
+                    ibest_writer["token"][key] = " ".join(token)
+                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                    ibest_writer["score"][key] = str(hyp.score)
+    
+                    if text is not None:
+                        ibest_writer["text"][key] = text
+
+
+    return _forward
+
+
+def inference_launch(**kwargs):
+    if 'mode' in kwargs:
+        mode = kwargs['mode']
+    else:
+        logging.info("Unknown decoding mode.")
+        return None
+    if mode == "asr":
+        return inference_asr(**kwargs)
+    elif mode == "uniasr":
+        return inference_uniasr(**kwargs)
+    elif mode == "paraformer":
+        return inference_paraformer(**kwargs)
+    elif mode == "paraformer_streaming":
+        return inference_paraformer_online(**kwargs)
+    elif mode.startswith("paraformer_vad"):
+        return inference_paraformer_vad_punc(**kwargs)
+    elif mode == "mfcca":
+        return inference_mfcca(**kwargs)
+    elif mode == "rnnt":
+        return inference_transducer(**kwargs)
+    else:
+        logging.info("Unknown decoding mode: {}".format(mode))
+        return None
+
+
 def get_parser():
     parser = config_argparse.ArgumentParser(
         description="ASR Decoding",
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
-
+    
     # Note(kamo): Use '_' instead of '-' as separator.
     # '-' is confusing if written in yaml.
     parser.add_argument(
@@ -946,7 +1484,7 @@
         choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
         help="The verbose level of logging",
     )
-
+    
     parser.add_argument("--output_dir", type=str, required=True)
     parser.add_argument(
         "--ngpu",
@@ -979,7 +1517,7 @@
         default=1,
         help="The number of workers used for DataLoader",
     )
-
+    
     group = parser.add_argument_group("Input data related")
     group.add_argument(
         "--data_path_and_name_and_type",
@@ -990,12 +1528,12 @@
     group.add_argument("--key_file", type=str_or_none)
     group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
     group.add_argument(
-            "--mc",
-            type=bool,
-            default=False,
-            help="MultiChannel input",
-        )
-        
+        "--mc",
+        type=bool,
+        default=False,
+        help="MultiChannel input",
+    )
+    
     group = parser.add_argument_group("The model configuration related")
     group.add_argument(
         "--vad_infer_config",
@@ -1058,7 +1596,7 @@
         default={},
         help="The keyword arguments for transducer beam search.",
     )
-
+    
     group = parser.add_argument_group("Beam-search related")
     group.add_argument(
         "--batch_size",
@@ -1104,8 +1642,8 @@
         type=bool,
         default=False,
         help="Whether to display partial hypotheses during chunk-by-chunk inference.",
-    )    
-   
+    )
+    
     group = parser.add_argument_group("Dynamic quantization related")
     group.add_argument(
         "--quantize_asr_model",
@@ -1129,8 +1667,8 @@
         default="qint8",
         choices=["float16", "qint8"],
         help="Dtype for dynamic quantization.",
-    )    
-
+    )
+    
     group = parser.add_argument_group("Text converter related")
     group.add_argument(
         "--token_type",
@@ -1157,36 +1695,6 @@
         help="CTC weight in joint decoding",
     )
     return parser
-
-
-
-def inference_launch(**kwargs):
-    if 'mode' in kwargs:
-        mode = kwargs['mode']
-    else:
-        logging.info("Unknown decoding mode.")
-        return None
-    if mode == "asr":
-        from funasr.bin.asr_inference import inference_modelscope
-        return inference_modelscope(**kwargs)
-    elif mode == "uniasr":
-        return inference_uniasr(**kwargs)
-    elif mode == "paraformer":
-        return inference_paraformer(**kwargs)
-    elif mode == "paraformer_streaming":
-        return inference_paraformer_online(**kwargs)
-    elif mode.startswith("paraformer_vad"):
-        return inference_paraformer_vad_punc(**kwargs)
-    elif mode == "mfcca":
-        from funasr.bin.asr_inference_mfcca import inference_modelscope
-        return inference_modelscope(**kwargs)
-    elif mode == "rnnt":
-        from funasr.bin.asr_inference_rnnt import inference_modelscope
-        return inference_modelscope(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
-
 
 
 def main(cmd=None):
diff --git a/funasr/bin/asr_train.py b/funasr/bin/asr_train.py
index a43472c..0dec107 100755
--- a/funasr/bin/asr_train.py
+++ b/funasr/bin/asr_train.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 import os
 
diff --git a/funasr/bin/asr_train_paraformer.py b/funasr/bin/asr_train_paraformer.py
index 76943d5..223be14 100755
--- a/funasr/bin/asr_train_paraformer.py
+++ b/funasr/bin/asr_train_paraformer.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 import os
 
@@ -9,6 +12,12 @@
 def parse_args():
     parser = ASRTask.get_parser()
     parser.add_argument(
+        "--mode",
+        type=str,
+        default="asr",
+        help="mode",
+    )
+    parser.add_argument(
         "--gpu_id",
         type=int,
         default=0,
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
index 5c30fdb..df3434f 100644
--- a/funasr/bin/build_trainer.py
+++ b/funasr/bin/build_trainer.py
@@ -1,3 +1,8 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import os
 
 import yaml
diff --git a/funasr/bin/diar_infer.py b/funasr/bin/diar_infer.py
index f698a66..f2dcb1e 100755
--- a/funasr/bin/diar_infer.py
+++ b/funasr/bin/diar_infer.py
@@ -1,3 +1,4 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index 08004e8..69d37d6 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -1,3 +1,4 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
@@ -362,6 +363,30 @@
     return _forward
 
 
+
+
+def inference_launch(mode, **kwargs):
+    if mode == "sond":
+        return inference_sond(mode=mode, **kwargs)
+    elif mode == "sond_demo":
+        param_dict = {
+            "extract_profile": True,
+            "sv_train_config": "sv.yaml",
+            "sv_model_file": "sv.pb",
+        }
+        if "param_dict" in kwargs and kwargs["param_dict"] is not None:
+            for key in param_dict:
+                if key not in kwargs["param_dict"]:
+                    kwargs["param_dict"][key] = param_dict[key]
+        else:
+            kwargs["param_dict"] = param_dict
+        return inference_sond(mode=mode, **kwargs)
+    elif mode == "eend-ola":
+        return inference_eend(mode=mode, **kwargs)
+    else:
+        logging.info("Unknown decoding mode: {}".format(mode))
+        return None
+
 def get_parser():
     parser = config_argparse.ArgumentParser(
         description="Speaker Verification",
@@ -469,29 +494,6 @@
     )
 
     return parser
-
-
-def inference_launch(mode, **kwargs):
-    if mode == "sond":
-        return inference_sond(mode=mode, **kwargs)
-    elif mode == "sond_demo":
-        param_dict = {
-            "extract_profile": True,
-            "sv_train_config": "sv.yaml",
-            "sv_model_file": "sv.pb",
-        }
-        if "param_dict" in kwargs and kwargs["param_dict"] is not None:
-            for key in param_dict:
-                if key not in kwargs["param_dict"]:
-                    kwargs["param_dict"][key] = param_dict[key]
-        else:
-            kwargs["param_dict"] = param_dict
-        return inference_sond(mode=mode, **kwargs)
-    elif mode == "eend-ola":
-        return inference_eend(mode=mode, **kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
 
 
 def main(cmd=None):
diff --git a/funasr/bin/diar_train.py b/funasr/bin/diar_train.py
index f76d1b9..16a4bd0 100755
--- a/funasr/bin/diar_train.py
+++ b/funasr/bin/diar_train.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 import os
 
diff --git a/funasr/bin/lm_calc_perplexity.py b/funasr/bin/lm_calc_perplexity.py
deleted file mode 100755
index 198d578..0000000
--- a/funasr/bin/lm_calc_perplexity.py
+++ /dev/null
@@ -1,211 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-from pathlib import Path
-import sys
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-from torch.nn.parallel import data_parallel
-from typeguard import check_argument_types
-
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.tasks.lm import LMTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import float_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-def calc_perplexity(
-    output_dir: str,
-    batch_size: int,
-    dtype: str,
-    ngpu: int,
-    seed: int,
-    num_workers: int,
-    log_level: Union[int, str],
-    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
-    key_file: Optional[str],
-    train_config: Optional[str],
-    model_file: Optional[str],
-    log_base: Optional[float],
-    allow_variable_data_keys: bool,
-):
-    assert check_argument_types()
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    if ngpu >= 1:
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build LM
-    model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device)
-    # Wrape model to make model.nll() data-parallel
-    wrapped_model = ForwardAdaptor(model, "nll")
-    wrapped_model.to(dtype=getattr(torch, dtype)).eval()
-    logging.info(f"Model:\n{model}")
-
-    # 3. Build data-iterator
-    loader = LMTask.build_streaming_iterator(
-        data_path_and_name_and_type,
-        dtype=dtype,
-        batch_size=batch_size,
-        key_file=key_file,
-        num_workers=num_workers,
-        preprocess_fn=LMTask.build_preprocess_fn(train_args, False),
-        collate_fn=LMTask.build_collate_fn(train_args, False),
-        allow_variable_data_keys=allow_variable_data_keys,
-        inference=True,
-    )
-
-    # 4. Start for-loop
-    with DatadirWriter(output_dir) as writer:
-        total_nll = 0.0
-        total_ntokens = 0
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
-            with torch.no_grad():
-                batch = to_device(batch, device)
-                if ngpu <= 1:
-                    # NOTE(kamo): data_parallel also should work with ngpu=1,
-                    # but for debuggability it's better to keep this block.
-                    nll, lengths = wrapped_model(**batch)
-                else:
-                    nll, lengths = data_parallel(
-                        wrapped_model, (), range(ngpu), module_kwargs=batch
-                    )
-
-            assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
-            # nll: (B, L) -> (B,)
-            nll = nll.detach().cpu().numpy().sum(1)
-            # lengths: (B,)
-            lengths = lengths.detach().cpu().numpy()
-            total_nll += nll.sum()
-            total_ntokens += lengths.sum()
-
-            for key, _nll, ntoken in zip(keys, nll, lengths):
-                if log_base is None:
-                    utt_ppl = np.exp(_nll / ntoken)
-                else:
-                    utt_ppl = log_base ** (_nll / ntoken / np.log(log_base))
-
-                # Write PPL of each utts for debugging or analysis
-                writer["utt2nll"][key] = str(-_nll)
-                writer["utt2ppl"][key] = str(utt_ppl)
-                writer["utt2ntokens"][key] = str(ntoken)
-
-        if log_base is None:
-            ppl = np.exp(total_nll / total_ntokens)
-        else:
-            ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
-
-        with (Path(output_dir) / "ppl").open("w", encoding="utf-8") as f:
-            f.write(f"{ppl}\n")
-        with (Path(output_dir) / "base").open("w", encoding="utf-8") as f:
-            if log_base is None:
-                _log_base = np.e
-            else:
-                _log_base = log_base
-            f.write(f"{_log_base}\n")
-        logging.info(f"PPL={ppl}")
-
-
-def get_parser():
-    parser = config_argparse.ArgumentParser(
-        description="Calc perplexity",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-
-    # Note(kamo): Use '_' instead of '-' as separator.
-    # '-' is confusing if written in yaml.
-    parser.add_argument(
-        "--log_level",
-        type=lambda x: x.upper(),
-        default="INFO",
-        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-        help="The verbose level of logging",
-    )
-
-    parser.add_argument("--output_dir", type=str, required=True)
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=0,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument(
-        "--dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type",
-    )
-    parser.add_argument(
-        "--num_workers",
-        type=int,
-        default=1,
-        help="The number of workers used for DataLoader",
-    )
-    parser.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="The batch size for inference",
-    )
-    parser.add_argument(
-        "--log_base",
-        type=float_or_none,
-        default=None,
-        help="The base of logarithm for Perplexity. "
-        "If None, napier's constant is used.",
-    )
-
-    group = parser.add_argument_group("Input data related")
-    group.add_argument(
-        "--data_path_and_name_and_type",
-        type=str2triple_str,
-        required=True,
-        action="append",
-    )
-    group.add_argument("--key_file", type=str_or_none)
-    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
-    group = parser.add_argument_group("The model configuration related")
-    group.add_argument("--train_config", type=str)
-    group.add_argument("--model_file", type=str)
-
-    return parser
-
-
-def main(cmd=None):
-    print(get_commandline_args(), file=sys.stderr)
-    parser = get_parser()
-    args = parser.parse_args(cmd)
-    kwargs = vars(args)
-    kwargs.pop("config", None)
-    calc_perplexity(**kwargs)
-
-
-if __name__ == "__main__":
-    main()
diff --git a/funasr/bin/lm_inference.py b/funasr/bin/lm_inference.py
deleted file mode 100644
index 76de6df..0000000
--- a/funasr/bin/lm_inference.py
+++ /dev/null
@@ -1,406 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-from pathlib import Path
-import sys
-import os
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-from typing import Any
-from typing import List
-
-import numpy as np
-import torch
-from torch.nn.parallel import data_parallel
-from typeguard import check_argument_types
-
-from funasr.tasks.lm import LMTask
-from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import float_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-def inference(
-    output_dir: str,
-    batch_size: int,
-    dtype: str,
-    ngpu: int,
-    seed: int,
-    num_workers: int,
-    log_level: Union[int, str],
-    train_config: Optional[str],
-    model_file: Optional[str],
-    log_base: Optional[float],
-    key_file: Optional[str] = None,
-    allow_variable_data_keys: bool = False,
-    split_with_space: Optional[bool] = False,
-    seg_dict_file: Optional[str] = None,
-    data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
-    raw_inputs: Union[List[Any], bytes, str] = None,
-    **kwargs,
-):
-    inference_pipeline = inference_modelscope(
-        output_dir=output_dir,
-        raw_inputs=raw_inputs,
-        batch_size=batch_size,
-        dtype=dtype,
-        ngpu=ngpu,
-        seed=seed,
-        num_workers=num_workers,
-        log_level=log_level,
-        key_file=key_file,
-        train_config=train_config,
-        model_file=model_file,
-        log_base = log_base,
-        allow_variable_data_keys = allow_variable_data_keys,
-        split_with_space=split_with_space,
-        seg_dict_file=seg_dict_file,
-        **kwargs,
-    )
-    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-
-def inference_modelscope(
-    batch_size: int,
-    dtype: str,
-    ngpu: int,
-    seed: int,
-    num_workers: int,
-    log_level: Union[int, str],
-    key_file: Optional[str],
-    train_config: Optional[str],
-    model_file: Optional[str],
-    log_base: Optional[float] = 10,
-    allow_variable_data_keys: bool = False,
-    split_with_space: Optional[bool] = False,
-    seg_dict_file: Optional[str] = None,
-    output_dir: Optional[str] = None,
-    param_dict: dict = None,
-    **kwargs,
-):
-    assert check_argument_types()
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build Model
-    model, train_args = LMTask.build_model_from_file(
-        train_config, model_file, device)
-    wrapped_model = ForwardAdaptor(model, "nll")
-    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
-    logging.info(f"Model:\n{model}")
-
-    preprocessor = LMPreprocessor(
-        train=False,
-        token_type=train_args.token_type,
-        token_list=train_args.token_list,
-        bpemodel=train_args.bpemodel,
-        text_cleaner=train_args.cleaner,
-        g2p_type=train_args.g2p,
-        text_name="text",
-        non_linguistic_symbols=train_args.non_linguistic_symbols,
-        split_with_space=split_with_space,
-        seg_dict_file=seg_dict_file
-    )
-
-    def _forward(
-        data_path_and_name_and_type,
-        raw_inputs: Union[List[Any], bytes, str] = None,
-        output_dir_v2: Optional[str] = None,
-        param_dict: dict = None,
-    ):
-        results = []
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-        else:
-            writer = None
-
-        if raw_inputs != None:
-            line = raw_inputs.strip()
-            key = "lm demo"
-            if line=="":
-                item = {'key': key, 'value': ""}
-                results.append(item)
-                return results
-            batch = {}
-            batch['text'] = line
-            if preprocessor != None:
-                batch = preprocessor(key, batch)
-            
-            #  Force data-precision
-            for name in batch:
-                value = batch[name]
-                if not isinstance(value, np.ndarray):
-                    raise RuntimeError(
-                        f"All values must be converted to np.ndarray object "
-                        f'by preprocessing, but "{name}" is still {type(value)}.'
-                    )
-                # Cast to desired type
-                if value.dtype.kind == "f":
-                    value = value.astype("float32")
-                elif value.dtype.kind == "i":
-                    value = value.astype("long")
-                else:
-                    raise NotImplementedError(f"Not supported dtype: {value.dtype}")
-                batch[name] = value
-            
-            batch["text_lengths"] = torch.from_numpy(
-                np.array([len(batch["text"])], dtype='int32'))
-            batch["text"] = np.expand_dims(batch["text"], axis=0)
-
-            with torch.no_grad():
-                batch = to_device(batch, device)
-                if ngpu <= 1:
-                    nll, lengths = wrapped_model(**batch)
-                else:
-                    nll, lengths = data_parallel(
-                        wrapped_model, (), range(ngpu), module_kwargs=batch
-                    )
-                ## compute ppl
-                ppl_out_batch = ""
-                ids2tokens = preprocessor.token_id_converter.ids2tokens
-                for sent_ids, sent_nll in zip(batch['text'], nll):
-                    pre_word = "<s>"
-                    cur_word = None
-                    sent_lst = ids2tokens(sent_ids) + ['</s>']
-                    ppl_out = " ".join(sent_lst) + "\n"
-                    for word, word_nll in zip(sent_lst, sent_nll):
-                        cur_word = word
-                        word_nll = -word_nll.cpu()
-                        if log_base is None:
-                            word_prob = np.exp(word_nll)
-                        else:
-                            word_prob = log_base ** (word_nll / np.log(log_base))
-                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
-                            cur=cur_word, 
-                            pre=pre_word, 
-                            prob=round(word_prob.item(), 8),
-                            word_nll=round(word_nll.item(), 8)
-                            )
-                        pre_word = cur_word
-                    
-                    sent_nll_mean = sent_nll.mean().cpu().numpy()
-                    sent_nll_sum = sent_nll.sum().cpu().numpy()
-                    if log_base is None:
-                        sent_ppl = np.exp(sent_nll_mean)
-                    else:
-                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
-                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
-                        sent_nll=round(-sent_nll_sum.item(), 4),
-                        sent_ppl=round(sent_ppl.item(), 4)
-                        )
-                    ppl_out_batch += ppl_out
-                    item = {'key': key, 'value': ppl_out}
-                    if writer is not None:
-                        writer["ppl"][key+":\n"] = ppl_out
-                    results.append(item)
-
-            return results
-                
-        # 3. Build data-iterator
-        loader = LMTask.build_streaming_iterator(
-            data_path_and_name_and_type,
-            dtype=dtype,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-            preprocess_fn=preprocessor,
-            collate_fn=LMTask.build_collate_fn(train_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
-        )
-
-        # 4. Start for-loop
-        total_nll = 0.0
-        total_ntokens = 0
-        ppl_out_all = ""
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
-            ppl_out_batch = ""
-            with torch.no_grad():
-                batch = to_device(batch, device)
-                if ngpu <= 1:
-                    # NOTE(kamo): data_parallel also should work with ngpu=1,
-                    # but for debuggability it's better to keep this block.
-                    nll, lengths = wrapped_model(**batch)
-                else:
-                    nll, lengths = data_parallel(
-                        wrapped_model, (), range(ngpu), module_kwargs=batch
-                    )
-                ## print ppl
-                ids2tokens = preprocessor.token_id_converter.ids2tokens
-                for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
-                    pre_word = "<s>"
-                    cur_word = None
-                    sent_lst = ids2tokens(sent_ids) + ['</s>']
-                    ppl_out = " ".join(sent_lst) + "\n"
-                    for word, word_nll in zip(sent_lst, sent_nll):
-                        cur_word = word
-                        word_nll = -word_nll.cpu()
-                        if log_base is None:
-                            word_prob = np.exp(word_nll)
-                        else:
-                            word_prob = log_base ** (word_nll / np.log(log_base))
-                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
-                            cur=cur_word, 
-                            pre=pre_word, 
-                            prob=round(word_prob.item(), 8),
-                            word_nll=round(word_nll.item(), 8)
-                            )
-                        pre_word = cur_word
-                    
-                    sent_nll_mean = sent_nll.mean().cpu().numpy()
-                    sent_nll_sum = sent_nll.sum().cpu().numpy()
-                    if log_base is None:
-                        sent_ppl = np.exp(sent_nll_mean)
-                    else:
-                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
-                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
-                        sent_nll=round(-sent_nll_sum.item(), 4),
-                        sent_ppl=round(sent_ppl.item(), 4)
-                        )
-                    ppl_out_batch += ppl_out
-                    utt2nll = round(-sent_nll_sum.item(), 5)
-                    item = {'key': key, 'value': ppl_out}
-                    if writer is not None:
-                        writer["ppl"][key+":\n"] = ppl_out
-                        writer["utt2nll"][key] = str(utt2nll)
-                    results.append(item)
-
-            ppl_out_all += ppl_out_batch
-            
-            assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
-            # nll: (B, L) -> (B,)
-            nll = nll.detach().cpu().numpy().sum(1)
-            # lengths: (B,)
-            lengths = lengths.detach().cpu().numpy()
-            total_nll += nll.sum()
-            total_ntokens += lengths.sum()
-
-        if log_base is None:
-            ppl = np.exp(total_nll / total_ntokens)
-        else:
-            ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
-
-        avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
-            total_nll=round(-total_nll.item(), 4),
-            total_ppl=round(ppl.item(), 4)
-            )
-        item = {'key': 'AVG PPL', 'value': avg_ppl}
-        ppl_out_all += avg_ppl
-        if writer is not None:
-            writer["ppl"]["AVG PPL : "] = avg_ppl
-        results.append(item)
-
-        return results
-
-    return _forward
-
-
-def get_parser():
-    parser = config_argparse.ArgumentParser(
-        description="Calc perplexity",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-
-    parser.add_argument(
-        "--log_level",
-        type=lambda x: x.upper(),
-        default="INFO",
-        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-        help="The verbose level of logging",
-    )
-
-    parser.add_argument("--output_dir", type=str, required=False)
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=0,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument(
-        "--dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type",
-    )
-    parser.add_argument(
-        "--num_workers",
-        type=int,
-        default=1,
-        help="The number of workers used for DataLoader",
-    )
-    parser.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="The batch size for inference",
-    )
-    parser.add_argument(
-        "--log_base",
-        type=float_or_none,
-        default=10,
-        help="The base of logarithm for Perplexity. "
-             "If None, napier's constant is used.",
-        required=False
-    )
-
-    group = parser.add_argument_group("Input data related")
-    group.add_argument(
-        "--data_path_and_name_and_type",
-        type=str2triple_str,
-        action="append",
-        required=False
-    )
-    group.add_argument(
-        "--raw_inputs",
-        type=str,
-        required=False
-    )
-    group.add_argument("--key_file", type=str_or_none)
-    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
-    group.add_argument("--split_with_space", type=str2bool, default=False)
-    group.add_argument("--seg_dict_file", type=str_or_none)
-
-    group = parser.add_argument_group("The model configuration related")
-    group.add_argument("--train_config", type=str)
-    group.add_argument("--model_file", type=str)
-
-    return parser
-
-
-def main(cmd=None):
-    print(get_commandline_args(), file=sys.stderr)
-    parser = get_parser()
-    args = parser.parse_args(cmd)
-    kwargs = vars(args)
-    inference(**kwargs)
-
-if __name__ == "__main__":
-    main()
-
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
index dc6414f..0840e6e 100644
--- a/funasr/bin/lm_inference_launch.py
+++ b/funasr/bin/lm_inference_launch.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 
 
@@ -14,8 +17,294 @@
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils.types import float_or_none
+import argparse
+import logging
+from pathlib import Path
+import sys
+import os
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+from typing import Any
+from typing import List
+
+import numpy as np
+import torch
+from torch.nn.parallel import data_parallel
+from typeguard import check_argument_types
+
+from funasr.tasks.lm import LMTask
+from funasr.datasets.preprocessor import LMPreprocessor
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import float_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
 
 
+def inference_lm(
+    batch_size: int,
+    dtype: str,
+    ngpu: int,
+    seed: int,
+    num_workers: int,
+    log_level: Union[int, str],
+    key_file: Optional[str],
+    train_config: Optional[str],
+    model_file: Optional[str],
+    log_base: Optional[float] = 10,
+    allow_variable_data_keys: bool = False,
+    split_with_space: Optional[bool] = False,
+    seg_dict_file: Optional[str] = None,
+    output_dir: Optional[str] = None,
+    param_dict: dict = None,
+    **kwargs,
+):
+    assert check_argument_types()
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
+    
+    if ngpu >= 1 and torch.cuda.is_available():
+        device = "cuda"
+    else:
+        device = "cpu"
+    
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+    
+    # 2. Build Model
+    model, train_args = LMTask.build_model_from_file(
+        train_config, model_file, device)
+    wrapped_model = ForwardAdaptor(model, "nll")
+    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+    logging.info(f"Model:\n{model}")
+    
+    preprocessor = LMPreprocessor(
+        train=False,
+        token_type=train_args.token_type,
+        token_list=train_args.token_list,
+        bpemodel=train_args.bpemodel,
+        text_cleaner=train_args.cleaner,
+        g2p_type=train_args.g2p,
+        text_name="text",
+        non_linguistic_symbols=train_args.non_linguistic_symbols,
+        split_with_space=split_with_space,
+        seg_dict_file=seg_dict_file
+    )
+    
+    def _forward(
+        data_path_and_name_and_type,
+        raw_inputs: Union[List[Any], bytes, str] = None,
+        output_dir_v2: Optional[str] = None,
+        param_dict: dict = None,
+    ):
+        results = []
+        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+        if output_path is not None:
+            writer = DatadirWriter(output_path)
+        else:
+            writer = None
+        
+        if raw_inputs != None:
+            line = raw_inputs.strip()
+            key = "lm demo"
+            if line == "":
+                item = {'key': key, 'value': ""}
+                results.append(item)
+                return results
+            batch = {}
+            batch['text'] = line
+            if preprocessor != None:
+                batch = preprocessor(key, batch)
+            
+            #  Force data-precision
+            for name in batch:
+                value = batch[name]
+                if not isinstance(value, np.ndarray):
+                    raise RuntimeError(
+                        f"All values must be converted to np.ndarray object "
+                        f'by preprocessing, but "{name}" is still {type(value)}.'
+                    )
+                # Cast to desired type
+                if value.dtype.kind == "f":
+                    value = value.astype("float32")
+                elif value.dtype.kind == "i":
+                    value = value.astype("long")
+                else:
+                    raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+                batch[name] = value
+            
+            batch["text_lengths"] = torch.from_numpy(
+                np.array([len(batch["text"])], dtype='int32'))
+            batch["text"] = np.expand_dims(batch["text"], axis=0)
+            
+            with torch.no_grad():
+                batch = to_device(batch, device)
+                if ngpu <= 1:
+                    nll, lengths = wrapped_model(**batch)
+                else:
+                    nll, lengths = data_parallel(
+                        wrapped_model, (), range(ngpu), module_kwargs=batch
+                    )
+                ## compute ppl
+                ppl_out_batch = ""
+                ids2tokens = preprocessor.token_id_converter.ids2tokens
+                for sent_ids, sent_nll in zip(batch['text'], nll):
+                    pre_word = "<s>"
+                    cur_word = None
+                    sent_lst = ids2tokens(sent_ids) + ['</s>']
+                    ppl_out = " ".join(sent_lst) + "\n"
+                    for word, word_nll in zip(sent_lst, sent_nll):
+                        cur_word = word
+                        word_nll = -word_nll.cpu()
+                        if log_base is None:
+                            word_prob = np.exp(word_nll)
+                        else:
+                            word_prob = log_base ** (word_nll / np.log(log_base))
+                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+                            cur=cur_word,
+                            pre=pre_word,
+                            prob=round(word_prob.item(), 8),
+                            word_nll=round(word_nll.item(), 8)
+                        )
+                        pre_word = cur_word
+                    
+                    sent_nll_mean = sent_nll.mean().cpu().numpy()
+                    sent_nll_sum = sent_nll.sum().cpu().numpy()
+                    if log_base is None:
+                        sent_ppl = np.exp(sent_nll_mean)
+                    else:
+                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+                        sent_nll=round(-sent_nll_sum.item(), 4),
+                        sent_ppl=round(sent_ppl.item(), 4)
+                    )
+                    ppl_out_batch += ppl_out
+                    item = {'key': key, 'value': ppl_out}
+                    if writer is not None:
+                        writer["ppl"][key + ":\n"] = ppl_out
+                    results.append(item)
+            
+            return results
+        
+        # 3. Build data-iterator
+        loader = LMTask.build_streaming_iterator(
+            data_path_and_name_and_type,
+            dtype=dtype,
+            batch_size=batch_size,
+            key_file=key_file,
+            num_workers=num_workers,
+            preprocess_fn=preprocessor,
+            collate_fn=LMTask.build_collate_fn(train_args, False),
+            allow_variable_data_keys=allow_variable_data_keys,
+            inference=True,
+        )
+        
+        # 4. Start for-loop
+        total_nll = 0.0
+        total_ntokens = 0
+        ppl_out_all = ""
+        for keys, batch in loader:
+            assert isinstance(batch, dict), type(batch)
+            assert all(isinstance(s, str) for s in keys), keys
+            _bs = len(next(iter(batch.values())))
+            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+            
+            ppl_out_batch = ""
+            with torch.no_grad():
+                batch = to_device(batch, device)
+                if ngpu <= 1:
+                    # NOTE(kamo): data_parallel also should work with ngpu=1,
+                    # but for debuggability it's better to keep this block.
+                    nll, lengths = wrapped_model(**batch)
+                else:
+                    nll, lengths = data_parallel(
+                        wrapped_model, (), range(ngpu), module_kwargs=batch
+                    )
+                ## print ppl
+                ids2tokens = preprocessor.token_id_converter.ids2tokens
+                for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
+                    pre_word = "<s>"
+                    cur_word = None
+                    sent_lst = ids2tokens(sent_ids) + ['</s>']
+                    ppl_out = " ".join(sent_lst) + "\n"
+                    for word, word_nll in zip(sent_lst, sent_nll):
+                        cur_word = word
+                        word_nll = -word_nll.cpu()
+                        if log_base is None:
+                            word_prob = np.exp(word_nll)
+                        else:
+                            word_prob = log_base ** (word_nll / np.log(log_base))
+                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+                            cur=cur_word,
+                            pre=pre_word,
+                            prob=round(word_prob.item(), 8),
+                            word_nll=round(word_nll.item(), 8)
+                        )
+                        pre_word = cur_word
+                    
+                    sent_nll_mean = sent_nll.mean().cpu().numpy()
+                    sent_nll_sum = sent_nll.sum().cpu().numpy()
+                    if log_base is None:
+                        sent_ppl = np.exp(sent_nll_mean)
+                    else:
+                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+                        sent_nll=round(-sent_nll_sum.item(), 4),
+                        sent_ppl=round(sent_ppl.item(), 4)
+                    )
+                    ppl_out_batch += ppl_out
+                    utt2nll = round(-sent_nll_sum.item(), 5)
+                    item = {'key': key, 'value': ppl_out}
+                    if writer is not None:
+                        writer["ppl"][key + ":\n"] = ppl_out
+                        writer["utt2nll"][key] = str(utt2nll)
+                    results.append(item)
+            
+            ppl_out_all += ppl_out_batch
+            
+            assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
+            # nll: (B, L) -> (B,)
+            nll = nll.detach().cpu().numpy().sum(1)
+            # lengths: (B,)
+            lengths = lengths.detach().cpu().numpy()
+            total_nll += nll.sum()
+            total_ntokens += lengths.sum()
+        
+        if log_base is None:
+            ppl = np.exp(total_nll / total_ntokens)
+        else:
+            ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
+        
+        avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
+            total_nll=round(-total_nll.item(), 4),
+            total_ppl=round(ppl.item(), 4)
+        )
+        item = {'key': 'AVG PPL', 'value': avg_ppl}
+        ppl_out_all += avg_ppl
+        if writer is not None:
+            writer["ppl"]["AVG PPL : "] = avg_ppl
+        results.append(item)
+        
+        return results
+    
+    return _forward
+
+
+def inference_launch(mode, **kwargs):
+    if mode == "transformer":
+        return inference_lm(**kwargs)
+    else:
+        logging.info("Unknown decoding mode: {}".format(mode))
+        return None
+    
 def get_parser():
     parser = config_argparse.ArgumentParser(
         description="Calc perplexity",
@@ -89,14 +378,6 @@
     group.add_argument("--model_file", type=str)
     group.add_argument("--mode", type=str, default="lm")
     return parser
-
-def inference_launch(mode, **kwargs):
-    if mode == "transformer":
-        from funasr.bin.lm_inference import inference_modelscope
-        return inference_modelscope(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
 
 
 def main(cmd=None):
diff --git a/funasr/bin/lm_train.py b/funasr/bin/lm_train.py
index 8641465..22b5f9c 100755
--- a/funasr/bin/lm_train.py
+++ b/funasr/bin/lm_train.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 import os
 
diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
index 41c4da3..4b6cd27 100644
--- a/funasr/bin/punc_infer.py
+++ b/funasr/bin/punc_infer.py
@@ -1,4 +1,8 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import argparse
 import logging
 from pathlib import Path
diff --git a/funasr/bin/punc_inference_launch.py b/funasr/bin/punc_inference_launch.py
index 594a7be..7f60f81 100755
--- a/funasr/bin/punc_inference_launch.py
+++ b/funasr/bin/punc_inference_launch.py
@@ -1,5 +1,7 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 import argparse
 import logging
@@ -175,6 +177,16 @@
     return _forward
 
 
+
+def inference_launch(mode, **kwargs):
+    if mode == "punc":
+        return inference_punc(**kwargs)
+    if mode == "punc_VadRealtime":
+        return inference_punc_vad_realtime(**kwargs)
+    else:
+        logging.info("Unknown decoding mode: {}".format(mode))
+        return None
+
 def get_parser():
     parser = config_argparse.ArgumentParser(
         description="Punctuation inference",
@@ -228,16 +240,6 @@
     group.add_argument("--model_file", type=str)
     group.add_argument("--mode", type=str, default="punc")
     return parser
-
-
-def inference_launch(mode, **kwargs):
-    if mode == "punc":
-        return inference_punc(**kwargs)
-    if mode == "punc_VadRealtime":
-        return inference_punc_vad_realtime(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
 
 
 def main(cmd=None):
diff --git a/funasr/bin/punc_train.py b/funasr/bin/punc_train.py
index 61b63ec..aeded7b 100644
--- a/funasr/bin/punc_train.py
+++ b/funasr/bin/punc_train.py
@@ -1,4 +1,8 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import os
 from funasr.tasks.punctuation import PunctuationTask
 
diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py
index c894f54..7a5ba83 100644
--- a/funasr/bin/sa_asr_inference.py
+++ b/funasr/bin/sa_asr_inference.py
@@ -1,3 +1,8 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import argparse
 import logging
 import sys
diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py
index 07b9b19..67106cf 100755
--- a/funasr/bin/sa_asr_train.py
+++ b/funasr/bin/sa_asr_train.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 import os
 
diff --git a/funasr/bin/sv_infer.py b/funasr/bin/sv_infer.py
index 8a9c6e9..9761497 100755
--- a/funasr/bin/sv_infer.py
+++ b/funasr/bin/sv_infer.py
@@ -1,3 +1,4 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py
index 24b8638..8e00730 100755
--- a/funasr/bin/sv_inference_launch.py
+++ b/funasr/bin/sv_inference_launch.py
@@ -1,7 +1,7 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
-
 
 import argparse
 import logging
@@ -174,6 +174,15 @@
     return _forward
 
 
+
+
+def inference_launch(mode, **kwargs):
+    if mode == "sv":
+        return inference_sv(**kwargs)
+    else:
+        logging.info("Unknown decoding mode: {}".format(mode))
+        return None
+
 def get_parser():
     parser = config_argparse.ArgumentParser(
         description="Speaker Verification",
@@ -287,14 +296,6 @@
     )
 
     return parser
-
-
-def inference_launch(mode, **kwargs):
-    if mode == "sv":
-        return inference_sv(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
 
 
 def main(cmd=None):
diff --git a/funasr/bin/tp_infer.py b/funasr/bin/tp_infer.py
index c83ceea..4ddcba4 100644
--- a/funasr/bin/tp_infer.py
+++ b/funasr/bin/tp_infer.py
@@ -1,3 +1,8 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import argparse
 import logging
 from optparse import Option
diff --git a/funasr/bin/tp_inference_launch.py b/funasr/bin/tp_inference_launch.py
index 2b2b2ae..a8d67ef 100644
--- a/funasr/bin/tp_inference_launch.py
+++ b/funasr/bin/tp_inference_launch.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 
 import argparse
@@ -179,6 +182,15 @@
     return _forward
 
 
+
+
+def inference_launch(mode, **kwargs):
+    if mode == "tp_norm":
+        return inference_tp(**kwargs)
+    else:
+        logging.info("Unknown decoding mode: {}".format(mode))
+        return None
+
 def get_parser():
     parser = config_argparse.ArgumentParser(
         description="Timestamp Prediction Inference",
@@ -264,13 +276,6 @@
     )
     return parser
 
-
-def inference_launch(mode, **kwargs):
-    if mode == "tp_norm":
-        return inference_tp(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
 
 def main(cmd=None):
     print(get_commandline_args(), file=sys.stderr)
diff --git a/funasr/bin/vad_infer.py b/funasr/bin/vad_infer.py
index 5835e77..245757c 100644
--- a/funasr/bin/vad_infer.py
+++ b/funasr/bin/vad_infer.py
@@ -1,3 +1,8 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import argparse
 import logging
 import os
diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py
index 2ccc716..1f17c5b 100644
--- a/funasr/bin/vad_inference_launch.py
+++ b/funasr/bin/vad_inference_launch.py
@@ -1,6 +1,8 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 
 import torch
 torch.set_num_threads(1)
@@ -267,6 +269,17 @@
     return _forward
 
 
+
+
+def inference_launch(mode, **kwargs):
+    if mode == "offline":
+        return inference_vad(**kwargs)
+    elif mode == "online":
+        return inference_vad_online(**kwargs)
+    else:
+        logging.info("Unknown decoding mode: {}".format(mode))
+        return None
+
 def get_parser():
     parser = config_argparse.ArgumentParser(
         description="VAD Decoding",
@@ -357,15 +370,6 @@
     )
     return parser
 
-
-def inference_launch(mode, **kwargs):
-    if mode == "offline":
-        return inference_vad(**kwargs)
-    elif mode == "online":
-        return inference_vad_online(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
 
 def main(cmd=None):
     print(get_commandline_args(), file=sys.stderr)

--
Gitblit v1.9.1