From 7012ca2efc130103c4acd24e3678c7ae280f8db4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 十二月 2023 20:08:55 +0800
Subject: [PATCH] funasr2 paraformer biciparaformer contextuaparaformer

---
 examples/industrial_data_pretraining/paraformer-large/run.sh |    2 
 funasr/models/paraformer/model.py                            |    2 
 /dev/null                                                    |  655 ------------------------------------------------------
 funasr/bin/train.py                                          |    7 
 funasr/utils/trainer.py                                      |    0 
 funasr/bin/export_model.py                                   |    0 
 setup.py                                                     |    2 
 funasr/models/model_class_factory.py                         |   22 -
 8 files changed, 5 insertions(+), 685 deletions(-)

diff --git a/examples/industrial_data_pretraining/paraformer-large/run.sh b/examples/industrial_data_pretraining/paraformer-large/run.sh
index 8571974..ce1953c 100644
--- a/examples/industrial_data_pretraining/paraformer-large/run.sh
+++ b/examples/industrial_data_pretraining/paraformer-large/run.sh
@@ -1,5 +1,5 @@
 
-cmd="funasr/cli/train_cli.py"
+cmd="funasr/bin/train.py"
 
 python $cmd \
 +model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
diff --git a/funasr/bin/argument.py b/funasr/bin/argument.py
deleted file mode 100644
index 0ea4ac9..0000000
--- a/funasr/bin/argument.py
+++ /dev/null
@@ -1,262 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import sys
-
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import config_argparse
-import argparse
-
-
-def get_parser():
-    parser = config_argparse.ArgumentParser(
-        description="ASR Decoding",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-
-    # Note(kamo): Use '_' instead of '-' as separator.
-    # '-' is confusing if written in yaml.
-    parser.add_argument(
-        "--log_level",
-        type=lambda x: x.upper(),
-        default="INFO",
-        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-        help="The verbose level of logging",
-    )
-
-    parser.add_argument("--output_dir", type=str, default=None)
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=1,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument(
-        "--njob",
-        type=int,
-        default=1,
-        help="The number of jobs for each gpu",
-    )
-    parser.add_argument(
-        "--gpuid_list",
-        type=str,
-        default="",
-        help="The visible gpus",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument(
-        "--dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type",
-    )
-    parser.add_argument(
-        "--num_workers",
-        type=int,
-        default=1,
-        help="The number of workers used for DataLoader",
-    )
-
-    group = parser.add_argument_group("Input data related")
-    group.add_argument(
-        "--data_path_and_name_and_type",
-        type=str2triple_str,
-        required=False,
-        action="append",
-    )
-    group.add_argument("--key_file", type=str_or_none)
-    parser.add_argument(
-        "--hotword",
-        type=str_or_none,
-        default=None,
-        help="hotword file path or hotwords seperated by space"
-    )
-    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-    group.add_argument(
-        "--mc",
-        type=bool,
-        default=False,
-        help="MultiChannel input",
-    )
-
-    group = parser.add_argument_group("The model configuration related")
-    group.add_argument(
-        "--vad_infer_config",
-        type=str,
-        help="VAD infer configuration",
-    )
-    group.add_argument(
-        "--vad_model_file",
-        type=str,
-        help="VAD model parameter file",
-    )
-    group.add_argument(
-        "--punc_infer_config",
-        type=str,
-        help="PUNC infer configuration",
-    )
-    group.add_argument(
-        "--punc_model_file",
-        type=str,
-        help="PUNC model parameter file",
-    )
-    group.add_argument(
-        "--cmvn_file",
-        type=str,
-        help="Global CMVN file",
-    )
-    group.add_argument(
-        "--asr_train_config",
-        type=str,
-        help="ASR training configuration",
-    )
-    group.add_argument(
-        "--asr_model_file",
-        type=str,
-        help="ASR model parameter file",
-    )
-    group.add_argument(
-        "--sv_model_file",
-        type=str,
-        help="SV model parameter file",
-    )
-    group.add_argument(
-        "--lm_train_config",
-        type=str,
-        help="LM training configuration",
-    )
-    group.add_argument(
-        "--lm_file",
-        type=str,
-        help="LM parameter file",
-    )
-    group.add_argument(
-        "--word_lm_train_config",
-        type=str,
-        help="Word LM training configuration",
-    )
-    group.add_argument(
-        "--word_lm_file",
-        type=str,
-        help="Word LM parameter file",
-    )
-    group.add_argument(
-        "--ngram_file",
-        type=str,
-        help="N-gram parameter file",
-    )
-    group.add_argument(
-        "--model_tag",
-        type=str,
-        help="Pretrained model tag. If specify this option, *_train_config and "
-             "*_file will be overwritten",
-    )
-    group.add_argument(
-        "--beam_search_config",
-        default={},
-        help="The keyword arguments for transducer beam search.",
-    )
-
-    group = parser.add_argument_group("Beam-search related")
-    group.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="The batch size for inference",
-    )
-    group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
-    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
-    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
-    group.add_argument(
-        "--maxlenratio",
-        type=float,
-        default=0.0,
-        help="Input length ratio to obtain max output length. "
-             "If maxlenratio=0.0 (default), it uses a end-detect "
-             "function "
-             "to automatically find maximum hypothesis lengths."
-             "If maxlenratio<0.0, its absolute value is interpreted"
-             "as a constant max output length",
-    )
-    group.add_argument(
-        "--minlenratio",
-        type=float,
-        default=0.0,
-        help="Input length ratio to obtain min output length",
-    )
-    group.add_argument(
-        "--ctc_weight",
-        type=float,
-        default=0.0,
-        help="CTC weight in joint decoding",
-    )
-    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
-    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
-    group.add_argument("--streaming", type=str2bool, default=False)
-    group.add_argument("--fake_streaming", type=str2bool, default=False)
-    group.add_argument("--full_utt", type=str2bool, default=False)
-    group.add_argument("--chunk_size", type=int, default=16)
-    group.add_argument("--left_context", type=int, default=16)
-    group.add_argument("--right_context", type=int, default=0)
-    group.add_argument(
-        "--display_partial_hypotheses",
-        type=bool,
-        default=False,
-        help="Whether to display partial hypotheses during chunk-by-chunk inference.",
-    )
-
-    group = parser.add_argument_group("Dynamic quantization related")
-    group.add_argument(
-        "--quantize_asr_model",
-        type=bool,
-        default=False,
-        help="Apply dynamic quantization to ASR model.",
-    )
-    group.add_argument(
-        "--quantize_modules",
-        nargs="*",
-        default=None,
-        help="""Module names to apply dynamic quantization on.
-        The module names are provided as a list, where each name is separated
-        by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
-        Each specified name should be an attribute of 'torch.nn', e.g.:
-        torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
-    )
-    group.add_argument(
-        "--quantize_dtype",
-        type=str,
-        default="qint8",
-        choices=["float16", "qint8"],
-        help="Dtype for dynamic quantization.",
-    )
-
-    group = parser.add_argument_group("Text converter related")
-    group.add_argument(
-        "--token_type",
-        type=str_or_none,
-        default=None,
-        choices=["char", "bpe", None],
-        help="The token type for ASR model. "
-             "If not given, refers from the training args",
-    )
-    group.add_argument(
-        "--bpemodel",
-        type=str_or_none,
-        default=None,
-        help="The model path of sentencepiece. "
-             "If not given, refers from the training args",
-    )
-    group.add_argument("--token_num_relax", type=int, default=1, help="")
-    group.add_argument("--decoding_ind", type=int, default=0, help="")
-    group.add_argument("--decoding_mode", type=str, default="model1", help="")
-    group.add_argument(
-        "--ctc_weight2",
-        type=float,
-        default=0.0,
-        help="CTC weight in joint decoding",
-    )
-    return parser
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
deleted file mode 100644
index a1cede1..0000000
--- a/funasr/bin/asr_infer.py
+++ /dev/null
@@ -1,2004 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-
-import codecs
-import copy
-import logging
-import os
-import re
-import tempfile
-from pathlib import Path
-from typing import Any
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import requests
-import torch
-from packaging.version import parse as V
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
-from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.modules.beam_search.beam_search import BeamSearch
-from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
-from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
-from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.build_utils.build_asr_model import frontend_choices
-from funasr.tokenizer.build_tokenizer import build_tokenizer
-from funasr.tokenizer.token_id_converter import TokenIDConverter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-
-
-class Speech2Text:
-    """Speech2Text class
-
-    Examples:
-        >>> import librosa
-        >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
-        >>> audio, rate = librosa.load("speech.wav")
-        >>> speech2text(audio)
-        [(text, token, token_int, hypothesis object), ...]
-
-    """
-
-    def __init__(
-            self,
-            asr_train_config: Union[Path, str] = None,
-            asr_model_file: Union[Path, str] = None,
-            cmvn_file: Union[Path, str] = None,
-            lm_train_config: Union[Path, str] = None,
-            lm_file: Union[Path, str] = None,
-            token_type: str = None,
-            bpemodel: str = None,
-            device: str = "cpu",
-            maxlenratio: float = 0.0,
-            minlenratio: float = 0.0,
-            batch_size: int = 1,
-            dtype: str = "float32",
-            beam_size: int = 20,
-            ctc_weight: float = 0.5,
-            lm_weight: float = 1.0,
-            ngram_weight: float = 0.9,
-            penalty: float = 0.0,
-            nbest: int = 1,
-            streaming: bool = False,
-            frontend_conf: dict = None,
-            **kwargs,
-    ):
-
-        # 1. Build ASR model
-        scorers = {}
-        asr_model, asr_train_args = build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
-        )
-        frontend = None
-        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            if asr_train_args.frontend == 'wav_frontend':
-                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-            else:
-                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
-                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
-
-        logging.info("asr_model: {}".format(asr_model))
-        logging.info("asr_train_args: {}".format(asr_train_args))
-        asr_model.to(dtype=getattr(torch, dtype)).eval()
-
-        decoder = asr_model.decoder
-
-        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
-        token_list = asr_model.token_list
-        scorers.update(
-            decoder=decoder,
-            ctc=ctc,
-            length_bonus=LengthBonus(len(token_list)),
-        )
-
-        # 2. Build Language model
-        if lm_train_config is not None:
-            lm, lm_train_args = build_model_from_file(
-                lm_train_config, lm_file, None, device
-            )
-            scorers["lm"] = lm.lm
-
-        # 3. Build ngram model
-        # ngram is not supported now
-        ngram = None
-        scorers["ngram"] = ngram
-
-        # 4. Build BeamSearch object
-        # transducer is not supported now
-        beam_search_transducer = None
-        from funasr.modules.beam_search.beam_search import BeamSearch
-
-        weights = dict(
-            decoder=1.0 - ctc_weight,
-            ctc=ctc_weight,
-            lm=lm_weight,
-            ngram=ngram_weight,
-            length_bonus=penalty,
-        )
-        beam_search = BeamSearch(
-            beam_size=beam_size,
-            weights=weights,
-            scorers=scorers,
-            sos=asr_model.sos,
-            eos=asr_model.eos,
-            vocab_size=len(token_list),
-            token_list=token_list,
-            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
-        )
-
-        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
-        if token_type is None:
-            token_type = asr_train_args.token_type
-        if bpemodel is None:
-            bpemodel = asr_train_args.bpemodel
-
-        if token_type is None:
-            tokenizer = None
-        elif token_type == "bpe":
-            if bpemodel is not None:
-                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
-            else:
-                tokenizer = None
-        else:
-            tokenizer = build_tokenizer(token_type=token_type)
-        converter = TokenIDConverter(token_list=token_list)
-        logging.info(f"Text tokenizer: {tokenizer}")
-
-        self.asr_model = asr_model
-        self.asr_train_args = asr_train_args
-        self.converter = converter
-        self.tokenizer = tokenizer
-        self.beam_search = beam_search
-        self.beam_search_transducer = beam_search_transducer
-        self.maxlenratio = maxlenratio
-        self.minlenratio = minlenratio
-        self.device = device
-        self.dtype = dtype
-        self.nbest = nbest
-        self.frontend = frontend
-
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
-    ) -> List[
-        Tuple[
-            Optional[str],
-            List[str],
-            List[int],
-            Union[Hypothesis],
-        ]
-    ]:
-        """Inference
-
-        Args:
-            speech: Input speech data
-        Returns:
-            text, token, token_int, hyp
-
-        """
-
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-
-        if self.frontend is not None:
-            feats, feats_len = self.frontend.forward(speech, speech_lengths)
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
-            self.asr_model.frontend = None
-        else:
-            feats = speech
-            feats_len = speech_lengths
-        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
-        batch = {"speech": feats, "speech_lengths": feats_len}
-
-        # a. To device
-        batch = to_device(batch, device=self.device)
-
-        # b. Forward Encoder
-        enc, _ = self.asr_model.encode(**batch)
-        if isinstance(enc, tuple):
-            enc = enc[0]
-        assert len(enc) == 1, len(enc)
-
-        # c. Passed the encoder result and the beam search
-        nbest_hyps = self.beam_search(
-            x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
-        )
-
-        nbest_hyps = nbest_hyps[: self.nbest]
-
-        results = []
-        for hyp in nbest_hyps:
-            assert isinstance(hyp, (Hypothesis)), type(hyp)
-
-            # remove sos/eos and get results
-            last_pos = -1
-            if isinstance(hyp.yseq, list):
-                token_int = hyp.yseq[1:last_pos]
-            else:
-                token_int = hyp.yseq[1:last_pos].tolist()
-
-            # remove blank symbol id, which is assumed to be 0
-            token_int = list(filter(lambda x: x != 0, token_int))
-
-            # Change integer-ids to tokens
-            token = self.converter.ids2tokens(token_int)
-
-            if self.tokenizer is not None:
-                text = self.tokenizer.tokens2text(token)
-            else:
-                text = None
-            results.append((text, token, token_int, hyp))
-
-        return results
-
-
-class Speech2TextParaformer:
-    """Speech2Text class
-
-    Examples:
-            >>> import librosa
-            >>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb")
-            >>> audio, rate = librosa.load("speech.wav")
-            >>> speech2text(audio)
-            [(text, token, token_int, hypothesis object), ...]
-
-    """
-
-    def __init__(
-            self,
-            asr_train_config: Union[Path, str] = None,
-            asr_model_file: Union[Path, str] = None,
-            cmvn_file: Union[Path, str] = None,
-            lm_train_config: Union[Path, str] = None,
-            lm_file: Union[Path, str] = None,
-            token_type: str = None,
-            bpemodel: str = None,
-            device: str = "cpu",
-            maxlenratio: float = 0.0,
-            minlenratio: float = 0.0,
-            dtype: str = "float32",
-            beam_size: int = 20,
-            ctc_weight: float = 0.5,
-            lm_weight: float = 1.0,
-            ngram_weight: float = 0.9,
-            penalty: float = 0.0,
-            nbest: int = 1,
-            frontend_conf: dict = None,
-            hotword_list_or_file: str = None,
-            clas_scale: float = 1.0,
-            decoding_ind: int = 0,
-            **kwargs,
-    ):
-
-        # 1. Build ASR model
-        scorers = {}
-        asr_model, asr_train_args = build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
-        )
-        frontend = None
-        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
-        logging.info("asr_model: {}".format(asr_model))
-        logging.info("asr_train_args: {}".format(asr_train_args))
-        asr_model.to(dtype=getattr(torch, dtype)).eval()
-
-        if asr_model.ctc != None:
-            ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
-            scorers.update(
-                ctc=ctc
-            )
-        token_list = asr_model.token_list
-        scorers.update(
-            length_bonus=LengthBonus(len(token_list)),
-        )
-
-        # 2. Build Language model
-        if lm_train_config is not None:
-            lm, lm_train_args = build_model_from_file(
-                lm_train_config, lm_file, None, device, task_name="lm"
-            )
-            scorers["lm"] = lm.lm
-
-        # 3. Build ngram model
-        # ngram is not supported now
-        ngram = None
-        scorers["ngram"] = ngram
-
-        # 4. Build BeamSearch object
-        # transducer is not supported now
-        beam_search_transducer = None
-        from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
-
-        weights = dict(
-            decoder=1.0 - ctc_weight,
-            ctc=ctc_weight,
-            lm=lm_weight,
-            ngram=ngram_weight,
-            length_bonus=penalty,
-        )
-        beam_search = BeamSearch(
-            beam_size=beam_size,
-            weights=weights,
-            scorers=scorers,
-            sos=asr_model.sos,
-            eos=asr_model.eos,
-            vocab_size=len(token_list),
-            token_list=token_list,
-            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
-        )
-
-        beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
-        for scorer in scorers.values():
-            if isinstance(scorer, torch.nn.Module):
-                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-
-        logging.info(f"Decoding device={device}, dtype={dtype}")
-
-        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
-        if token_type is None:
-            token_type = asr_train_args.token_type
-        if bpemodel is None:
-            bpemodel = asr_train_args.bpemodel
-
-        if token_type is None:
-            tokenizer = None
-        elif token_type == "bpe":
-            if bpemodel is not None:
-                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
-            else:
-                tokenizer = None
-        else:
-            tokenizer = build_tokenizer(token_type=token_type)
-        converter = TokenIDConverter(token_list=token_list)
-        logging.info(f"Text tokenizer: {tokenizer}")
-
-        self.asr_model = asr_model
-        self.asr_train_args = asr_train_args
-        self.converter = converter
-        self.tokenizer = tokenizer
-        self.cmvn_file = cmvn_file
-
-        # 6. [Optional] Build hotword list from str, local file or url
-        self.hotword_list = None
-        self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
-        self.clas_scale = clas_scale
-
-        is_use_lm = lm_weight != 0.0 and lm_file is not None
-        if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
-            beam_search = None
-        self.beam_search = beam_search
-        logging.info(f"Beam_search: {self.beam_search}")
-        self.beam_search_transducer = beam_search_transducer
-        self.maxlenratio = maxlenratio
-        self.minlenratio = minlenratio
-        self.device = device
-        self.dtype = dtype
-        self.nbest = nbest
-        self.frontend = frontend
-        self.encoder_downsampling_factor = 1
-        self.decoding_ind = decoding_ind
-        if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
-            self.encoder_downsampling_factor = 4
-
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
-            decoding_ind: int = None, begin_time: int = 0, end_time: int = None,
-    ):
-        """Inference
-
-        Args:
-                speech: Input speech data
-        Returns:
-                text, token, token_int, hyp
-
-        """
-
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-
-        if self.frontend is not None:
-            feats, feats_len = self.frontend.forward(speech, speech_lengths)
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
-            self.asr_model.frontend = None
-        else:
-            feats = speech
-            feats_len = speech_lengths
-        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
-        batch = {"speech": feats, "speech_lengths": feats_len}
-
-        # a. To device
-        batch = to_device(batch, device=self.device)
-
-        # b. Forward Encoder
-        if decoding_ind is None:
-            decoding_ind = 0 if self.decoding_ind is None else self.decoding_ind
-        enc, enc_len = self.asr_model.encode(**batch, ind=decoding_ind)
-        if isinstance(enc, tuple):
-            enc = enc[0]
-        # assert len(enc) == 1, len(enc)
-        enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
-
-        predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
-                                                                        predictor_outs[2], predictor_outs[3]
-        pre_token_length = pre_token_length.round().long()
-        if torch.max(pre_token_length) < 1:
-            return []
-        if not isinstance(self.asr_model, ContextualParaformer) and \
-            not isinstance(self.asr_model, NeatContextualParaformer):
-            if self.hotword_list:
-                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
-            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
-                                                                     pre_token_length)
-            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-        else:
-            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, 
-                                                                     enc_len, 
-                                                                     pre_acoustic_embeds,
-                                                                     pre_token_length, 
-                                                                     hw_list=self.hotword_list,
-                                                                     clas_scale=self.clas_scale)
-            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-
-        if isinstance(self.asr_model, BiCifParaformer):
-            _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
-                                                                                pre_token_length)  # test no bias cif2
-
-        results = []
-        b, n, d = decoder_out.size()
-        for i in range(b):
-            x = enc[i, :enc_len[i], :]
-            am_scores = decoder_out[i, :pre_token_length[i], :]
-            if self.beam_search is not None:
-                nbest_hyps = self.beam_search(
-                    x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
-                )
-
-                nbest_hyps = nbest_hyps[: self.nbest]
-            else:
-                if pre_token_length[i] == 0:
-                    yseq = torch.tensor(
-                        [self.asr_model.sos] + [self.asr_model.eos], device=pre_acoustic_embeds.device
-                    )
-                    score = torch.tensor(0.0, device=pre_acoustic_embeds.device)
-                else:
-                    yseq = am_scores.argmax(dim=-1)
-                    score = am_scores.max(dim=-1)[0]
-                    score = torch.sum(score, dim=-1)
-                    # pad with mask tokens to ensure compatibility with sos/eos tokens
-                    yseq = torch.tensor(
-                        [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
-                    )
-                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-            for hyp in nbest_hyps:
-                assert isinstance(hyp, (Hypothesis)), type(hyp)
-
-                # remove sos/eos and get results
-                last_pos = -1
-                if isinstance(hyp.yseq, list):
-                    token_int = hyp.yseq[1:last_pos]
-                else:
-                    token_int = hyp.yseq[1:last_pos].tolist()
-
-                # remove blank symbol id, which is assumed to be 0
-                token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-
-                # Change integer-ids to tokens
-                token = self.converter.ids2tokens(token_int)
-
-                if self.tokenizer is not None:
-                    text = self.tokenizer.tokens2text(token)
-                else:
-                    text = None
-                timestamp = []
-                if isinstance(self.asr_model, BiCifParaformer):
-                    _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i] * 3],
-                                                               us_peaks[i][:enc_len[i] * 3],
-                                                               copy.copy(token),
-                                                               vad_offset=begin_time)
-                results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
-
-        return results
-
-    def generate_hotwords_list(self, hotword_list_or_file):
-        def load_seg_dict(seg_dict_file):
-            seg_dict = {}
-            assert isinstance(seg_dict_file, str)
-            with open(seg_dict_file, "r", encoding="utf8") as f:
-                lines = f.readlines()
-                for line in lines:
-                    s = line.strip().split()
-                    key = s[0]
-                    value = s[1:]
-                    seg_dict[key] = " ".join(value)
-            return seg_dict
-
-        def seg_tokenize(txt, seg_dict):
-            pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
-            out_txt = ""
-            for word in txt:
-                word = word.lower()
-                if word in seg_dict:
-                    out_txt += seg_dict[word] + " "
-                else:
-                    if pattern.match(word):
-                        for char in word:
-                            if char in seg_dict:
-                                out_txt += seg_dict[char] + " "
-                            else:
-                                out_txt += "<unk>" + " "
-                    else:
-                        out_txt += "<unk>" + " "
-            return out_txt.strip().split()
-
-        seg_dict = None
-        if self.cmvn_file is not None:
-            model_dir = os.path.dirname(self.cmvn_file)
-            seg_dict_file = os.path.join(model_dir, 'seg_dict')
-            if os.path.exists(seg_dict_file):
-                seg_dict = load_seg_dict(seg_dict_file)
-            else:
-                seg_dict = None
-        # for None
-        if hotword_list_or_file is None:
-            hotword_list = None
-        # for local txt inputs
-        elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
-            logging.info("Attempting to parse hotwords from local txt...")
-            hotword_list = []
-            hotword_str_list = []
-            with codecs.open(hotword_list_or_file, 'r') as fin:
-                for line in fin.readlines():
-                    hw = line.strip()
-                    hw_list = hw.split()
-                    if seg_dict is not None:
-                        hw_list = seg_tokenize(hw_list, seg_dict)
-                    hotword_str_list.append(hw)
-                    hotword_list.append(self.converter.tokens2ids(hw_list))
-                hotword_list.append([self.asr_model.sos])
-                hotword_str_list.append('<s>')
-            logging.info("Initialized hotword list from file: {}, hotword list: {}."
-                         .format(hotword_list_or_file, hotword_str_list))
-        # for url, download and generate txt
-        elif hotword_list_or_file.startswith('http'):
-            logging.info("Attempting to parse hotwords from url...")
-            work_dir = tempfile.TemporaryDirectory().name
-            if not os.path.exists(work_dir):
-                os.makedirs(work_dir)
-            text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
-            local_file = requests.get(hotword_list_or_file)
-            open(text_file_path, "wb").write(local_file.content)
-            hotword_list_or_file = text_file_path
-            hotword_list = []
-            hotword_str_list = []
-            with codecs.open(hotword_list_or_file, 'r') as fin:
-                for line in fin.readlines():
-                    hw = line.strip()
-                    hw_list = hw.split()
-                    if seg_dict is not None:
-                        hw_list = seg_tokenize(hw_list, seg_dict)
-                    hotword_str_list.append(hw)
-                    hotword_list.append(self.converter.tokens2ids(hw_list))
-                hotword_list.append([self.asr_model.sos])
-                hotword_str_list.append('<s>')
-            logging.info("Initialized hotword list from file: {}, hotword list: {}."
-                         .format(hotword_list_or_file, hotword_str_list))
-        # for text str input
-        elif not hotword_list_or_file.endswith('.txt'):
-            logging.info("Attempting to parse hotwords as str...")
-            hotword_list = []
-            hotword_str_list = []
-            for hw in hotword_list_or_file.strip().split():
-                hotword_str_list.append(hw)
-                hw_list = hw.strip().split()
-                if seg_dict is not None:
-                    hw_list = seg_tokenize(hw_list, seg_dict)
-                hotword_list.append(self.converter.tokens2ids(hw_list))
-            hotword_list.append([self.asr_model.sos])
-            hotword_str_list.append('<s>')
-            logging.info("Hotword list: {}.".format(hotword_str_list))
-        else:
-            hotword_list = None
-        return hotword_list
-
-
-class Speech2TextParaformerOnline:
-    """Speech2Text class
-
-    Examples:
-            >>> import librosa
-            >>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth")
-            >>> audio, rate = librosa.load("speech.wav")
-            >>> speech2text(audio)
-            [(text, token, token_int, hypothesis object), ...]
-
-    """
-
-    def __init__(
-            self,
-            asr_train_config: Union[Path, str] = None,
-            asr_model_file: Union[Path, str] = None,
-            cmvn_file: Union[Path, str] = None,
-            lm_train_config: Union[Path, str] = None,
-            lm_file: Union[Path, str] = None,
-            token_type: str = None,
-            bpemodel: str = None,
-            device: str = "cpu",
-            maxlenratio: float = 0.0,
-            minlenratio: float = 0.0,
-            dtype: str = "float32",
-            beam_size: int = 20,
-            ctc_weight: float = 0.5,
-            lm_weight: float = 1.0,
-            ngram_weight: float = 0.9,
-            penalty: float = 0.0,
-            nbest: int = 1,
-            frontend_conf: dict = None,
-            hotword_list_or_file: str = None,
-            **kwargs,
-    ):
-
-        # 1. Build ASR model
-        scorers = {}
-        asr_model, asr_train_args = build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
-        )
-        frontend = None
-        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            frontend = WavFrontendOnline(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
-        logging.info("asr_model: {}".format(asr_model))
-        logging.info("asr_train_args: {}".format(asr_train_args))
-        asr_model.to(dtype=getattr(torch, dtype)).eval()
-
-        if asr_model.ctc != None:
-            ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
-            scorers.update(
-                ctc=ctc
-            )
-        token_list = asr_model.token_list
-        scorers.update(
-            length_bonus=LengthBonus(len(token_list)),
-        )
-
-        # 2. Build Language model
-        if lm_train_config is not None:
-            lm, lm_train_args = build_model_from_file(
-                lm_train_config, lm_file, None, device, task_name="lm"
-            )
-            scorers["lm"] = lm.lm
-
-        # 3. Build ngram model
-        # ngram is not supported now
-        ngram = None
-        scorers["ngram"] = ngram
-
-        # 4. Build BeamSearch object
-        # transducer is not supported now
-        beam_search_transducer = None
-        from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
-
-        weights = dict(
-            decoder=1.0 - ctc_weight,
-            ctc=ctc_weight,
-            lm=lm_weight,
-            ngram=ngram_weight,
-            length_bonus=penalty,
-        )
-        beam_search = BeamSearch(
-            beam_size=beam_size,
-            weights=weights,
-            scorers=scorers,
-            sos=asr_model.sos,
-            eos=asr_model.eos,
-            vocab_size=len(token_list),
-            token_list=token_list,
-            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
-        )
-
-        beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
-        for scorer in scorers.values():
-            if isinstance(scorer, torch.nn.Module):
-                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-
-        logging.info(f"Decoding device={device}, dtype={dtype}")
-
-        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
-        if token_type is None:
-            token_type = asr_train_args.token_type
-        if bpemodel is None:
-            bpemodel = asr_train_args.bpemodel
-
-        if token_type is None:
-            tokenizer = None
-        elif token_type == "bpe":
-            if bpemodel is not None:
-                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
-            else:
-                tokenizer = None
-        else:
-            tokenizer = build_tokenizer(token_type=token_type)
-        converter = TokenIDConverter(token_list=token_list)
-        logging.info(f"Text tokenizer: {tokenizer}")
-
-        self.asr_model = asr_model
-        self.asr_train_args = asr_train_args
-        self.converter = converter
-        self.tokenizer = tokenizer
-
-        # 6. [Optional] Build hotword list from str, local file or url
-
-        is_use_lm = lm_weight != 0.0 and lm_file is not None
-        if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
-            beam_search = None
-        self.beam_search = beam_search
-        logging.info(f"Beam_search: {self.beam_search}")
-        self.beam_search_transducer = beam_search_transducer
-        self.maxlenratio = maxlenratio
-        self.minlenratio = minlenratio
-        self.device = device
-        self.dtype = dtype
-        self.nbest = nbest
-        self.frontend = frontend
-        self.encoder_downsampling_factor = 1
-        if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
-            self.encoder_downsampling_factor = 4
-
-    @torch.no_grad()
-    def __call__(
-            self, cache: dict, speech: Union[torch.Tensor], speech_lengths: Union[torch.Tensor] = None
-    ):
-        """Inference
-
-        Args:
-                speech: Input speech data
-        Returns:
-                text, token, token_int, hyp
-
-        """
-        results = []
-        cache_en = cache["encoder"]
-        if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
-            if cache_en["start_idx"] == 0:
-                return []
-            cache_en["tail_chunk"] = True
-            feats = cache_en["feats"]
-            feats_len = torch.tensor([feats.shape[1]])
-            self.asr_model.frontend = None
-            self.frontend.cache_reset()
-            results = self.infer(feats, feats_len, cache)
-            return results
-        else:
-            if self.frontend is not None:
-                if cache_en["start_idx"] == 0:
-                    self.frontend.cache_reset()
-                feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"])
-                feats = to_device(feats, device=self.device)
-                feats_len = feats_len.int()
-                self.asr_model.frontend = None
-            else:
-                feats = speech
-                feats_len = speech_lengths
-
-            if feats.shape[1] != 0:
-                results = self.infer(feats, feats_len, cache)
-
-        return results
-
-    @torch.no_grad()
-    def infer(self, feats: Union[torch.Tensor], feats_len: Union[torch.Tensor], cache: List = None):
-        batch = {"speech": feats, "speech_lengths": feats_len}
-        batch = to_device(batch, device=self.device)
-        # b. Forward Encoder
-        enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache=cache)
-        if isinstance(enc, tuple):
-            enc = enc[0]
-        # assert len(enc) == 1, len(enc)
-        enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
-
-        predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
-        pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
-        if torch.max(pre_token_length) < 1:
-            return []
-        decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
-        decoder_out = decoder_outs
-
-        results = []
-        b, n, d = decoder_out.size()
-        for i in range(b):
-            x = enc[i, :enc_len[i], :]
-            am_scores = decoder_out[i, :pre_token_length[i], :]
-            if self.beam_search is not None:
-                nbest_hyps = self.beam_search(
-                    x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
-                )
-
-                nbest_hyps = nbest_hyps[: self.nbest]
-            else:
-                yseq = am_scores.argmax(dim=-1)
-                score = am_scores.max(dim=-1)[0]
-                score = torch.sum(score, dim=-1)
-                # pad with mask tokens to ensure compatibility with sos/eos tokens
-                yseq = torch.tensor(
-                    [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
-                )
-                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-
-            for hyp in nbest_hyps:
-                assert isinstance(hyp, (Hypothesis)), type(hyp)
-
-                # remove sos/eos and get results
-                last_pos = -1
-                if isinstance(hyp.yseq, list):
-                    token_int = hyp.yseq[1:last_pos]
-                else:
-                    token_int = hyp.yseq[1:last_pos].tolist()
-
-                # remove blank symbol id, which is assumed to be 0
-                token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-
-                # Change integer-ids to tokens
-                token = self.converter.ids2tokens(token_int)
-                postprocessed_result = ""
-                for item in token:
-                    if item.endswith('@@'):
-                        postprocessed_result += item[:-2]
-                    elif re.match('^[a-zA-Z]+$', item):
-                        postprocessed_result += item + " "
-                    else:
-                        postprocessed_result += item
-
-                results.append(postprocessed_result)
-
-        return results
-
-
-class Speech2TextUniASR:
-    """Speech2Text class
-
-    Examples:
-        >>> import librosa
-        >>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb")
-        >>> audio, rate = librosa.load("speech.wav")
-        >>> speech2text(audio)
-        [(text, token, token_int, hypothesis object), ...]
-
-    """
-
-    def __init__(
-            self,
-            asr_train_config: Union[Path, str] = None,
-            asr_model_file: Union[Path, str] = None,
-            cmvn_file: Union[Path, str] = None,
-            lm_train_config: Union[Path, str] = None,
-            lm_file: Union[Path, str] = None,
-            token_type: str = None,
-            bpemodel: str = None,
-            device: str = "cpu",
-            maxlenratio: float = 0.0,
-            minlenratio: float = 0.0,
-            dtype: str = "float32",
-            beam_size: int = 20,
-            ctc_weight: float = 0.5,
-            lm_weight: float = 1.0,
-            ngram_weight: float = 0.9,
-            penalty: float = 0.0,
-            nbest: int = 1,
-            token_num_relax: int = 1,
-            decoding_ind: int = 0,
-            decoding_mode: str = "model1",
-            frontend_conf: dict = None,
-            **kwargs,
-    ):
-
-        # 1. Build ASR model
-        scorers = {}
-        asr_model, asr_train_args = build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device, mode="uniasr"
-        )
-        frontend = None
-        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
-        logging.info("asr_train_args: {}".format(asr_train_args))
-        asr_model.to(dtype=getattr(torch, dtype)).eval()
-        if decoding_mode == "model1":
-            decoder = asr_model.decoder
-        else:
-            decoder = asr_model.decoder2
-
-        if asr_model.ctc != None:
-            ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
-            scorers.update(
-                ctc=ctc
-            )
-        token_list = asr_model.token_list
-        scorers.update(
-            decoder=decoder,
-            length_bonus=LengthBonus(len(token_list)),
-        )
-
-        # 2. Build Language model
-        if lm_train_config is not None:
-            lm, lm_train_args = build_model_from_file(
-                lm_train_config, lm_file, device, "lm"
-            )
-            scorers["lm"] = lm.lm
-
-        # 3. Build ngram model
-        # ngram is not supported now
-        ngram = None
-        scorers["ngram"] = ngram
-
-        # 4. Build BeamSearch object
-        # transducer is not supported now
-        beam_search_transducer = None
-        from funasr.modules.beam_search.beam_search import BeamSearchScama as BeamSearch
-
-        weights = dict(
-            decoder=1.0 - ctc_weight,
-            ctc=ctc_weight,
-            lm=lm_weight,
-            ngram=ngram_weight,
-            length_bonus=penalty,
-        )
-        beam_search = BeamSearch(
-            beam_size=beam_size,
-            weights=weights,
-            scorers=scorers,
-            sos=asr_model.sos,
-            eos=asr_model.eos,
-            vocab_size=len(token_list),
-            token_list=token_list,
-            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
-        )
-
-        beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
-        for scorer in scorers.values():
-            if isinstance(scorer, torch.nn.Module):
-                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-        # logging.info(f"Beam_search: {beam_search}")
-        logging.info(f"Decoding device={device}, dtype={dtype}")
-
-        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
-        if token_type is None:
-            token_type = asr_train_args.token_type
-        if bpemodel is None:
-            bpemodel = asr_train_args.bpemodel
-
-        if token_type is None:
-            tokenizer = None
-        elif token_type == "bpe":
-            if bpemodel is not None:
-                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
-            else:
-                tokenizer = None
-        else:
-            tokenizer = build_tokenizer(token_type=token_type)
-        converter = TokenIDConverter(token_list=token_list)
-        logging.info(f"Text tokenizer: {tokenizer}")
-
-        self.asr_model = asr_model
-        self.asr_train_args = asr_train_args
-        self.converter = converter
-        self.tokenizer = tokenizer
-        self.beam_search = beam_search
-        self.beam_search_transducer = beam_search_transducer
-        self.maxlenratio = maxlenratio
-        self.minlenratio = minlenratio
-        self.device = device
-        self.dtype = dtype
-        self.nbest = nbest
-        self.token_num_relax = token_num_relax
-        self.decoding_ind = decoding_ind
-        self.decoding_mode = decoding_mode
-        self.frontend = frontend
-
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
-    ) -> List[
-        Tuple[
-            Optional[str],
-            List[str],
-            List[int],
-            Union[Hypothesis],
-        ]
-    ]:
-        """Inference
-
-        Args:
-            speech: Input speech data
-        Returns:
-            text, token, token_int, hyp
-
-        """
-
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-
-        if self.frontend is not None:
-            feats, feats_len = self.frontend.forward(speech, speech_lengths)
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
-            self.asr_model.frontend = None
-        else:
-            feats = speech
-            feats_len = speech_lengths
-        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
-        feats_raw = feats.clone().to(self.device)
-        batch = {"speech": feats, "speech_lengths": feats_len}
-
-        # a. To device
-        batch = to_device(batch, device=self.device)
-        # b. Forward Encoder
-        _, enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
-        if isinstance(enc, tuple):
-            enc = enc[0]
-        assert len(enc) == 1, len(enc)
-        if self.decoding_mode == "model1":
-            predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len)
-        else:
-            enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind)
-            predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len)
-
-        scama_mask = predictor_outs[4]
-        pre_token_length = predictor_outs[1]
-        pre_acoustic_embeds = predictor_outs[0]
-        maxlen = pre_token_length.sum().item() + self.token_num_relax
-        minlen = max(0, pre_token_length.sum().item() - self.token_num_relax)
-        # c. Passed the encoder result and the beam search
-        nbest_hyps = self.beam_search(
-            x=enc[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=self.maxlenratio,
-            minlenratio=self.minlenratio, maxlen=int(maxlen), minlen=int(minlen),
-        )
-
-        nbest_hyps = nbest_hyps[: self.nbest]
-
-        results = []
-        for hyp in nbest_hyps:
-            assert isinstance(hyp, (Hypothesis)), type(hyp)
-
-            # remove sos/eos and get results
-            last_pos = -1
-            if isinstance(hyp.yseq, list):
-                token_int = hyp.yseq[1:last_pos]
-            else:
-                token_int = hyp.yseq[1:last_pos].tolist()
-
-            # remove blank symbol id, which is assumed to be 0
-            token_int = list(filter(lambda x: x != 0, token_int))
-
-            # Change integer-ids to tokens
-            token = self.converter.ids2tokens(token_int)
-            token = list(filter(lambda x: x != "<gbg>", token))
-
-            if self.tokenizer is not None:
-                text = self.tokenizer.tokens2text(token)
-            else:
-                text = None
-            results.append((text, token, token_int, hyp))
-
-        return results
-
-
-class Speech2TextMFCCA:
-    """Speech2Text class
-
-    Examples:
-        >>> import librosa
-        >>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb")
-        >>> audio, rate = librosa.load("speech.wav")
-        >>> speech2text(audio)
-        [(text, token, token_int, hypothesis object), ...]
-
-    """
-
-    def __init__(
-            self,
-            asr_train_config: Union[Path, str] = None,
-            asr_model_file: Union[Path, str] = None,
-            cmvn_file: Union[Path, str] = None,
-            lm_train_config: Union[Path, str] = None,
-            lm_file: Union[Path, str] = None,
-            token_type: str = None,
-            bpemodel: str = None,
-            device: str = "cpu",
-            maxlenratio: float = 0.0,
-            minlenratio: float = 0.0,
-            batch_size: int = 1,
-            dtype: str = "float32",
-            beam_size: int = 20,
-            ctc_weight: float = 0.5,
-            lm_weight: float = 1.0,
-            ngram_weight: float = 0.9,
-            penalty: float = 0.0,
-            nbest: int = 1,
-            streaming: bool = False,
-            **kwargs,
-    ):
-
-        # 1. Build ASR model
-        scorers = {}
-        asr_model, asr_train_args = build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
-        )
-
-        logging.info("asr_model: {}".format(asr_model))
-        logging.info("asr_train_args: {}".format(asr_train_args))
-        asr_model.to(dtype=getattr(torch, dtype)).eval()
-
-        decoder = asr_model.decoder
-
-        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
-        token_list = asr_model.token_list
-        scorers.update(
-            decoder=decoder,
-            ctc=ctc,
-            length_bonus=LengthBonus(len(token_list)),
-        )
-
-        # 2. Build Language model
-        if lm_train_config is not None:
-            lm, lm_train_args = build_model_from_file(
-                lm_train_config, lm_file, None, device, task_name="lm"
-            )
-            lm.to(device)
-            scorers["lm"] = lm.lm
-        # 3. Build ngram model
-        # ngram is not supported now
-        ngram = None
-        scorers["ngram"] = ngram
-
-        # 4. Build BeamSearch object
-        # transducer is not supported now
-        beam_search_transducer = None
-
-        weights = dict(
-            decoder=1.0 - ctc_weight,
-            ctc=ctc_weight,
-            lm=lm_weight,
-            ngram=ngram_weight,
-            length_bonus=penalty,
-        )
-        beam_search = BeamSearch(
-            beam_size=beam_size,
-            weights=weights,
-            scorers=scorers,
-            sos=asr_model.sos,
-            eos=asr_model.eos,
-            vocab_size=len(token_list),
-            token_list=token_list,
-            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
-        )
-        # beam_search.__class__ = BatchBeamSearch
-        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
-        if token_type is None:
-            token_type = asr_train_args.token_type
-        if bpemodel is None:
-            bpemodel = asr_train_args.bpemodel
-
-        if token_type is None:
-            tokenizer = None
-        elif token_type == "bpe":
-            if bpemodel is not None:
-                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
-            else:
-                tokenizer = None
-        else:
-            tokenizer = build_tokenizer(token_type=token_type)
-        converter = TokenIDConverter(token_list=token_list)
-        logging.info(f"Text tokenizer: {tokenizer}")
-
-        self.asr_model = asr_model
-        self.asr_train_args = asr_train_args
-        self.converter = converter
-        self.tokenizer = tokenizer
-        self.beam_search = beam_search
-        self.beam_search_transducer = beam_search_transducer
-        self.maxlenratio = maxlenratio
-        self.minlenratio = minlenratio
-        self.device = device
-        self.dtype = dtype
-        self.nbest = nbest
-
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
-    ) -> List[
-        Tuple[
-            Optional[str],
-            List[str],
-            List[int],
-            Union[Hypothesis],
-        ]
-    ]:
-        """Inference
-
-        Args:
-            speech: Input speech data
-        Returns:
-            text, token, token_int, hyp
-
-        """
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-        if (speech.dim() == 3):
-            speech = torch.squeeze(speech, 2)
-        # speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
-        speech = speech.to(getattr(torch, self.dtype))
-        # lenghts: (1,)
-        lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
-        batch = {"speech": speech, "speech_lengths": lengths}
-
-        # a. To device
-        batch = to_device(batch, device=self.device)
-
-        # b. Forward Encoder
-        enc, _ = self.asr_model.encode(**batch)
-
-        assert len(enc) == 1, len(enc)
-
-        # c. Passed the encoder result and the beam search
-        nbest_hyps = self.beam_search(
-            x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
-        )
-
-        nbest_hyps = nbest_hyps[: self.nbest]
-
-        results = []
-        for hyp in nbest_hyps:
-            assert isinstance(hyp, (Hypothesis)), type(hyp)
-
-            # remove sos/eos and get results
-            last_pos = -1
-            if isinstance(hyp.yseq, list):
-                token_int = hyp.yseq[1:last_pos]
-            else:
-                token_int = hyp.yseq[1:last_pos].tolist()
-
-            # remove blank symbol id, which is assumed to be 0
-            token_int = list(filter(lambda x: x != 0, token_int))
-
-            # Change integer-ids to tokens
-            token = self.converter.ids2tokens(token_int)
-
-            if self.tokenizer is not None:
-                text = self.tokenizer.tokens2text(token)
-            else:
-                text = None
-            results.append((text, token, token_int, hyp))
-
-        return results
-
-
-class Speech2TextTransducer:
-    """Speech2Text class for Transducer models.
-    Args:
-        asr_train_config: ASR model training config path.
-        asr_model_file: ASR model path.
-        beam_search_config: Beam search config path.
-        lm_train_config: Language Model training config path.
-        lm_file: Language Model config path.
-        token_type: Type of token units.
-        bpemodel: BPE model path.
-        device: Device to use for inference.
-        beam_size: Size of beam during search.
-        dtype: Data type.
-        lm_weight: Language model weight.
-        quantize_asr_model: Whether to apply dynamic quantization to ASR model.
-        quantize_modules: List of module names to apply dynamic quantization on.
-        quantize_dtype: Dynamic quantization data type.
-        nbest: Number of final hypothesis.
-        streaming: Whether to perform chunk-by-chunk inference.
-        chunk_size: Number of frames in chunk AFTER subsampling.
-        left_context: Number of frames in left context AFTER subsampling.
-        right_context: Number of frames in right context AFTER subsampling.
-        display_partial_hypotheses: Whether to display partial hypotheses.
-    """
-
-    def __init__(
-            self,
-            asr_train_config: Union[Path, str] = None,
-            asr_model_file: Union[Path, str] = None,
-            cmvn_file: Union[Path, str] = None,
-            beam_search_config: Dict[str, Any] = None,
-            lm_train_config: Union[Path, str] = None,
-            lm_file: Union[Path, str] = None,
-            token_type: str = None,
-            bpemodel: str = None,
-            device: str = "cpu",
-            beam_size: int = 5,
-            dtype: str = "float32",
-            lm_weight: float = 1.0,
-            quantize_asr_model: bool = False,
-            quantize_modules: List[str] = None,
-            quantize_dtype: str = "qint8",
-            nbest: int = 1,
-            streaming: bool = False,
-            fake_streaming: bool = False,
-            full_utt: bool = False,
-            chunk_size: int = 16,
-            left_context: int = 32,
-            right_context: int = 0,
-            display_partial_hypotheses: bool = False,
-    ) -> None:
-        """Construct a Speech2Text object."""
-        super().__init__()
-
-        asr_model, asr_train_args = build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
-        )
-
-        frontend = None
-        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
-        if quantize_asr_model:
-            if quantize_modules is not None:
-                if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
-                    raise ValueError(
-                        "Only 'Linear' and 'LSTM' modules are currently supported"
-                        " by PyTorch and in --quantize_modules"
-                    )
-
-                q_config = set([getattr(torch.nn, q) for q in quantize_modules])
-            else:
-                q_config = {torch.nn.Linear}
-
-            if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
-                raise ValueError(
-                    "float16 dtype for dynamic quantization is not supported with torch"
-                    " version < 1.5.0. Switching to qint8 dtype instead."
-                )
-            q_dtype = getattr(torch, quantize_dtype)
-
-            asr_model = torch.quantization.quantize_dynamic(
-                asr_model, q_config, dtype=q_dtype
-            ).eval()
-        else:
-            asr_model.to(dtype=getattr(torch, dtype)).eval()
-
-        if lm_train_config is not None:
-            lm, lm_train_args = build_model_from_file(
-                lm_train_config, lm_file, None, device, task_name="lm"
-            )
-            lm_scorer = lm.lm
-        else:
-            lm_scorer = None
-
-        # 4. Build BeamSearch object
-        if beam_search_config is None:
-            beam_search_config = {}
-
-        beam_search = BeamSearchTransducer(
-            asr_model.decoder,
-            asr_model.joint_network,
-            beam_size,
-            lm=lm_scorer,
-            lm_weight=lm_weight,
-            nbest=nbest,
-            **beam_search_config,
-        )
-
-        token_list = asr_model.token_list
-
-        if token_type is None:
-            token_type = asr_train_args.token_type
-        if bpemodel is None:
-            bpemodel = asr_train_args.bpemodel
-
-        if token_type is None:
-            tokenizer = None
-        elif token_type == "bpe":
-            if bpemodel is not None:
-                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
-            else:
-                tokenizer = None
-        else:
-            tokenizer = build_tokenizer(token_type=token_type)
-        converter = TokenIDConverter(token_list=token_list)
-        logging.info(f"Text tokenizer: {tokenizer}")
-
-        self.asr_model = asr_model
-        self.asr_train_args = asr_train_args
-        self.device = device
-        self.dtype = dtype
-        self.nbest = nbest
-
-        self.converter = converter
-        self.tokenizer = tokenizer
-
-        self.beam_search = beam_search
-        self.streaming = streaming
-        self.fake_streaming = fake_streaming
-        self.full_utt = full_utt
-        self.chunk_size = max(chunk_size, 0)
-        self.left_context = left_context
-        self.right_context = max(right_context, 0)
-
-        if not streaming or chunk_size == 0:
-            self.streaming = False
-            self.asr_model.encoder.dynamic_chunk_training = False
-
-        if not fake_streaming or chunk_size == 0:
-            self.fake_streaming = False
-            self.asr_model.encoder.dynamic_chunk_training = False
-
-        self.frontend = frontend
-        self.window_size = self.chunk_size + self.right_context
-
-        if self.streaming:
-            self._ctx = self.asr_model.encoder.get_encoder_input_size(
-                self.window_size
-            )
-            self._right_ctx = right_context
-
-            self.last_chunk_length = (
-                    self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
-            )
-            self.reset_inference_cache()
-
-    def reset_inference_cache(self) -> None:
-        """Reset Speech2Text parameters."""
-        self.frontend_cache = None
-
-        self.asr_model.encoder.reset_streaming_cache(
-            self.left_context, device=self.device
-        )
-        self.beam_search.reset_inference_cache()
-
-        self.num_processed_frames = torch.tensor([[0]], device=self.device)
-
-    @torch.no_grad()
-    def streaming_decode(
-            self,
-            speech: Union[torch.Tensor, np.ndarray],
-            is_final: bool = True,
-    ) -> List[HypothesisTransducer]:
-        """Speech2Text streaming call.
-        Args:
-            speech: Chunk of speech data. (S)
-            is_final: Whether speech corresponds to the final chunk of data.
-        Returns:
-            nbest_hypothesis: N-best hypothesis.
-        """
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-        if is_final:
-            if self.streaming and speech.size(0) < self.last_chunk_length:
-                pad = torch.zeros(
-                    self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype
-                )
-                speech = torch.cat([speech, pad],
-                                   dim=0)  # feats, feats_length = self.apply_frontend(speech, is_final=is_final)
-
-        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
-        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
-        if self.asr_model.normalize is not None:
-            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
-        feats = to_device(feats, device=self.device)
-        feats_lengths = to_device(feats_lengths, device=self.device)
-        enc_out = self.asr_model.encoder.chunk_forward(
-            feats,
-            feats_lengths,
-            self.num_processed_frames,
-            chunk_size=self.chunk_size,
-            left_context=self.left_context,
-            right_context=self.right_context,
-        )
-        nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
-
-        self.num_processed_frames += self.chunk_size
-
-        if is_final:
-            self.reset_inference_cache()
-
-        return nbest_hyps
-
-    @torch.no_grad()
-    def fake_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
-        """Speech2Text call.
-        Args:
-            speech: Speech data. (S)
-        Returns:
-            nbest_hypothesis: N-best hypothesis.
-        """
-
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-
-        if self.frontend is not None:
-            speech = torch.unsqueeze(speech, axis=0)
-            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
-            feats, feats_lengths = self.frontend(speech, speech_lengths)
-        else:
-            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
-            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
-        if self.asr_model.normalize is not None:
-            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
-        feats = to_device(feats, device=self.device)
-        feats_lengths = to_device(feats_lengths, device=self.device)
-        enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context,
-                                                            self.right_context)
-        nbest_hyps = self.beam_search(enc_out[0])
-
-        return nbest_hyps
-
-    @torch.no_grad()
-    def full_utt_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
-        """Speech2Text call.
-        Args:
-            speech: Speech data. (S)
-        Returns:
-            nbest_hypothesis: N-best hypothesis.
-        """
-        assert check_argument_types()
-
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-
-        if self.frontend is not None:
-            speech = torch.unsqueeze(speech, axis=0)
-            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
-            feats, feats_lengths = self.frontend(speech, speech_lengths)
-        else:
-            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
-            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
-        if self.asr_model.normalize is not None:
-            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
-        feats = to_device(feats, device=self.device)
-        feats_lengths = to_device(feats_lengths, device=self.device)
-        enc_out = self.asr_model.encoder.full_utt_forward(feats, feats_lengths)
-        nbest_hyps = self.beam_search(enc_out[0])
-
-        return nbest_hyps
-
-    @torch.no_grad()
-    def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
-        """Speech2Text call.
-        Args:
-            speech: Speech data. (S)
-        Returns:
-            nbest_hypothesis: N-best hypothesis.
-        """
-
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-
-        if self.frontend is not None:
-            speech = torch.unsqueeze(speech, axis=0)
-            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
-            feats, feats_lengths = self.frontend(speech, speech_lengths)
-        else:
-            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
-            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
-        feats = to_device(feats, device=self.device)
-        feats_lengths = to_device(feats_lengths, device=self.device)
-
-        enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
-        nbest_hyps = self.beam_search(enc_out[0])
-
-        return nbest_hyps
-
-    def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]:
-        """Build partial or final results from the hypotheses.
-        Args:
-            nbest_hyps: N-best hypothesis.
-        Returns:
-            results: Results containing different representation for the hypothesis.
-        """
-        results = []
-
-        for hyp in nbest_hyps:
-            token_int = list(filter(lambda x: x != 0, hyp.yseq))
-
-            token = self.converter.ids2tokens(token_int)
-
-            if self.tokenizer is not None:
-                text = self.tokenizer.tokens2text(token)
-            else:
-                text = None
-            results.append((text, token, token_int, hyp))
-
-
-        return results
-
-
-class Speech2TextSAASR:
-    """Speech2Text class
-
-    Examples:
-        >>> import librosa
-        >>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
-        >>> audio, rate = librosa.load("speech.wav")
-        >>> speech2text(audio)
-        [(text, token, token_int, hypothesis object), ...]
-
-    """
-
-    def __init__(
-            self,
-            asr_train_config: Union[Path, str] = None,
-            asr_model_file: Union[Path, str] = None,
-            cmvn_file: Union[Path, str] = None,
-            lm_train_config: Union[Path, str] = None,
-            lm_file: Union[Path, str] = None,
-            token_type: str = None,
-            bpemodel: str = None,
-            device: str = "cpu",
-            maxlenratio: float = 0.0,
-            minlenratio: float = 0.0,
-            batch_size: int = 1,
-            dtype: str = "float32",
-            beam_size: int = 20,
-            ctc_weight: float = 0.5,
-            lm_weight: float = 1.0,
-            ngram_weight: float = 0.9,
-            penalty: float = 0.0,
-            nbest: int = 1,
-            streaming: bool = False,
-            frontend_conf: dict = None,
-            **kwargs,
-    ):
-
-        # 1. Build ASR model
-        scorers = {}
-        asr_model, asr_train_args = build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
-        )
-        frontend = None
-        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            from funasr.tasks.sa_asr import frontend_choices
-            if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
-                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
-                frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
-            else:
-                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
-                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
-
-        logging.info("asr_model: {}".format(asr_model))
-        logging.info("asr_train_args: {}".format(asr_train_args))
-        asr_model.to(dtype=getattr(torch, dtype)).eval()
-
-        decoder = asr_model.decoder
-
-        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
-        token_list = asr_model.token_list
-        scorers.update(
-            decoder=decoder,
-            ctc=ctc,
-            length_bonus=LengthBonus(len(token_list)),
-        )
-
-        # 2. Build Language model
-        if lm_train_config is not None:
-            lm, lm_train_args = build_model_from_file(
-                lm_train_config, lm_file, None, device, task_name="lm"
-            )
-            scorers["lm"] = lm.lm
-
-        # 3. Build ngram model
-        # ngram is not supported now
-        ngram = None
-        scorers["ngram"] = ngram
-
-        # 4. Build BeamSearch object
-        # transducer is not supported now
-        beam_search_transducer = None
-        from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
-
-        weights = dict(
-            decoder=1.0 - ctc_weight,
-            ctc=ctc_weight,
-            lm=lm_weight,
-            ngram=ngram_weight,
-            length_bonus=penalty,
-        )
-        beam_search = BeamSearch(
-            beam_size=beam_size,
-            weights=weights,
-            scorers=scorers,
-            sos=asr_model.sos,
-            eos=asr_model.eos,
-            vocab_size=len(token_list),
-            token_list=token_list,
-            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
-        )
-
-        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
-        if token_type is None:
-            token_type = asr_train_args.token_type
-        if bpemodel is None:
-            bpemodel = asr_train_args.bpemodel
-
-        if token_type is None:
-            tokenizer = None
-        elif token_type == "bpe":
-            if bpemodel is not None:
-                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
-            else:
-                tokenizer = None
-        else:
-            tokenizer = build_tokenizer(token_type=token_type)
-        converter = TokenIDConverter(token_list=token_list)
-        logging.info(f"Text tokenizer: {tokenizer}")
-
-        self.asr_model = asr_model
-        self.asr_train_args = asr_train_args
-        self.converter = converter
-        self.tokenizer = tokenizer
-        self.beam_search = beam_search
-        self.beam_search_transducer = beam_search_transducer
-        self.maxlenratio = maxlenratio
-        self.minlenratio = minlenratio
-        self.device = device
-        self.dtype = dtype
-        self.nbest = nbest
-        self.frontend = frontend
-
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
-            profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
-    ) -> List[
-        Tuple[
-            Optional[str],
-            Optional[str],
-            List[str],
-            List[int],
-            Union[HypothesisSAASR],
-        ]
-    ]:
-        """Inference
-
-        Args:
-            speech: Input speech data
-        Returns:
-            text, text_id, token, token_int, hyp
-
-        """
-
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-
-        if isinstance(profile, np.ndarray):
-            profile = torch.tensor(profile)
-
-        if self.frontend is not None:
-            feats, feats_len = self.frontend.forward(speech, speech_lengths)
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
-            self.asr_model.frontend = None
-        else:
-            feats = speech
-            feats_len = speech_lengths
-        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
-        batch = {"speech": feats, "speech_lengths": feats_len}
-
-        # a. To device
-        batch = to_device(batch, device=self.device)
-
-        # b. Forward Encoder
-        asr_enc, _, spk_enc = self.asr_model.encode(**batch)
-        if isinstance(asr_enc, tuple):
-            asr_enc = asr_enc[0]
-        if isinstance(spk_enc, tuple):
-            spk_enc = spk_enc[0]
-        assert len(asr_enc) == 1, len(asr_enc)
-        assert len(spk_enc) == 1, len(spk_enc)
-
-        # c. Passed the encoder result and the beam search
-        nbest_hyps = self.beam_search(
-            asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
-        )
-
-        nbest_hyps = nbest_hyps[: self.nbest]
-
-        results = []
-        for hyp in nbest_hyps:
-            assert isinstance(hyp, (HypothesisSAASR)), type(hyp)
-
-            # remove sos/eos and get results
-            last_pos = -1
-            if isinstance(hyp.yseq, list):
-                token_int = hyp.yseq[1: last_pos]
-            else:
-                token_int = hyp.yseq[1: last_pos].tolist()
-
-            spk_weigths = torch.stack(hyp.spk_weigths, dim=0)
-
-            token_ori = self.converter.ids2tokens(token_int)
-            text_ori = self.tokenizer.tokens2text(token_ori)
-
-            text_ori_spklist = text_ori.split('$')
-            cur_index = 0
-            spk_choose = []
-            for i in range(len(text_ori_spklist)):
-                text_ori_split = text_ori_spklist[i]
-                n = len(text_ori_split)
-                spk_weights_local = spk_weigths[cur_index: cur_index + n]
-                cur_index = cur_index + n + 1
-                spk_weights_local = spk_weights_local.mean(dim=0)
-                spk_choose_local = spk_weights_local.argmax(-1)
-                spk_choose.append(spk_choose_local.item() + 1)
-
-            # remove blank symbol id, which is assumed to be 0
-            token_int = list(filter(lambda x: x != 0, token_int))
-
-            # Change integer-ids to tokens
-            token = self.converter.ids2tokens(token_int)
-
-            if self.tokenizer is not None:
-                text = self.tokenizer.tokens2text(token)
-            else:
-                text = None
-
-            text_spklist = text.split('$')
-            assert len(spk_choose) == len(text_spklist)
-
-            spk_list = []
-            for i in range(len(text_spklist)):
-                text_split = text_spklist[i]
-                n = len(text_split)
-                spk_list.append(str(spk_choose[i]) * n)
-
-            text_id = '$'.join(spk_list)
-
-            assert len(text) == len(text_id)
-
-            results.append((text, text_id, token, token_int, hyp))
-
-        return results
-
-
-class Speech2TextWhisper:
-    """Speech2Text class
-
-    Examples:
-        >>> import librosa
-        >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
-        >>> audio, rate = librosa.load("speech.wav")
-        >>> speech2text(audio)
-        [(text, token, token_int, hypothesis object), ...]
-
-    """
-
-    def __init__(
-            self,
-            asr_train_config: Union[Path, str] = None,
-            asr_model_file: Union[Path, str] = None,
-            cmvn_file: Union[Path, str] = None,
-            lm_train_config: Union[Path, str] = None,
-            lm_file: Union[Path, str] = None,
-            token_type: str = None,
-            bpemodel: str = None,
-            device: str = "cpu",
-            maxlenratio: float = 0.0,
-            minlenratio: float = 0.0,
-            batch_size: int = 1,
-            dtype: str = "float32",
-            beam_size: int = 20,
-            ctc_weight: float = 0.5,
-            lm_weight: float = 1.0,
-            ngram_weight: float = 0.9,
-            penalty: float = 0.0,
-            nbest: int = 1,
-            streaming: bool = False,
-            frontend_conf: dict = None,
-            language: str = None,
-            task: str = "transcribe",
-            **kwargs,
-    ):
-
-        from funasr.tasks.whisper import ASRTask
-
-        # 1. Build ASR model
-        scorers = {}
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
-        )
-        frontend = None
-
-        logging.info("asr_model: {}".format(asr_model))
-        logging.info("asr_train_args: {}".format(asr_train_args))
-        asr_model.to(dtype=getattr(torch, dtype)).eval()
-
-        decoder = asr_model.decoder
-
-        token_list = []
-
-        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
-        if token_type is None:
-            token_type = asr_train_args.token_type
-        if bpemodel is None:
-            bpemodel = asr_train_args.bpemodel
-
-        if token_type is None:
-            tokenizer = None
-        elif token_type == "bpe":
-            if bpemodel is not None:
-                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
-            else:
-                tokenizer = None
-        else:
-            tokenizer = build_tokenizer(token_type=token_type)
-        logging.info(f"Text tokenizer: {tokenizer}")
-
-        self.asr_model = asr_model
-        self.asr_train_args = asr_train_args
-        self.tokenizer = tokenizer
-        self.device = device
-        self.dtype = dtype
-        self.frontend = frontend
-        self.language = language
-        self.task = task
-
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
-    ) -> List[
-        Tuple[
-            Optional[str],
-            List[str],
-            List[int],
-            Union[Hypothesis],
-        ]
-    ]:
-        """Inference
-
-        Args:
-            speech: Input speech data
-        Returns:
-            text, token, token_int, hyp
-
-        """
-
-        from funasr.utils.whisper_utils.transcribe import transcribe
-        from funasr.utils.whisper_utils.audio import pad_or_trim, log_mel_spectrogram
-        from funasr.utils.whisper_utils.decoding import DecodingOptions, detect_language, decode
-
-        speech = speech[0]
-        speech = pad_or_trim(speech)
-        mel = log_mel_spectrogram(speech).to(self.device)
-
-        if self.asr_model.is_multilingual:
-            options = DecodingOptions(fp16=False, language=self.language, task=self.task)
-            asr_res = decode(self.asr_model, mel, options)
-            text = asr_res.text
-            language = self.language if self.language else asr_res.language
-        else:
-            asr_res = transcribe(self.asr_model, speech, fp16=False)
-            text = asr_res["text"]
-            language = asr_res["language"]
-        results = [(text, language)]
-        return results
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
deleted file mode 100644
index 6151d28..0000000
--- a/funasr/bin/asr_inference_launch.py
+++ /dev/null
@@ -1,2248 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-from optparse import Option
-import os
-import sys
-import time
-from pathlib import Path
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-import torchaudio
-# import librosa
-import librosa
-import yaml
-
-from funasr.bin.asr_infer import Speech2Text
-from funasr.bin.asr_infer import Speech2TextMFCCA
-from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
-from funasr.bin.asr_infer import Speech2TextSAASR
-from funasr.bin.asr_infer import Speech2TextTransducer
-from funasr.bin.asr_infer import Speech2TextUniASR
-from funasr.bin.asr_infer import Speech2TextWhisper
-from funasr.bin.punc_infer import Text2Punc
-from funasr.bin.tp_infer import Speech2Timestamp
-from funasr.bin.vad_infer import Speech2VadSegment
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.subsampling import TooShortUttError
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import asr_utils, postprocess_utils
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils.vad_utils import slice_padding_fbank
-from funasr.utils.speaker_utils import (check_audio_list,
-                                        sv_preprocess,
-                                        sv_chunk,
-                                        extract_feature,
-                                        postprocess,
-                                        distribute_spk)
-import funasr.modules.cnn as sv_module
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.utils.cluster_backend import ClusterBackend
-from funasr.utils.modelscope_utils import get_cache_dir
-from tqdm import tqdm
-
-def inference_asr(
-        maxlenratio: float,
-        minlenratio: float,
-        batch_size: int,
-        beam_size: int,
-        ngpu: int,
-        ctc_weight: float,
-        lm_weight: float,
-        penalty: float,
-        log_level: Union[int, str],
-        # data_path_and_name_and_type,
-        asr_train_config: Optional[str],
-        asr_model_file: Optional[str],
-        cmvn_file: Optional[str] = None,
-        lm_train_config: Optional[str] = None,
-        lm_file: Optional[str] = None,
-        token_type: Optional[str] = None,
-        key_file: Optional[str] = None,
-        word_lm_train_config: Optional[str] = None,
-        bpemodel: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        streaming: bool = False,
-        output_dir: Optional[str] = None,
-        dtype: str = "float32",
-        seed: int = 0,
-        ngram_weight: float = 0.9,
-        nbest: int = 1,
-        num_workers: int = 1,
-        mc: bool = False,
-        param_dict: dict = None,
-        **kwargs,
-):
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-    if batch_size > 1:
-        raise NotImplementedError("batch decoding is not implemented")
-    if word_lm_train_config is not None:
-        raise NotImplementedError("Word LM is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    for handler in logging.root.handlers[:]:
-        logging.root.removeHandler(handler)
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2text
-    speech2text_kwargs = dict(
-        asr_train_config=asr_train_config,
-        asr_model_file=asr_model_file,
-        cmvn_file=cmvn_file,
-        lm_train_config=lm_train_config,
-        lm_file=lm_file,
-        token_type=token_type,
-        bpemodel=bpemodel,
-        device=device,
-        maxlenratio=maxlenratio,
-        minlenratio=minlenratio,
-        dtype=dtype,
-        beam_size=beam_size,
-        ctc_weight=ctc_weight,
-        lm_weight=lm_weight,
-        ngram_weight=ngram_weight,
-        penalty=penalty,
-        nbest=nbest,
-        streaming=streaming,
-    )
-    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
-    speech2text = Speech2Text(**speech2text_kwargs)
-
-    def _forward(data_path_and_name_and_type,
-                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-                 output_dir_v2: Optional[str] = None,
-                 fs: dict = None,
-                 param_dict: dict = None,
-                 **kwargs,
-                 ):
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = build_streaming_iterator(
-            task_name="asr",
-            preprocess_args=speech2text.asr_train_args,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            fs=fs,
-            mc=mc,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-        )
-
-        finish_count = 0
-        file_count = 1
-        # 7 .Start for-loop
-        # FIXME(kamo): The output format should be discussed about
-        asr_result_list = []
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-        else:
-            writer = None
-
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
-            # N-best list of (text, token, token_int, hyp_object)
-            try:
-                results = speech2text(**batch)
-            except TooShortUttError as e:
-                logging.warning(f"Utterance {keys} {e}")
-                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
-                results = [[" ", ["sil"], [2], hyp]] * nbest
-
-            # Only supporting batch_size==1
-            key = keys[0]
-            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
-                # Create a directory: outdir/{n}best_recog
-                if writer is not None:
-                    ibest_writer = writer[f"{n}best_recog"]
-
-                    # Write the result to each file
-                    ibest_writer["token"][key] = " ".join(token)
-                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
-                    ibest_writer["score"][key] = str(hyp.score)
-
-                if text is not None:
-                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-                    item = {'key': key, 'value': text_postprocessed}
-                    asr_result_list.append(item)
-                    finish_count += 1
-                    asr_utils.print_progress(finish_count / file_count)
-                    if writer is not None:
-                        ibest_writer["text"][key] = text
-
-                logging.info("uttid: {}".format(key))
-                logging.info("text predictions: {}\n".format(text))
-        return asr_result_list
-
-    return _forward
-
-
-def inference_paraformer(
-        maxlenratio: float,
-        minlenratio: float,
-        batch_size: int,
-        beam_size: int,
-        ngpu: int,
-        ctc_weight: float,
-        lm_weight: float,
-        penalty: float,
-        log_level: Union[int, str],
-        # data_path_and_name_and_type,
-        asr_train_config: Optional[str],
-        asr_model_file: Optional[str],
-        cmvn_file: Optional[str] = None,
-        lm_train_config: Optional[str] = None,
-        lm_file: Optional[str] = None,
-        token_type: Optional[str] = None,
-        key_file: Optional[str] = None,
-        word_lm_train_config: Optional[str] = None,
-        bpemodel: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        dtype: str = "float32",
-        seed: int = 0,
-        ngram_weight: float = 0.9,
-        nbest: int = 1,
-        num_workers: int = 1,
-        output_dir: Optional[str] = None,
-        timestamp_infer_config: Union[Path, str] = None,
-        timestamp_model_file: Union[Path, str] = None,
-        param_dict: dict = None,
-        decoding_ind: int = 0,
-        **kwargs,
-):
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-
-    if word_lm_train_config is not None:
-        raise NotImplementedError("Word LM is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    export_mode = False
-    if param_dict is not None:
-        hotword_list_or_file = param_dict.get('hotword')
-        export_mode = param_dict.get("export_mode", False)
-        clas_scale = param_dict.get('clas_scale', 1.0)
-    else:
-        hotword_list_or_file = None
-        clas_scale = 1.0
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-        batch_size = 1
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2text
-    speech2text_kwargs = dict(
-        asr_train_config=asr_train_config,
-        asr_model_file=asr_model_file,
-        cmvn_file=cmvn_file,
-        lm_train_config=lm_train_config,
-        lm_file=lm_file,
-        token_type=token_type,
-        bpemodel=bpemodel,
-        device=device,
-        maxlenratio=maxlenratio,
-        minlenratio=minlenratio,
-        dtype=dtype,
-        beam_size=beam_size,
-        ctc_weight=ctc_weight,
-        lm_weight=lm_weight,
-        ngram_weight=ngram_weight,
-        penalty=penalty,
-        nbest=nbest,
-        hotword_list_or_file=hotword_list_or_file,
-        clas_scale=clas_scale,
-        decoding_ind=decoding_ind,
-    )
-
-    speech2text = Speech2TextParaformer(**speech2text_kwargs)
-
-    if timestamp_model_file is not None:
-        speechtext2timestamp = Speech2Timestamp(
-            timestamp_cmvn_file=cmvn_file,
-            timestamp_model_file=timestamp_model_file,
-            timestamp_infer_config=timestamp_infer_config,
-        )
-    else:
-        speechtext2timestamp = None
-
-    def _forward(
-            data_path_and_name_and_type,
-            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-            output_dir_v2: Optional[str] = None,
-            fs: dict = None,
-            param_dict: dict = None,
-            **kwargs,
-    ):
-
-        decoding_ind = None
-        hotword_list_or_file = None
-        if param_dict is not None:
-            hotword_list_or_file = param_dict.get('hotword')
-        if 'hotword' in kwargs and kwargs['hotword'] is not None:
-            hotword_list_or_file = kwargs['hotword']
-        if hotword_list_or_file is not None or 'hotword' in kwargs:
-            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
-        if param_dict is not None and "decoding_ind" in param_dict:
-            decoding_ind = param_dict["decoding_ind"]
-
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = build_streaming_iterator(
-            task_name="asr",
-            preprocess_args=speech2text.asr_train_args,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            fs=fs,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-        )
-
-        if param_dict is not None:
-            use_timestamp = param_dict.get('use_timestamp', True)
-        else:
-            use_timestamp = True
-
-        forward_time_total = 0.0
-        length_total = 0.0
-        finish_count = 0
-        file_count = 1
-        # 7 .Start for-loop
-        # FIXME(kamo): The output format should be discussed about
-        asr_result_list = []
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-        else:
-            writer = None
-
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
-
-            logging.info("decoding, utt_id: {}".format(keys))
-            # N-best list of (text, token, token_int, hyp_object)
-
-            time_beg = time.time()
-            batch["decoding_ind"] = decoding_ind
-            results = speech2text(**batch)
-            if len(results) < 1:
-                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
-                results = [[" ", ["sil"], [2], hyp, 10, 6, []]] * nbest
-            time_end = time.time()
-            forward_time = time_end - time_beg
-            lfr_factor = results[0][-1]
-            length = results[0][-2]
-            forward_time_total += forward_time
-            length_total += length
-            rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time,
-                                                                                               100 * forward_time / (
-                                                                                                       length * lfr_factor))
-            logging.info(rtf_cur)
-
-            for batch_id in range(_bs):
-                result = [results[batch_id][:-2]]
-
-                key = keys[batch_id]
-                for n, result in zip(range(1, nbest + 1), result):
-                    text, token, token_int, hyp = result[0], result[1], result[2], result[3]
-                    timestamp = result[4] if len(result[4]) > 0 else None
-                    # conduct timestamp prediction here
-                    # timestamp inference requires token length
-                    # thus following inference cannot be conducted in batch
-                    if timestamp is None and speechtext2timestamp:
-                        ts_batch = {}
-                        ts_batch['speech'] = batch['speech'][batch_id].unsqueeze(0)
-                        ts_batch['speech_lengths'] = torch.tensor([batch['speech_lengths'][batch_id]])
-                        ts_batch['text_lengths'] = torch.tensor([len(token)])
-                        us_alphas, us_peaks = speechtext2timestamp(**ts_batch)
-                        ts_str, timestamp = ts_prediction_lfr6_standard(us_alphas[0], us_peaks[0], token,
-                                                                        force_time_shift=-3.0)
-                    # Create a directory: outdir/{n}best_recog
-                    if writer is not None:
-                        ibest_writer = writer[f"{n}best_recog"]
-
-                        # Write the result to each file
-                        ibest_writer["token"][key] = " ".join(token)
-                        # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
-                        ibest_writer["score"][key] = str(hyp.score)
-                        ibest_writer["rtf"][key] = rtf_cur
-
-                    if text is not None:
-                        if use_timestamp and timestamp is not None and len(timestamp):
-                            postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
-                        else:
-                            postprocessed_result = postprocess_utils.sentence_postprocess(token)
-                        timestamp_postprocessed = ""
-                        if len(postprocessed_result) == 3:
-                            text_postprocessed, timestamp_postprocessed, word_lists = postprocessed_result[0], \
-                                                                                      postprocessed_result[1], \
-                                                                                      postprocessed_result[2]
-                        else:
-                            text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
-                        item = {'key': key, 'value': text_postprocessed}
-                        if timestamp_postprocessed != "":
-                            item['timestamp'] = timestamp_postprocessed
-                        asr_result_list.append(item)
-                        finish_count += 1
-                        # asr_utils.print_progress(finish_count / file_count)
-                        if writer is not None:
-                            ibest_writer["text"][key] = " ".join(word_lists)
-
-                    logging.info("decoding, utt: {}, predictions: {}".format(key, text))
-        rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
-                                                                                                           forward_time_total,
-                                                                                                           100 * forward_time_total / (
-                                                                                                                   length_total * lfr_factor))
-        logging.info(rtf_avg)
-        if writer is not None:
-            ibest_writer["rtf"]["rtf_avf"] = rtf_avg
-        torch.cuda.empty_cache()
-        return asr_result_list
-
-    return _forward
-
-
-def inference_paraformer_vad_punc(
-        maxlenratio: float=0.0,
-        minlenratio: float=0.0,
-        batch_size: int=1,
-        beam_size: int=1,
-        ngpu: int=1,
-        ctc_weight: float=0.0,
-        lm_weight: float=0.0,
-        penalty: float=0.0,
-        log_level: Union[int, str]=logging.ERROR,
-        # data_path_and_name_and_type,
-        asr_train_config: Optional[str]=None,
-        asr_model_file: Optional[str]=None,
-        cmvn_file: Optional[str] = None,
-        lm_train_config: Optional[str] = None,
-        lm_file: Optional[str] = None,
-        token_type: Optional[str] = None,
-        key_file: Optional[str] = None,
-        word_lm_train_config: Optional[str] = None,
-        bpemodel: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        output_dir: Optional[str] = None,
-        dtype: str = "float32",
-        seed: int = 0,
-        ngram_weight: float = 0.9,
-        nbest: int = 1,
-        num_workers: int = 0,
-        vad_infer_config: Optional[str] = None,
-        vad_model_file: Optional[str] = None,
-        vad_cmvn_file: Optional[str] = None,
-        time_stamp_writer: bool = True,
-        punc_infer_config: Optional[str] = None,
-        punc_model_file: Optional[str] = None,
-        outputs_dict: Optional[bool] = True,
-        param_dict: dict = None,
-        **kwargs,
-):
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-    language = kwargs.get("model_lang", None)
-
-    if word_lm_train_config is not None:
-        raise NotImplementedError("Word LM is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    if param_dict is not None:
-        hotword_list_or_file = param_dict.get('hotword')
-    else:
-        hotword_list_or_file = None
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2vadsegment
-    speech2vadsegment_kwargs = dict(
-        vad_infer_config=vad_infer_config,
-        vad_model_file=vad_model_file,
-        vad_cmvn_file=vad_cmvn_file,
-        device=device,
-        dtype=dtype,
-    )
-    # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
-    speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
-
-    # 3. Build speech2text
-    speech2text_kwargs = dict(
-        asr_train_config=asr_train_config,
-        asr_model_file=asr_model_file,
-        cmvn_file=cmvn_file,
-        lm_train_config=lm_train_config,
-        lm_file=lm_file,
-        token_type=token_type,
-        bpemodel=bpemodel,
-        device=device,
-        maxlenratio=maxlenratio,
-        minlenratio=minlenratio,
-        dtype=dtype,
-        beam_size=beam_size,
-        ctc_weight=ctc_weight,
-        lm_weight=lm_weight,
-        ngram_weight=ngram_weight,
-        penalty=penalty,
-        nbest=nbest,
-        hotword_list_or_file=hotword_list_or_file,
-    )
-    speech2text = Speech2TextParaformer(**speech2text_kwargs)
-    text2punc = None
-    if punc_model_file is not None:
-        text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
-
-    if output_dir is not None:
-        writer = DatadirWriter(output_dir)
-        ibest_writer = writer[f"1best_recog"]
-        ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
-
-    def _forward(data_path_and_name_and_type,
-                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-                 output_dir_v2: Optional[str] = None,
-                 fs: dict = None,
-                 param_dict: dict = None,
-                 **kwargs,
-                 ):
-
-        hotword_list_or_file = None
-        if param_dict is not None:
-            hotword_list_or_file = param_dict.get('hotword')
-
-        if 'hotword' in kwargs:
-            hotword_list_or_file = kwargs['hotword']
-
-        speech2vadsegment.vad_model.vad_opts.max_single_segment_time = kwargs.get("max_single_segment_time", 60000)
-        batch_size_token_threshold_s = kwargs.get("batch_size_token_threshold_s", int(speech2vadsegment.vad_model.vad_opts.max_single_segment_time*0.67/1000)) * 1000
-        batch_size_token = kwargs.get("batch_size_token", 6000)
-        print("batch_size_token: ", batch_size_token)
-
-        if speech2text.hotword_list is None:
-            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
-
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = build_streaming_iterator(
-            task_name="asr",
-            preprocess_args=None,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            fs=fs,
-            batch_size=1,
-            key_file=key_file,
-            num_workers=num_workers,
-        )
-
-        if param_dict is not None:
-            use_timestamp = param_dict.get('use_timestamp', True)
-        else:
-            use_timestamp = True
-
-        finish_count = 0
-        file_count = 1
-        lfr_factor = 6
-        # 7 .Start for-loop
-        asr_result_list = []
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        writer = None
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-            ibest_writer = writer[f"1best_recog"]
-
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            beg_vad = time.time()
-            vad_results = speech2vadsegment(**batch)
-            end_vad = time.time()
-            print("time cost vad: ", end_vad - beg_vad)
-            _, vadsegments = vad_results[0], vad_results[1][0]
-
-            speech, speech_lengths = batch["speech"], batch["speech_lengths"]
-
-            n = len(vadsegments)
-            data_with_index = [(vadsegments[i], i) for i in range(n)]
-            sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
-            results_sorted = []
-            
-            if not len(sorted_data):
-                key = keys[0]
-                # no active segments after VAD
-                if writer is not None:
-                    # Write empty results
-                    ibest_writer["token"][key] = ""
-                    ibest_writer["token_int"][key] = ""
-                    ibest_writer["vad"][key] = ""
-                    ibest_writer["text"][key] = ""
-                    ibest_writer["text_with_punc"][key] = ""
-                    if use_timestamp:
-                        ibest_writer["time_stamp"][key] = ""
-
-                logging.info("decoding, utt: {}, empty speech".format(key))
-                continue
-
-            batch_size_token_ms = batch_size_token*60
-            if speech2text.device == "cpu":
-                batch_size_token_ms = 0
-            if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
-                batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
-            
-            batch_size_token_ms_cum = 0
-            beg_idx = 0
-            beg_asr_total = time.time()
-            for j, _ in enumerate(tqdm(range(0, n))):
-                batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
-                if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_ms and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_threshold_s:
-                    continue
-                batch_size_token_ms_cum = 0
-                end_idx = j + 1
-                speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
-                beg_idx = end_idx
-                batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
-                batch = to_device(batch, device=device)
-
-                beg_asr = time.time()
-                results = speech2text(**batch)
-                end_asr = time.time()
-                if speech2text.device != "cpu":
-                    print("batch: ", speech_j.shape[0])
-                    print("time cost asr: ", end_asr - beg_asr)
-
-                if len(results) < 1:
-                    results = [["", [], [], [], [], [], []]]
-                results_sorted.extend(results)
-            end_asr_total = time.time()
-            print("total time cost asr: ", end_asr_total-beg_asr_total)
-            restored_data = [0] * n
-            for j in range(n):
-                index = sorted_data[j][1]
-                restored_data[index] = results_sorted[j]
-            result = ["", [], [], [], [], [], []]
-            for j in range(n):
-                result[0] += restored_data[j][0]
-                result[1] += restored_data[j][1]
-                result[2] += restored_data[j][2]
-                if len(restored_data[j][4]) > 0:
-                    for t in restored_data[j][4]:
-                        t[0] += vadsegments[j][0]
-                        t[1] += vadsegments[j][0]
-                    result[4] += restored_data[j][4]
-                # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
-
-            key = keys[0]
-            # result = result_segments[0]
-            text, token, token_int = result[0], result[1], result[2]
-            time_stamp = result[4] if len(result[4]) > 0 else None
-
-            if language == "en-bpe":
-                postprocessed_result = postprocess_utils.sentence_postprocess_sentencepiece(token)
-            else:
-                if use_timestamp and time_stamp is not None and len(time_stamp):
-                    postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
-                else:
-                    postprocessed_result = postprocess_utils.sentence_postprocess(token)
-            text_postprocessed = ""
-            time_stamp_postprocessed = ""
-            text_postprocessed_punc = postprocessed_result
-            if len(postprocessed_result) == 3:
-                text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
-                                                                           postprocessed_result[1], \
-                                                                           postprocessed_result[2]
-            else:
-                text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
-
-            text_postprocessed_punc = text_postprocessed
-            punc_id_list = []
-            if len(word_lists) > 0 and text2punc is not None:
-                beg_punc = time.time()
-                text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
-                end_punc = time.time()
-                print("time cost punc: ", end_punc - beg_punc)
-
-            item = {'key': key, 'value': text_postprocessed_punc}
-            if text_postprocessed != "":
-                item['text_postprocessed'] = text_postprocessed
-            if time_stamp_postprocessed != "":
-                item['time_stamp'] = time_stamp_postprocessed
-
-            item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
-
-            asr_result_list.append(item)
-            finish_count += 1
-            # asr_utils.print_progress(finish_count / file_count)
-            if writer is not None:
-                # Write the result to each file
-                ibest_writer["token"][key] = " ".join(token)
-                ibest_writer["token_int"][key] = " ".join(map(str, token_int))
-                ibest_writer["vad"][key] = "{}".format(vadsegments)
-                ibest_writer["text"][key] = " ".join(word_lists)
-                ibest_writer["text_with_punc"][key] = text_postprocessed_punc
-                if time_stamp_postprocessed is not None:
-                    ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
-
-            logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
-        torch.cuda.empty_cache()
-        return asr_result_list
-
-    return _forward
-
-
-def inference_paraformer_vad_speaker(
-        maxlenratio: float,
-        minlenratio: float,
-        batch_size: int,
-        beam_size: int,
-        ngpu: int,
-        ctc_weight: float,
-        lm_weight: float,
-        penalty: float,
-        log_level: Union[int, str],
-        # data_path_and_name_and_type,
-        asr_train_config: Optional[str],
-        asr_model_file: Optional[str],
-        cmvn_file: Optional[str] = None,
-        lm_train_config: Optional[str] = None,
-        lm_file: Optional[str] = None,
-        token_type: Optional[str] = None,
-        key_file: Optional[str] = None,
-        word_lm_train_config: Optional[str] = None,
-        bpemodel: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        output_dir: Optional[str] = None,
-        dtype: str = "float32",
-        seed: int = 0,
-        ngram_weight: float = 0.9,
-        nbest: int = 1,
-        num_workers: int = 1,
-        vad_infer_config: Optional[str] = None,
-        vad_model_file: Optional[str] = None,
-        vad_cmvn_file: Optional[str] = None,
-        time_stamp_writer: bool = True,
-        punc_infer_config: Optional[str] = None,
-        punc_model_file: Optional[str] = None,
-        sv_model_file: Optional[str] = None, 
-        streaming: bool = False,
-        embedding_node: str = "resnet1_dense",
-        sv_threshold: float = 0.9465,
-        outputs_dict: Optional[bool] = True,
-        param_dict: dict = None,
-
-        **kwargs,
-):
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-
-    if word_lm_train_config is not None:
-        raise NotImplementedError("Word LM is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    sv_model_config_path = asr_model_file.replace("model.pb", "sv_model_config.yaml")
-    if not os.path.exists(sv_model_config_path):
-        sv_model_config = {'sv_model_class': 'CAMPPlus','sv_model_file': 'campplus_cn_common.bin', 'models_config': {}}
-    else:
-        with open(sv_model_config_path, 'r') as f:
-            sv_model_config = yaml.load(f, Loader=yaml.FullLoader)
-    if sv_model_config['models_config'] is None:
-        sv_model_config['models_config'] = {}
-    sv_model_file = asr_model_file.replace("model.pb", sv_model_config['sv_model_file'])
-
-    if param_dict is not None:
-        hotword_list_or_file = param_dict.get('hotword')
-    else:
-        hotword_list_or_file = None
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2vadsegment
-    speech2vadsegment_kwargs = dict(
-        vad_infer_config=vad_infer_config,
-        vad_model_file=vad_model_file,
-        vad_cmvn_file=vad_cmvn_file,
-        device=device,
-        dtype=dtype,
-    )
-    # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
-    speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
-
-    # 3. Build speech2text
-    speech2text_kwargs = dict(
-        asr_train_config=asr_train_config,
-        asr_model_file=asr_model_file,
-        cmvn_file=cmvn_file,
-        lm_train_config=lm_train_config,
-        lm_file=lm_file,
-        token_type=token_type,
-        bpemodel=bpemodel,
-        device=device,
-        maxlenratio=maxlenratio,
-        minlenratio=minlenratio,
-        dtype=dtype,
-        beam_size=beam_size,
-        ctc_weight=ctc_weight,
-        lm_weight=lm_weight,
-        ngram_weight=ngram_weight,
-        penalty=penalty,
-        nbest=nbest,
-        hotword_list_or_file=hotword_list_or_file,
-    )
-    speech2text = Speech2TextParaformer(**speech2text_kwargs)
-    text2punc = None
-    if punc_model_file is not None:
-        text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
-
-    if output_dir is not None:
-        writer = DatadirWriter(output_dir)
-        ibest_writer = writer[f"1best_recog"]
-        ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
-
-    def _forward(data_path_and_name_and_type,
-                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-                 output_dir_v2: Optional[str] = None,
-                 fs: dict = None,
-                 param_dict: dict = None,
-                 **kwargs,
-                 ):
-
-        hotword_list_or_file = None
-        if param_dict is not None:
-            hotword_list_or_file = param_dict.get('hotword')
-
-        if 'hotword' in kwargs:
-            hotword_list_or_file = kwargs['hotword']
-
-        speech2vadsegment.vad_model.vad_opts.max_single_segment_time = kwargs.get("max_single_segment_time", 60000)
-        batch_size_token_threshold_s = kwargs.get("batch_size_token_threshold_s", int(speech2vadsegment.vad_model.vad_opts.max_single_segment_time*0.67/1000)) * 1000
-        batch_size_token = kwargs.get("batch_size_token", 6000)
-        print("batch_size_token: ", batch_size_token)
-
-        if speech2text.hotword_list is None:
-            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
-
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = build_streaming_iterator(
-            task_name="asr",
-            preprocess_args=None,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            fs=fs,
-            batch_size=1,
-            key_file=key_file,
-            num_workers=num_workers,
-        )
-
-        if param_dict is not None:
-            use_timestamp = param_dict.get('use_timestamp', True)
-        else:
-            use_timestamp = True
-
-        finish_count = 0
-        file_count = 1
-        lfr_factor = 6
-        # 7 .Start for-loop
-        asr_result_list = []
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        writer = None
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-            ibest_writer = writer[f"1best_recog"]
-
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            beg_vad = time.time()
-            vad_results = speech2vadsegment(**batch)
-            end_vad = time.time()
-            print("time cost vad: ", end_vad - beg_vad)
-            _, vadsegments = vad_results[0], vad_results[1][0]
-            ##################################
-            #####  speaker_verification  #####
-            ##################################
-            # load sv model
-            if ngpu > 0:
-                sv_model_dict = torch.load(sv_model_file)
-                sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
-                sv_model.cuda()
-            else:
-                sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
-                sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
-            sv_model.load_state_dict(sv_model_dict)
-            print(f'load sv model params: {sv_model_file}')
-            sv_model.eval()
-            cb_model = ClusterBackend()
-            vad_segments = []
-            audio = batch['speech'].numpy().reshape(-1)
-            for vadsegment in vadsegments:
-                st = int(vadsegment[0]) / 1000
-                ed = int(vadsegment[1]) / 1000
-                vad_segments.append(
-                    [st, ed, audio[int(st * 16000):int(ed * 16000)]])
-            audio_dur = check_audio_list(vad_segments)
-            if audio_dur > 5:
-                # sv pipeline
-                segments = sv_chunk(vad_segments)
-                embeddings = []
-                for s in segments:
-                    #_, embs = self.sv_pipeline([s[2]], output_emb=True)
-                    # embeddings.append(embs)
-                    wavs = sv_preprocess([s[2]])
-                    # embs = self.forward(wavs)
-                    embs = []
-                    for x in wavs:
-                        x = extract_feature([x])
-                        if ngpu > 0:
-                            x = x.cuda()
-                        embs.append(sv_model(x))
-                    embs = torch.cat(embs)
-                    embeddings.append(embs.cpu().detach().numpy())
-                embeddings = np.concatenate(embeddings)
-                labels = cb_model(embeddings)
-                sv_output = postprocess(segments, vad_segments, labels, embeddings)
-            else:
-                # fake speaker res for too shot utterance
-                sv_output = [[0.0, vadsegments[-1][-1]/1000.0, 0]]
-                logging.warning("Too short utterence found: {}, return default speaker results.".format(keys))
-
-            speech, speech_lengths = batch["speech"], batch["speech_lengths"]
-
-            n = len(vadsegments)
-            data_with_index = [(vadsegments[i], i) for i in range(n)]
-            sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
-            results_sorted = []
-            
-            if not len(sorted_data):
-                key = keys[0]
-                # no active segments after VAD
-                if writer is not None:
-                    # Write empty results
-                    ibest_writer["token"][key] = ""
-                    ibest_writer["token_int"][key] = ""
-                    ibest_writer["vad"][key] = ""
-                    ibest_writer["text"][key] = ""
-                    ibest_writer["text_with_punc"][key] = ""
-                    if use_timestamp:
-                        ibest_writer["time_stamp"][key] = ""
-
-                logging.info("decoding, utt: {}, empty speech".format(key))
-                continue
-
-            batch_size_token_ms = batch_size_token*60
-            if speech2text.device == "cpu":
-                batch_size_token_ms = 0
-            if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
-                batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
-            
-            batch_size_token_ms_cum = 0
-            beg_idx = 0
-            beg_asr_total = time.time()
-            for j, _ in enumerate(tqdm(range(0, n))):
-                batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
-                if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_ms and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_threshold_s:
-                    continue
-                batch_size_token_ms_cum = 0
-                end_idx = j + 1
-                speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
-                beg_idx = end_idx
-                batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
-                batch = to_device(batch, device=device)
-                # print("batch: ", speech_j.shape[0])
-                beg_asr = time.time()
-                results = speech2text(**batch)
-                end_asr = time.time()
-                # print("time cost asr: ", end_asr - beg_asr)
-
-                if len(results) < 1:
-                    results = [["", [], [], [], [], [], []]]
-                results_sorted.extend(results)
-            end_asr_total = time.time()
-            print("total time cost asr: ", end_asr_total-beg_asr_total)
-            restored_data = [0] * n
-            for j in range(n):
-                index = sorted_data[j][1]
-                restored_data[index] = results_sorted[j]
-            result = ["", [], [], [], [], [], []]
-            for j in range(n):
-                result[0] += restored_data[j][0]
-                result[1] += restored_data[j][1]
-                result[2] += restored_data[j][2]
-                if len(restored_data[j][4]) > 0:
-                    for t in restored_data[j][4]:
-                        t[0] += vadsegments[j][0]
-                        t[1] += vadsegments[j][0]
-                    result[4] += restored_data[j][4]
-                # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
-
-            key = keys[0]
-            # result = result_segments[0]
-            text, token, token_int = result[0], result[1], result[2]
-            time_stamp = result[4] if len(result[4]) > 0 else None
-
-            if use_timestamp and time_stamp is not None and len(time_stamp):
-                postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
-            else:
-                postprocessed_result = postprocess_utils.sentence_postprocess(token)
-            text_postprocessed = ""
-            time_stamp_postprocessed = ""
-            text_postprocessed_punc = postprocessed_result
-            if len(postprocessed_result) == 3:
-                text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
-                                                                           postprocessed_result[1], \
-                                                                           postprocessed_result[2]
-            else:
-                text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
-
-            text_postprocessed_punc = text_postprocessed
-            punc_id_list = []
-            if len(word_lists) > 0 and text2punc is not None:
-                beg_punc = time.time()
-                text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
-                end_punc = time.time()
-                print("time cost punc: ", end_punc - beg_punc)
-
-            item = {'key': key, 'value': text_postprocessed_punc}
-            if text_postprocessed != "":
-                item['text_postprocessed'] = text_postprocessed
-            if time_stamp_postprocessed != "":
-                item['time_stamp'] = time_stamp_postprocessed
-
-            item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
-
-            asr_result_list.append(item)
-            finish_count += 1
-            # asr_utils.print_progress(finish_count / file_count)
-            if writer is not None:
-                # Write the result to each file
-                ibest_writer["token"][key] = " ".join(token)
-                ibest_writer["token_int"][key] = " ".join(map(str, token_int))
-                ibest_writer["vad"][key] = "{}".format(vadsegments)
-                ibest_writer["text"][key] = " ".join(word_lists)
-                ibest_writer["text_with_punc"][key] = text_postprocessed_punc
-                if time_stamp_postprocessed is not None:
-                    ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
-
-            logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
-        torch.cuda.empty_cache()
-        distribute_spk(asr_result_list[0]['sentences'], sv_output)
-        return asr_result_list
-
-    return _forward
-
-
-def inference_paraformer_online(
-        maxlenratio: float=0.0,
-        minlenratio: float=0.0,
-        batch_size: int=1,
-        beam_size: int=1,
-        ngpu: int=1,
-        ctc_weight: float=0.0,
-        lm_weight: float=0.0,
-        penalty: float=0.0,
-        log_level: Union[int, str]=logging.ERROR,
-        # data_path_and_name_and_type,
-        asr_train_config: Optional[str]=None,
-        asr_model_file: Optional[str]=None,
-        cmvn_file: Optional[str] = None,
-        lm_train_config: Optional[str] = None,
-        lm_file: Optional[str] = None,
-        token_type: Optional[str] = None,
-        key_file: Optional[str] = None,
-        word_lm_train_config: Optional[str] = None,
-        bpemodel: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        dtype: str = "float32",
-        seed: int = 0,
-        ngram_weight: float = 0.9,
-        nbest: int = 1,
-        num_workers: int = 1,
-        output_dir: Optional[str] = None,
-        param_dict: dict = None,
-        **kwargs,
-):
-
-    if word_lm_train_config is not None:
-        raise NotImplementedError("Word LM is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    export_mode = False
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-        batch_size = 1
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2text
-    speech2text_kwargs = dict(
-        asr_train_config=asr_train_config,
-        asr_model_file=asr_model_file,
-        cmvn_file=cmvn_file,
-        lm_train_config=lm_train_config,
-        lm_file=lm_file,
-        token_type=token_type,
-        bpemodel=bpemodel,
-        device=device,
-        maxlenratio=maxlenratio,
-        minlenratio=minlenratio,
-        dtype=dtype,
-        beam_size=beam_size,
-        ctc_weight=ctc_weight,
-        lm_weight=lm_weight,
-        ngram_weight=ngram_weight,
-        penalty=penalty,
-        nbest=nbest,
-    )
-
-    speech2text = Speech2TextParaformerOnline(**speech2text_kwargs)
-
-    def _load_bytes(input):
-        middle_data = np.frombuffer(input, dtype=np.int16)
-        middle_data = np.asarray(middle_data)
-        if middle_data.dtype.kind not in 'iu':
-            raise TypeError("'middle_data' must be an array of integers")
-        dtype = np.dtype('float32')
-        if dtype.kind != 'f':
-            raise TypeError("'dtype' must be a floating point type")
-
-        i = np.iinfo(middle_data.dtype)
-        abs_max = 2 ** (i.bits - 1)
-        offset = i.min + abs_max
-        array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
-        return array
-
-    def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
-        if not Path(yaml_path).exists():
-            raise FileExistsError(f'The {yaml_path} does not exist.')
-
-        with open(str(yaml_path), 'rb') as f:
-            data = yaml.load(f, Loader=yaml.Loader)
-        return data
-
-    def _prepare_cache(cache: dict = {}, chunk_size=[5, 10, 5], encoder_chunk_look_back=0,
-                       decoder_chunk_look_back=0, batch_size=1):
-        if len(cache) > 0:
-            return cache
-        config = _read_yaml(asr_train_config)
-        enc_output_size = config["encoder_conf"]["output_size"]
-        feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
-        cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
-                    "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
-                    "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
-                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
-        cache["encoder"] = cache_en
-
-        cache_de = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None, "chunk_size": chunk_size}
-        cache["decoder"] = cache_de
-
-        return cache
-
-    def _cache_reset(cache: dict = {}, chunk_size=[5, 10, 5], encoder_chunk_look_back=0,
-                     decoder_chunk_look_back=0, batch_size=1):
-        if len(cache) > 0:
-            config = _read_yaml(asr_train_config)
-            enc_output_size = config["encoder_conf"]["output_size"]
-            feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
-            cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
-                        "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
-                        "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
-                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
-            cache["encoder"] = cache_en
-
-            cache_de = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None, "chunk_size": chunk_size}
-            cache["decoder"] = cache_de
-
-        return cache
-
-
-    def _forward(
-            data_path_and_name_and_type,
-            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-            output_dir_v2: Optional[str] = None,
-            fs: dict = None,
-            param_dict: dict = None,
-            **kwargs,
-    ):
-
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
-            raw_inputs = _load_bytes(data_path_and_name_and_type[0])
-            raw_inputs = torch.tensor(raw_inputs)
-        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
-            try:
-                raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
-            except:
-                # raw_inputs = librosa.load(data_path_and_name_and_type[0], dtype='float32')[0]
-                raw_inputs, sr = librosa.load(data_path_and_name_and_type[0], dtype='float32')
-                if raw_inputs.ndim == 2:
-                    raw_inputs = raw_inputs[:, 0]
-                raw_inputs = torch.tensor(raw_inputs)
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, np.ndarray):
-                raw_inputs = torch.tensor(raw_inputs)
-        is_final = False
-        cache = {}
-        chunk_size = [5, 10, 5]
-        encoder_chunk_look_back = 0
-        decoder_chunk_look_back = 0
-        if param_dict is not None and "cache" in param_dict:
-            cache = param_dict["cache"]
-        if param_dict is not None and "is_final" in param_dict:
-            is_final = param_dict["is_final"]
-        if param_dict is not None and "chunk_size" in param_dict:
-            chunk_size = param_dict["chunk_size"]
-        if param_dict is not None and "encoder_chunk_look_back" in param_dict:
-            encoder_chunk_look_back = param_dict["encoder_chunk_look_back"]
-            if encoder_chunk_look_back > 0:
-                chunk_size[0] = 0
-        if param_dict is not None and "decoder_chunk_look_back" in param_dict:
-            decoder_chunk_look_back = param_dict["decoder_chunk_look_back"]
-
-        # 7 .Start for-loop
-        # FIXME(kamo): The output format should be discussed about
-        raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
-        asr_result_list = []
-        cache = _prepare_cache(cache, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, 
-                               decoder_chunk_look_back=decoder_chunk_look_back, batch_size=1)
-        item = {}
-        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
-            sample_offset = 0
-            speech_length = raw_inputs.shape[1]
-            stride_size = chunk_size[1] * 960
-            cache = _prepare_cache(cache, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, 
-                                   decoder_chunk_look_back=decoder_chunk_look_back, batch_size=1)
-            final_result = ""
-            for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
-                if sample_offset + stride_size >= speech_length - 1:
-                    stride_size = speech_length - sample_offset
-                    cache["encoder"]["is_final"] = True
-                else:
-                    cache["encoder"]["is_final"] = False
-                input_lens = torch.tensor([stride_size])
-                asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
-                if len(asr_result) != 0:
-                    final_result += " ".join(asr_result) + " "
-            item = {'key': "utt", 'value': final_result.strip()}
-        else:
-            input_lens = torch.tensor([raw_inputs.shape[1]])
-            cache["encoder"]["is_final"] = is_final
-            asr_result = speech2text(cache, raw_inputs, input_lens)
-            item = {'key': "utt", 'value': " ".join(asr_result)}
-
-        asr_result_list.append(item)
-        if is_final:
-            cache = _cache_reset(cache, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, 
-                                 decoder_chunk_look_back=decoder_chunk_look_back, batch_size=1)
-        return asr_result_list
-
-    return _forward
-
-
-def inference_uniasr(
-        maxlenratio: float,
-        minlenratio: float,
-        batch_size: int,
-        beam_size: int,
-        ngpu: int,
-        ctc_weight: float,
-        lm_weight: float,
-        penalty: float,
-        log_level: Union[int, str],
-        # data_path_and_name_and_type,
-        asr_train_config: Optional[str],
-        asr_model_file: Optional[str],
-        ngram_file: Optional[str] = None,
-        cmvn_file: Optional[str] = None,
-        # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-        lm_train_config: Optional[str] = None,
-        lm_file: Optional[str] = None,
-        token_type: Optional[str] = None,
-        key_file: Optional[str] = None,
-        word_lm_train_config: Optional[str] = None,
-        bpemodel: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        streaming: bool = False,
-        output_dir: Optional[str] = None,
-        dtype: str = "float32",
-        seed: int = 0,
-        ngram_weight: float = 0.9,
-        nbest: int = 1,
-        num_workers: int = 1,
-        token_num_relax: int = 1,
-        decoding_ind: int = 0,
-        decoding_mode: str = "model1",
-        param_dict: dict = None,
-        **kwargs,
-):
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-    if batch_size > 1:
-        raise NotImplementedError("batch decoding is not implemented")
-    if word_lm_train_config is not None:
-        raise NotImplementedError("Word LM is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    if param_dict is not None and "decoding_model" in param_dict:
-        if param_dict["decoding_model"] == "fast":
-            decoding_ind = 0
-            decoding_mode = "model1"
-        elif param_dict["decoding_model"] == "normal":
-            decoding_ind = 0
-            decoding_mode = "model2"
-        elif param_dict["decoding_model"] == "offline":
-            decoding_ind = 1
-            decoding_mode = "model2"
-        else:
-            raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2text
-    speech2text_kwargs = dict(
-        asr_train_config=asr_train_config,
-        asr_model_file=asr_model_file,
-        cmvn_file=cmvn_file,
-        lm_train_config=lm_train_config,
-        lm_file=lm_file,
-        ngram_file=ngram_file,
-        token_type=token_type,
-        bpemodel=bpemodel,
-        device=device,
-        maxlenratio=maxlenratio,
-        minlenratio=minlenratio,
-        dtype=dtype,
-        beam_size=beam_size,
-        ctc_weight=ctc_weight,
-        lm_weight=lm_weight,
-        ngram_weight=ngram_weight,
-        penalty=penalty,
-        nbest=nbest,
-        streaming=streaming,
-        token_num_relax=token_num_relax,
-        decoding_ind=decoding_ind,
-        decoding_mode=decoding_mode,
-    )
-    speech2text = Speech2TextUniASR(**speech2text_kwargs)
-
-    def _forward(data_path_and_name_and_type,
-                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-                 output_dir_v2: Optional[str] = None,
-                 fs: dict = None,
-                 param_dict: dict = None,
-                 **kwargs,
-                 ):
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = build_streaming_iterator(
-            task_name="asr",
-            preprocess_args=speech2text.asr_train_args,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            fs=fs,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-        )
-
-        finish_count = 0
-        file_count = 1
-        # 7 .Start for-loop
-        # FIXME(kamo): The output format should be discussed about
-        asr_result_list = []
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-        else:
-            writer = None
-
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
-            # N-best list of (text, token, token_int, hyp_object)
-            try:
-                results = speech2text(**batch)
-            except TooShortUttError as e:
-                logging.warning(f"Utterance {keys} {e}")
-                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
-                results = [[" ", ["sil"], [2], hyp]] * nbest
-
-            # Only supporting batch_size==1
-            key = keys[0]
-            logging.info(f"Utterance: {key}")
-            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
-                # Create a directory: outdir/{n}best_recog
-                if writer is not None:
-                    ibest_writer = writer[f"{n}best_recog"]
-
-                    # Write the result to each file
-                    ibest_writer["token"][key] = " ".join(token)
-                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
-                    ibest_writer["score"][key] = str(hyp.score)
-
-                if text is not None:
-                    text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
-                    item = {'key': key, 'value': text_postprocessed}
-                    asr_result_list.append(item)
-                    finish_count += 1
-                    asr_utils.print_progress(finish_count / file_count)
-                    if writer is not None:
-                        ibest_writer["text"][key] = " ".join(word_lists)
-        return asr_result_list
-
-    return _forward
-
-
-def inference_mfcca(
-        maxlenratio: float,
-        minlenratio: float,
-        batch_size: int,
-        beam_size: int,
-        ngpu: int,
-        ctc_weight: float,
-        lm_weight: float,
-        penalty: float,
-        log_level: Union[int, str],
-        # data_path_and_name_and_type,
-        asr_train_config: Optional[str],
-        asr_model_file: Optional[str],
-        cmvn_file: Optional[str] = None,
-        lm_train_config: Optional[str] = None,
-        lm_file: Optional[str] = None,
-        token_type: Optional[str] = None,
-        key_file: Optional[str] = None,
-        word_lm_train_config: Optional[str] = None,
-        bpemodel: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        streaming: bool = False,
-        output_dir: Optional[str] = None,
-        dtype: str = "float32",
-        seed: int = 0,
-        ngram_weight: float = 0.9,
-        nbest: int = 1,
-        num_workers: int = 1,
-        param_dict: dict = None,
-        **kwargs,
-):
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-    if batch_size > 1:
-        raise NotImplementedError("batch decoding is not implemented")
-    if word_lm_train_config is not None:
-        raise NotImplementedError("Word LM is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2text
-    speech2text_kwargs = dict(
-        asr_train_config=asr_train_config,
-        asr_model_file=asr_model_file,
-        cmvn_file=cmvn_file,
-        lm_train_config=lm_train_config,
-        lm_file=lm_file,
-        token_type=token_type,
-        bpemodel=bpemodel,
-        device=device,
-        maxlenratio=maxlenratio,
-        minlenratio=minlenratio,
-        dtype=dtype,
-        beam_size=beam_size,
-        ctc_weight=ctc_weight,
-        lm_weight=lm_weight,
-        ngram_weight=ngram_weight,
-        penalty=penalty,
-        nbest=nbest,
-        streaming=streaming,
-    )
-    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
-    speech2text = Speech2TextMFCCA(**speech2text_kwargs)
-
-    def _forward(data_path_and_name_and_type,
-                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-                 output_dir_v2: Optional[str] = None,
-                 fs: dict = None,
-                 param_dict: dict = None,
-                 **kwargs,
-                 ):
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = build_streaming_iterator(
-            task_name="asr",
-            preprocess_args=speech2text.asr_train_args,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            batch_size=batch_size,
-            fs=fs,
-            mc=True,
-            key_file=key_file,
-            num_workers=num_workers,
-        )
-
-        finish_count = 0
-        file_count = 1
-        # 7 .Start for-loop
-        # FIXME(kamo): The output format should be discussed about
-        asr_result_list = []
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-        else:
-            writer = None
-
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
-            # N-best list of (text, token, token_int, hyp_object)
-            try:
-                results = speech2text(**batch)
-            except TooShortUttError as e:
-                logging.warning(f"Utterance {keys} {e}")
-                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
-                results = [[" ", ["<space>"], [2], hyp]] * nbest
-
-            # Only supporting batch_size==1
-            key = keys[0]
-            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
-                # Create a directory: outdir/{n}best_recog
-                if writer is not None:
-                    ibest_writer = writer[f"{n}best_recog"]
-
-                    # Write the result to each file
-                    ibest_writer["token"][key] = " ".join(token)
-                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
-                    ibest_writer["score"][key] = str(hyp.score)
-
-                if text is not None:
-                    text_postprocessed = postprocess_utils.sentence_postprocess(token)
-                    item = {'key': key, 'value': text_postprocessed}
-                    asr_result_list.append(item)
-                    finish_count += 1
-                    asr_utils.print_progress(finish_count / file_count)
-                    if writer is not None:
-                        ibest_writer["text"][key] = text
-        return asr_result_list
-
-    return _forward
-
-
-def inference_transducer(
-        output_dir: str,
-        batch_size: int,
-        dtype: str,
-        beam_size: int,
-        ngpu: int,
-        seed: int,
-        lm_weight: float,
-        nbest: int,
-        num_workers: int,
-        log_level: Union[int, str],
-        # data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
-        asr_train_config: Optional[str],
-        asr_model_file: Optional[str],
-        cmvn_file: Optional[str] = None,
-        beam_search_config: Optional[dict] = None,
-        lm_train_config: Optional[str] = None,
-        lm_file: Optional[str] = None,
-        model_tag: Optional[str] = None,
-        token_type: Optional[str] = None,
-        bpemodel: Optional[str] = None,
-        key_file: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        quantize_asr_model: Optional[bool] = False,
-        quantize_modules: Optional[List[str]] = None,
-        quantize_dtype: Optional[str] = "float16",
-        streaming: Optional[bool] = False,
-        fake_streaming: Optional[bool] = False,
-        full_utt: Optional[bool] = False,
-        chunk_size: Optional[int] = 16,
-        left_context: Optional[int] = 16,
-        right_context: Optional[int] = 0,
-        display_partial_hypotheses: bool = False,
-        **kwargs,
-) -> None:
-    """Transducer model inference.
-    Args:
-        output_dir: Output directory path.
-        batch_size: Batch decoding size.
-        dtype: Data type.
-        beam_size: Beam size.
-        ngpu: Number of GPUs.
-        seed: Random number generator seed.
-        lm_weight: Weight of language model.
-        nbest: Number of final hypothesis.
-        num_workers: Number of workers.
-        log_level: Level of verbose for logs.
-        data_path_and_name_and_type:
-        asr_train_config: ASR model training config path.
-        asr_model_file: ASR model path.
-        beam_search_config: Beam search config path.
-        lm_train_config: Language Model training config path.
-        lm_file: Language Model path.
-        model_tag: Model tag.
-        token_type: Type of token units.
-        bpemodel: BPE model path.
-        key_file: File key.
-        allow_variable_data_keys: Whether to allow variable data keys.
-        quantize_asr_model: Whether to apply dynamic quantization to ASR model.
-        quantize_modules: List of module names to apply dynamic quantization on.
-        quantize_dtype: Dynamic quantization data type.
-        streaming: Whether to perform chunk-by-chunk inference.
-        chunk_size: Number of frames in chunk AFTER subsampling.
-        left_context: Number of frames in left context AFTER subsampling.
-        right_context: Number of frames in right context AFTER subsampling.
-        display_partial_hypotheses: Whether to display partial hypotheses.
-    """
-
-    if batch_size > 1:
-        raise NotImplementedError("batch decoding is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2text
-    speech2text_kwargs = dict(
-        asr_train_config=asr_train_config,
-        asr_model_file=asr_model_file,
-        cmvn_file=cmvn_file,
-        beam_search_config=beam_search_config,
-        lm_train_config=lm_train_config,
-        lm_file=lm_file,
-        token_type=token_type,
-        bpemodel=bpemodel,
-        device=device,
-        dtype=dtype,
-        beam_size=beam_size,
-        lm_weight=lm_weight,
-        nbest=nbest,
-        quantize_asr_model=quantize_asr_model,
-        quantize_modules=quantize_modules,
-        quantize_dtype=quantize_dtype,
-        streaming=streaming,
-        fake_streaming=fake_streaming,
-        full_utt=full_utt,
-        chunk_size=chunk_size,
-        left_context=left_context,
-        right_context=right_context,
-    )
-    speech2text = Speech2TextTransducer(**speech2text_kwargs)
-
-    def _forward(data_path_and_name_and_type,
-                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-                 output_dir_v2: Optional[str] = None,
-                 fs: dict = None,
-                 param_dict: dict = None,
-                 **kwargs,
-                 ):
-        # 3. Build data-iterator
-        loader = build_streaming_iterator(
-            task_name="asr",
-            preprocess_args=speech2text.asr_train_args,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-        )
-        asr_result_list = []
-
-        if output_dir is not None:
-            writer = DatadirWriter(output_dir)
-        else:
-            writer = None
-
-        # 4 .Start for-loop
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-            assert len(batch.keys()) == 1
-
-            try:
-                if speech2text.streaming:
-                    speech = batch["speech"]
-
-                    _steps = len(speech) // speech2text._ctx
-                    _end = 0
-                    for i in range(_steps):
-                        _end = (i + 1) * speech2text._ctx
-
-                        speech2text.streaming_decode(
-                            speech[i * speech2text._ctx: _end + speech2text._right_ctx], is_final=False
-                        )
-
-                    final_hyps = speech2text.streaming_decode(
-                        speech[_end: len(speech)], is_final=True
-                    )
-                elif speech2text.fake_streaming:
-                    final_hyps = speech2text.fake_streaming_decode(**batch)
-                elif speech2text.full_utt:
-                    final_hyps = speech2text.full_utt_decode(**batch)
-                else:
-                    final_hyps = speech2text(**batch)
-
-                results = speech2text.hypotheses_to_results(final_hyps)
-            except TooShortUttError as e:
-                logging.warning(f"Utterance {keys} {e}")
-                hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
-                results = [[" ", ["<space>"], [2], hyp]] * nbest
-
-            key = keys[0]
-            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
-                item = {'key': key, 'value': text}
-                asr_result_list.append(item)
-                if writer is not None:
-                    ibest_writer = writer[f"{n}best_recog"]
-
-                    ibest_writer["token"][key] = " ".join(token)
-                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
-                    ibest_writer["score"][key] = str(hyp.score)
-
-                    if text is not None:
-                        ibest_writer["text"][key] = text
-
-                logging.info("decoding, utt: {}, predictions: {}".format(key, text))
-        return asr_result_list
-    return _forward
-
-
-def inference_sa_asr(
-        maxlenratio: float,
-        minlenratio: float,
-        batch_size: int,
-        beam_size: int,
-        ngpu: int,
-        ctc_weight: float,
-        lm_weight: float,
-        penalty: float,
-        log_level: Union[int, str],
-        # data_path_and_name_and_type,
-        asr_train_config: Optional[str],
-        asr_model_file: Optional[str],
-        cmvn_file: Optional[str] = None,
-        lm_train_config: Optional[str] = None,
-        lm_file: Optional[str] = None,
-        token_type: Optional[str] = None,
-        key_file: Optional[str] = None,
-        word_lm_train_config: Optional[str] = None,
-        bpemodel: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        streaming: bool = False,
-        output_dir: Optional[str] = None,
-        dtype: str = "float32",
-        seed: int = 0,
-        ngram_weight: float = 0.9,
-        nbest: int = 1,
-        num_workers: int = 1,
-        mc: bool = False,
-        param_dict: dict = None,
-        **kwargs,
-):
-    if batch_size > 1:
-        raise NotImplementedError("batch decoding is not implemented")
-    if word_lm_train_config is not None:
-        raise NotImplementedError("Word LM is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    for handler in logging.root.handlers[:]:
-        logging.root.removeHandler(handler)
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2text
-    speech2text_kwargs = dict(
-        asr_train_config=asr_train_config,
-        asr_model_file=asr_model_file,
-        cmvn_file=cmvn_file,
-        lm_train_config=lm_train_config,
-        lm_file=lm_file,
-        token_type=token_type,
-        bpemodel=bpemodel,
-        device=device,
-        maxlenratio=maxlenratio,
-        minlenratio=minlenratio,
-        dtype=dtype,
-        beam_size=beam_size,
-        ctc_weight=ctc_weight,
-        lm_weight=lm_weight,
-        ngram_weight=ngram_weight,
-        penalty=penalty,
-        nbest=nbest,
-        streaming=streaming,
-    )
-    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
-    speech2text = Speech2TextSAASR(**speech2text_kwargs)
-
-    def _forward(data_path_and_name_and_type,
-                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-                 output_dir_v2: Optional[str] = None,
-                 fs: dict = None,
-                 param_dict: dict = None,
-                 **kwargs,
-                 ):
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = build_streaming_iterator(
-            task_name="asr",
-            preprocess_args=speech2text.asr_train_args,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            fs=fs,
-            mc=mc,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-        )
-
-        finish_count = 0
-        file_count = 1
-        # 7 .Start for-loop
-        # FIXME(kamo): The output format should be discussed about
-        asr_result_list = []
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-        else:
-            writer = None
-
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-            # N-best list of (text, token, token_int, hyp_object)
-            try:
-                results = speech2text(**batch)
-            except TooShortUttError as e:
-                logging.warning(f"Utterance {keys} {e}")
-                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
-                results = [[" ", ["sil"], [2], hyp]] * nbest
-
-            # Only supporting batch_size==1
-            key = keys[0]
-            for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
-                # Create a directory: outdir/{n}best_recog
-                if writer is not None:
-                    ibest_writer = writer[f"{n}best_recog"]
-
-                    # Write the result to each file
-                    ibest_writer["token"][key] = " ".join(token)
-                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
-                    ibest_writer["score"][key] = str(hyp.score)
-                    ibest_writer["text_id"][key] = text_id
-
-                if text is not None:
-                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-                    item = {'key': key, 'value': text_postprocessed}
-                    asr_result_list.append(item)
-                    finish_count += 1
-                    asr_utils.print_progress(finish_count / file_count)
-                    if writer is not None:
-                        ibest_writer["text"][key] = text
-
-                logging.info("uttid: {}".format(key))
-                logging.info("text predictions: {}".format(text))
-                logging.info("text_id predictions: {}\n".format(text_id))
-        return asr_result_list
-
-    return _forward
-
-def inference_whisper(
-        maxlenratio: float,
-        minlenratio: float,
-        batch_size: int,
-        beam_size: int,
-        ngpu: int,
-        ctc_weight: float,
-        lm_weight: float,
-        penalty: float,
-        log_level: Union[int, str],
-        # data_path_and_name_and_type,
-        asr_train_config: Optional[str],
-        asr_model_file: Optional[str],
-        cmvn_file: Optional[str] = None,
-        lm_train_config: Optional[str] = None,
-        lm_file: Optional[str] = None,
-        token_type: Optional[str] = None,
-        key_file: Optional[str] = None,
-        word_lm_train_config: Optional[str] = None,
-        bpemodel: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        streaming: bool = False,
-        output_dir: Optional[str] = None,
-        dtype: str = "float32",
-        seed: int = 0,
-        ngram_weight: float = 0.9,
-        nbest: int = 1,
-        num_workers: int = 1,
-        mc: bool = False,
-        param_dict: dict = None,
-        **kwargs,
-):
-
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-    if param_dict:
-        language = param_dict.get("language", None)
-        task = param_dict.get("task", "transcribe")
-    else:
-        language = None
-        task = "transcribe"
-    if batch_size > 1:
-        raise NotImplementedError("batch decoding is not implemented")
-    if word_lm_train_config is not None:
-        raise NotImplementedError("Word LM is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    for handler in logging.root.handlers[:]:
-        logging.root.removeHandler(handler)
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2text
-    speech2text_kwargs = dict(
-        asr_train_config=asr_train_config,
-        asr_model_file=asr_model_file,
-        cmvn_file=cmvn_file,
-        lm_train_config=lm_train_config,
-        lm_file=lm_file,
-        token_type=token_type,
-        bpemodel=bpemodel,
-        device=device,
-        maxlenratio=maxlenratio,
-        minlenratio=minlenratio,
-        dtype=dtype,
-        beam_size=beam_size,
-        ctc_weight=ctc_weight,
-        lm_weight=lm_weight,
-        ngram_weight=ngram_weight,
-        penalty=penalty,
-        nbest=nbest,
-        streaming=streaming,
-        language=language,
-        task=task,
-    )
-    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
-    speech2text = Speech2TextWhisper(**speech2text_kwargs)
-
-    def _forward(data_path_and_name_and_type,
-                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-                 output_dir_v2: Optional[str] = None,
-                 fs: dict = None,
-                 param_dict: dict = None,
-                 **kwargs,
-                 ):
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = build_streaming_iterator(
-            task_name="asr",
-            preprocess_args=speech2text.asr_train_args,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            fs=fs,
-            mc=mc,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-        )
-
-        finish_count = 0
-        file_count = 1
-        # 7 .Start for-loop
-        # FIXME(kamo): The output format should be discussed about
-        asr_result_list = []
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-        else:
-            writer = None
-
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
-            # N-best list of (text, token, token_int, hyp_object)
-            try:
-                results = speech2text(**batch)
-            except TooShortUttError as e:
-                logging.warning(f"Utterance {keys} {e}")
-                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
-                results = [[" ", ["sil"], [2], hyp]] * nbest
-
-            # Only supporting batch_size==1
-            key = keys[0]
-
-            for n, (text, language) in zip(range(1, nbest + 1), results):
-                # Create a directory: outdir/{n}best_recog
-                if writer is not None:
-                    ibest_writer = writer[f"{n}best_recog"]
-
-                    # Write the result to each file
-                    ibest_writer["language"][key] = language
-
-                if text is not None:
-                    item = {'key': key, 'value': text}
-                    asr_result_list.append(item)
-                    finish_count += 1
-                    if writer is not None:
-                        ibest_writer["text"][key] = text
-
-                logging.info("uttid: {}".format(key))
-                logging.info("text predictions: {}\n".format(text))
-        return asr_result_list
-
-    return _forward
-
-def inference_launch(**kwargs):
-    if 'mode' in kwargs:
-        mode = kwargs['mode']
-    else:
-        logging.info("Unknown decoding mode.")
-        return None
-    if mode == "asr":
-        return inference_asr(**kwargs)
-    elif mode == "uniasr":
-        return inference_uniasr(**kwargs)
-    elif mode == "paraformer":
-        return inference_paraformer(**kwargs)
-    elif mode == "paraformer_fake_streaming":
-        return inference_paraformer(**kwargs)
-    elif mode == "paraformer_streaming":
-        return inference_paraformer_online(**kwargs)
-    elif mode.startswith("paraformer_vad_speaker"):
-        return inference_paraformer_vad_speaker(**kwargs)
-    elif mode.startswith("paraformer_vad"):
-        return inference_paraformer_vad_punc(**kwargs)
-    elif mode == "mfcca":
-        return inference_mfcca(**kwargs)
-    elif mode == "rnnt":
-        return inference_transducer(**kwargs)
-    elif mode == "bat":
-        return inference_transducer(**kwargs)
-    elif mode == "sa_asr":
-        return inference_sa_asr(**kwargs)
-    elif mode == "whisper":
-        return inference_whisper(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
-
-def main(cmd=None):
-    print(get_commandline_args(), file=sys.stderr)
-    from funasr.bin.argument import get_parser
-    parser = get_parser()
-    parser.add_argument(
-        "--mode",
-        type=str,
-        default="asr",
-        help="The decoding mode",
-    )
-    args = parser.parse_args(cmd)
-    kwargs = vars(args)
-    kwargs.pop("config", None)
-
-    # set logging messages
-    logging.basicConfig(
-        level=args.log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-    logging.info("Decoding args: {}".format(kwargs))
-
-    # gpu setting
-    if args.ngpu > 0:
-        jobid = int(args.output_dir.split(".")[-1])
-        gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
-        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
-        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
-    inference_pipeline = inference_launch(**kwargs)
-    return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
-
-
-if __name__ == "__main__":
-    main()
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
deleted file mode 100644
index c03bdf3..0000000
--- a/funasr/bin/build_trainer.py
+++ /dev/null
@@ -1,725 +0,0 @@
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import os
-import sys
-from io import BytesIO
-
-import torch
-import yaml
-
-from funasr.build_utils.build_args import build_args
-from funasr.build_utils.build_dataloader import build_dataloader
-from funasr.build_utils.build_distributed import build_distributed
-from funasr.build_utils.build_model import build_model
-from funasr.build_utils.build_optimizer import build_optimizer
-from funasr.build_utils.build_scheduler import build_scheduler
-from funasr.build_utils.build_trainer import build_trainer as build_trainer_modelscope
-from funasr.modules.lora.utils import mark_only_lora_as_trainable
-from funasr.tokenizer.phoneme_tokenizer import g2p_choices
-from funasr.torch_utils.load_pretrained_model import load_pretrained_model
-from funasr.torch_utils.model_summary import model_summary
-from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.prepare_data import prepare_data
-from funasr.utils.types import int_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str_or_none
-from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
-
-
-def update_dct(fin_configs, root):
-    if root == {}:
-        return {}
-    for root_key, root_value in root.items():
-        if not isinstance(root[root_key], dict):
-            fin_configs[root_key] = root[root_key]
-        else:
-            if root_key in fin_configs.keys():
-                result = update_dct(fin_configs[root_key], root[root_key])
-                fin_configs[root_key] = result
-            else:
-                fin_configs[root_key] = root[root_key]
-    return fin_configs
-
-
-def get_parser():
-    parser = argparse.ArgumentParser(
-        description="FunASR Common Training Parser",
-    )
-
-    # common configuration
-    parser.add_argument("--output_dir", help="model save path")
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=0,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
-
-    # ddp related
-    parser.add_argument(
-        "--dist_backend",
-        default="nccl",
-        type=str,
-        help="distributed backend",
-    )
-    parser.add_argument(
-        "--dist_init_method",
-        type=str,
-        default="env://",
-        help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
-             '"WORLD_SIZE", and "RANK" are referred.',
-    )
-    parser.add_argument(
-        "--dist_world_size",
-        type=int,
-        default=1,
-        help="number of nodes for distributed training",
-    )
-    parser.add_argument(
-        "--dist_rank",
-        type=int,
-        default=None,
-        help="node rank for distributed training",
-    )
-    parser.add_argument(
-        "--local_rank",
-        type=int,
-        default=None,
-        help="local rank for distributed training",
-    )
-    parser.add_argument(
-        "--dist_master_addr",
-        default=None,
-        type=str_or_none,
-        help="The master address for distributed training. "
-             "This value is used when dist_init_method == 'env://'",
-    )
-    parser.add_argument(
-        "--dist_master_port",
-        default=None,
-        type=int_or_none,
-        help="The master port for distributed training"
-             "This value is used when dist_init_method == 'env://'",
-    )
-    parser.add_argument(
-        "--dist_launcher",
-        default=None,
-        type=str_or_none,
-        choices=["slurm", "mpi", None],
-        help="The launcher type for distributed training",
-    )
-    parser.add_argument(
-        "--multiprocessing_distributed",
-        default=True,
-        type=str2bool,
-        help="Use multi-processing distributed training to launch "
-             "N processes per node, which has N GPUs. This is the "
-             "fastest way to use PyTorch for either single node or "
-             "multi node data parallel training",
-    )
-    parser.add_argument(
-        "--unused_parameters",
-        type=str2bool,
-        default=False,
-        help="Whether to use the find_unused_parameters in "
-             "torch.nn.parallel.DistributedDataParallel ",
-    )
-    parser.add_argument(
-        "--gpu_id",
-        type=int,
-        default=0,
-        help="local gpu id.",
-    )
-
-    # cudnn related
-    parser.add_argument(
-        "--cudnn_enabled",
-        type=str2bool,
-        default=torch.backends.cudnn.enabled,
-        help="Enable CUDNN",
-    )
-    parser.add_argument(
-        "--cudnn_benchmark",
-        type=str2bool,
-        default=torch.backends.cudnn.benchmark,
-        help="Enable cudnn-benchmark mode",
-    )
-    parser.add_argument(
-        "--cudnn_deterministic",
-        type=str2bool,
-        default=True,
-        help="Enable cudnn-deterministic mode",
-    )
-
-    # trainer related
-    parser.add_argument(
-        "--max_epoch",
-        type=int,
-        default=40,
-        help="The maximum number epoch to train",
-    )
-    parser.add_argument(
-        "--max_update",
-        type=int,
-        default=sys.maxsize,
-        help="The maximum number update step to train",
-    )
-    parser.add_argument(
-        "--batch_interval",
-        type=int,
-        default=10000,
-        help="The batch interval for saving model.",
-    )
-    parser.add_argument(
-        "--patience",
-        type=int_or_none,
-        default=None,
-        help="Number of epochs to wait without improvement "
-             "before stopping the training",
-    )
-    parser.add_argument(
-        "--val_scheduler_criterion",
-        type=str,
-        nargs=2,
-        default=("valid", "loss"),
-        help="The criterion used for the value given to the lr scheduler. "
-             'Give a pair referring the phase, "train" or "valid",'
-             'and the criterion name. The mode specifying "min" or "max" can '
-             "be changed by --scheduler_conf",
-    )
-    parser.add_argument(
-        "--early_stopping_criterion",
-        type=str,
-        nargs=3,
-        default=("valid", "loss", "min"),
-        help="The criterion used for judging of early stopping. "
-             'Give a pair referring the phase, "train" or "valid",'
-             'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
-    )
-    parser.add_argument(
-        "--best_model_criterion",
-        nargs="+",
-        default=[
-            ("train", "loss", "min"),
-            ("valid", "loss", "min"),
-            ("train", "acc", "max"),
-            ("valid", "acc", "max"),
-        ],
-        help="The criterion used for judging of the best model. "
-             'Give a pair referring the phase, "train" or "valid",'
-             'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
-    )
-    parser.add_argument(
-        "--keep_nbest_models",
-        type=int,
-        nargs="+",
-        default=[10],
-        help="Remove previous snapshots excluding the n-best scored epochs",
-    )
-    parser.add_argument(
-        "--nbest_averaging_interval",
-        type=int,
-        default=0,
-        help="The epoch interval to apply model averaging and save nbest models",
-    )
-    parser.add_argument(
-        "--grad_clip",
-        type=float,
-        default=5.0,
-        help="Gradient norm threshold to clip",
-    )
-    parser.add_argument(
-        "--grad_clip_type",
-        type=float,
-        default=2.0,
-        help="The type of the used p-norm for gradient clip. Can be inf",
-    )
-    parser.add_argument(
-        "--grad_noise",
-        type=str2bool,
-        default=False,
-        help="The flag to switch to use noise injection to "
-             "gradients during training",
-    )
-    parser.add_argument(
-        "--accum_grad",
-        type=int,
-        default=1,
-        help="The number of gradient accumulation",
-    )
-    parser.add_argument(
-        "--resume",
-        type=str2bool,
-        default=False,
-        help="Enable resuming if checkpoint is existing",
-    )
-    parser.add_argument(
-        "--train_dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type for training.",
-    )
-    parser.add_argument(
-        "--use_amp",
-        type=str2bool,
-        default=False,
-        help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
-    )
-    parser.add_argument(
-        "--log_interval",
-        default=None,
-        help="Show the logs every the number iterations in each epochs at the "
-             "training phase. If None is given, it is decided according the number "
-             "of training samples automatically .",
-    )
-    parser.add_argument(
-        "--use_tensorboard",
-        type=str2bool,
-        default=True,
-        help="Enable tensorboard logging",
-    )
-
-    # pretrained model related
-    parser.add_argument(
-        "--init_param",
-        type=str,
-        action="append",
-        default=[],
-        help="Specify the file path used for initialization of parameters. "
-             "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
-             "where file_path is the model file path, "
-             "src_key specifies the key of model states to be used in the model file, "
-             "dst_key specifies the attribute of the model to be initialized, "
-             "and exclude_keys excludes keys of model states for the initialization."
-             "e.g.\n"
-             "  # Load all parameters"
-             "  --init_param some/where/model.pb\n"
-             "  # Load only decoder parameters"
-             "  --init_param some/where/model.pb:decoder:decoder\n"
-             "  # Load only decoder parameters excluding decoder.embed"
-             "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
-             "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
-    )
-    parser.add_argument(
-        "--ignore_init_mismatch",
-        type=str2bool,
-        default=False,
-        help="Ignore size mismatch when loading pre-trained model",
-    )
-    parser.add_argument(
-        "--freeze_param",
-        type=str,
-        default=[],
-        action="append",
-        help="Freeze parameters",
-    )
-
-    # dataset related
-    parser.add_argument(
-        "--dataset_type",
-        type=str,
-        default="small",
-        help="whether to use dataloader for large dataset",
-    )
-    parser.add_argument(
-        "--dataset_conf",
-        action=NestedDictAction,
-        default=dict(),
-        help=f"The keyword arguments for dataset",
-    )
-    parser.add_argument(
-        "--data_dir",
-        type=str,
-        default=None,
-        help="root path of data",
-    )
-    parser.add_argument(
-        "--train_set",
-        type=str,
-        default="train",
-        help="train dataset",
-    )
-    parser.add_argument(
-        "--valid_set",
-        type=str,
-        default="validation",
-        help="dev dataset",
-    )
-    parser.add_argument(
-        "--data_file_names",
-        type=str,
-        default="wav.scp,text",
-        help="input data files",
-    )
-    parser.add_argument(
-        "--speed_perturb",
-        type=float,
-        nargs="+",
-        default=None,
-        help="speed perturb",
-    )
-    parser.add_argument(
-        "--use_preprocessor",
-        type=str2bool,
-        default=True,
-        help="Apply preprocessing to data or not",
-    )
-
-    # optimization related
-    parser.add_argument(
-        "--optim",
-        type=lambda x: x.lower(),
-        default="adam",
-        help="The optimizer type",
-    )
-    parser.add_argument(
-        "--optim_conf",
-        action=NestedDictAction,
-        default=dict(),
-        help="The keyword arguments for optimizer",
-    )
-    parser.add_argument(
-        "--scheduler",
-        type=lambda x: str_or_none(x.lower()),
-        default=None,
-        help="The lr scheduler type",
-    )
-    parser.add_argument(
-        "--scheduler_conf",
-        action=NestedDictAction,
-        default=dict(),
-        help="The keyword arguments for lr scheduler",
-    )
-
-    # most task related
-    parser.add_argument(
-        "--init",
-        type=lambda x: str_or_none(x.lower()),
-        default=None,
-        help="The initialization method",
-        choices=[
-            "chainer",
-            "xavier_uniform",
-            "xavier_normal",
-            "kaiming_uniform",
-            "kaiming_normal",
-            None,
-        ],
-    )
-    parser.add_argument(
-        "--token_list",
-        type=str_or_none,
-        default=None,
-        help="A text mapping int-id to token",
-    )
-    parser.add_argument(
-        "--token_type",
-        type=str,
-        default="bpe",
-        choices=["bpe", "char", "word"],
-        help="",
-    )
-    parser.add_argument(
-        "--bpemodel",
-        type=str_or_none,
-        default=None,
-        help="The model file fo sentencepiece",
-    )
-    parser.add_argument(
-        "--cleaner",
-        type=str_or_none,
-        choices=[None, "tacotron", "jaconv", "vietnamese"],
-        default=None,
-        help="Apply text cleaning",
-    )
-    parser.add_argument(
-        "--g2p",
-        type=str_or_none,
-        choices=g2p_choices,
-        default=None,
-        help="Specify g2p method if --token_type=phn",
-    )
-
-    # pai related
-    parser.add_argument(
-        "--use_pai",
-        type=str2bool,
-        default=False,
-        help="flag to indicate whether training on PAI",
-    )
-    parser.add_argument(
-        "--simple_ddp",
-        type=str2bool,
-        default=False,
-    )
-    parser.add_argument(
-        "--num_worker_count",
-        type=int,
-        default=1,
-        help="The number of machines on PAI.",
-    )
-    parser.add_argument(
-        "--access_key_id",
-        type=str,
-        default=None,
-        help="The username for oss.",
-    )
-    parser.add_argument(
-        "--access_key_secret",
-        type=str,
-        default=None,
-        help="The password for oss.",
-    )
-    parser.add_argument(
-        "--endpoint",
-        type=str,
-        default=None,
-        help="The endpoint for oss.",
-    )
-    parser.add_argument(
-        "--bucket_name",
-        type=str,
-        default=None,
-        help="The bucket name for oss.",
-    )
-    parser.add_argument(
-        "--oss_bucket",
-        default=None,
-        help="oss bucket.",
-    )
-    parser.add_argument(
-        "--enable_lora",
-        type=str2bool,
-        default=False,
-        help="Apply lora for finetuning.",
-    )
-    parser.add_argument(
-        "--lora_bias",
-        type=str,
-        default="none",
-        help="lora bias.",
-    )
-
-    return parser
-
-
-def build_trainer(modelscope_dict,
-                  data_dir,
-                  output_dir,
-                  train_set="train",
-                  dev_set="validation",
-                  distributed=False,
-                  dataset_type="small",
-                  batch_bins=None,
-                  max_epoch=None,
-                  optim=None,
-                  lr=None,
-                  scheduler=None,
-                  scheduler_conf=None,
-                  specaug=None,
-                  specaug_conf=None,
-                  mate_params=None,
-                  **kwargs):
-    parser = get_parser()
-    args, extra_task_params = parser.parse_known_args()
-    args = build_args(args, parser, extra_task_params)
-
-    if args.local_rank is not None:
-        distributed = True
-    else:
-        distributed = False
-    args.local_rank = args.local_rank if args.local_rank is not None else 0
-    local_rank = args.local_rank
-    if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
-        gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
-        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
-    else:
-        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
-
-    config = modelscope_dict['am_model_config']
-    finetune_config = modelscope_dict['finetune_config']
-    init_param = modelscope_dict['init_model']
-    cmvn_file = modelscope_dict['cmvn_file']
-    seg_dict_file = modelscope_dict['seg_dict']
-    if 'bpemodel' in modelscope_dict:
-        bpemodel = modelscope_dict['bpemodel']
-    else:
-        bpemodel = None
-
-    # overwrite parameters
-    with open(config) as f:
-        configs = yaml.safe_load(f)
-    with open(finetune_config) as f:
-        finetune_configs = yaml.safe_load(f)
-        # set data_types
-        if dataset_type == "large":
-            # finetune_configs["dataset_conf"]["data_types"] = "sound,text"
-            if 'data_types' not in finetune_configs['dataset_conf']:
-                finetune_configs["dataset_conf"]["data_types"] = "sound,text"
-    finetune_configs = update_dct(configs, finetune_configs)
-    for key, value in finetune_configs.items():
-        if hasattr(args, key):
-            setattr(args, key, value)
-    if mate_params is not None:
-        for key, value in mate_params.items():
-            if hasattr(args, key):
-                setattr(args, key, value)
-    if mate_params is not None and "lora_params" in mate_params:
-        lora_params = mate_params['lora_params']
-        configs['encoder_conf'].update(lora_params)
-        configs['decoder_conf'].update(lora_params)
-    args.dataset_type = dataset_type
-    args.init_param = [init_param]
-    if mate_params is not None and "init_param" in mate_params:
-        if len(mate_params["init_param"]) != 0:
-            args.init_param = mate_params["init_param"]
-    args.cmvn_file = cmvn_file
-    if os.path.exists(seg_dict_file):
-        args.seg_dict_file = seg_dict_file
-    else:
-        args.seg_dict_file = None
-    if bpemodel is not None and os.path.exists(bpemodel):
-        args.bpemodel = bpemodel
-    else:
-        args.bpemodel = None
-    args.data_dir = data_dir
-    args.train_set = train_set
-    args.dev_set = dev_set
-    args.output_dir = output_dir
-    args.gpu_id = args.local_rank
-    args.config = finetune_config
-    args.use_pai = False
-    args.batch_type = "length"
-    args.oss_bucket = None
-    args.input_size = None
-    if distributed:
-        args.distributed = True
-        args.simple_ddp = True
-    else:
-        args.distributed = False
-        args.ngpu = 1
-    if optim is not None:
-        args.optim = optim
-    if lr is not None:
-        args.optim_conf["lr"] = lr
-    if scheduler is not None:
-        args.scheduler = scheduler
-    if scheduler_conf is not None:
-        args.scheduler_conf = scheduler_conf
-    if specaug is not None:
-        args.specaug = specaug
-    if specaug_conf is not None:
-        args.specaug_conf = specaug_conf
-    if max_epoch is not None:
-        args.max_epoch = max_epoch
-    if batch_bins is not None:
-        if args.dataset_type == "small":
-            args.batch_bins = batch_bins
-            args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
-        elif args.dataset_type == "large":
-            args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
-        else:
-            raise ValueError(f"Not supported dataset_type={args.dataset_type}")
-    if args.normalize in ["null", "none", "None"]:
-        args.normalize = None
-    if args.patience in ["null", "none", "None"]:
-        args.patience = None
-    args.local_rank = local_rank
-
-    # set random seed
-    set_all_random_seed(args.seed)
-    torch.backends.cudnn.enabled = args.cudnn_enabled
-    torch.backends.cudnn.benchmark = args.cudnn_benchmark
-    torch.backends.cudnn.deterministic = args.cudnn_deterministic
-
-    # ddp init
-    distributed_option = build_distributed(args)
-
-    # for logging
-    if not distributed_option.distributed or distributed_option.dist_rank == 0:
-        logging.basicConfig(
-            level="INFO",
-            format=f"[{os.uname()[1].split('.')[0]}]"
-                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-        )
-    else:
-        logging.basicConfig(
-            level="ERROR",
-            format=f"[{os.uname()[1].split('.')[0]}]"
-                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-        )
-
-    # prepare files for dataloader
-    prepare_data(args, distributed_option)
-
-    model = build_model(args)
-    model = model.to(
-        dtype=getattr(torch, args.train_dtype),
-        device="cuda" if args.ngpu > 0 else "cpu",
-    )
-    if args.enable_lora:
-        mark_only_lora_as_trainable(model, args.lora_bias)
-    for t in args.freeze_param:
-        for k, p in model.named_parameters():
-            if k.startswith(t + ".") or k == t:
-                logging.info(f"Setting {k}.requires_grad = False")
-                p.requires_grad = False
-
-    optimizers = build_optimizer(args, model=model)
-    schedulers = build_scheduler(args, optimizers)
-
-    logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
-                                                                   distributed_option.dist_rank,
-                                                                   distributed_option.local_rank))
-    logging.info(pytorch_cudnn_version())
-    logging.info("Args: {}".format(args))
-    logging.info(model_summary(model))
-    logging.info("Optimizer: {}".format(optimizers))
-    logging.info("Scheduler: {}".format(schedulers))
-
-    # dump args to config.yaml
-    if not distributed_option.distributed or distributed_option.dist_rank == 0:
-        os.makedirs(args.output_dir, exist_ok=True)
-        with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
-            logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml"))
-            if args.use_pai:
-                buffer = BytesIO()
-                torch.save({"config": vars(args)}, buffer)
-                args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
-            else:
-                yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
-
-    for p in args.init_param:
-        logging.info(f"Loading pretrained params from {p}")
-        load_pretrained_model(
-            model=model,
-            init_param=p,
-            ignore_init_mismatch=args.ignore_init_mismatch,
-            map_location=f"cuda:{torch.cuda.current_device()}"
-            if args.ngpu > 0
-            else "cpu",
-            oss_bucket=args.oss_bucket,
-        )
-
-    # dataloader for training/validation
-    train_dataloader, valid_dataloader = build_dataloader(args)
-
-    # Trainer, including model, optimizers, etc.
-    trainer = build_trainer_modelscope(
-        args=args,
-        model=model,
-        optimizers=optimizers,
-        schedulers=schedulers,
-        train_dataloader=train_dataloader,
-        valid_dataloader=valid_dataloader,
-        distributed_option=distributed_option
-    )
-
-    return trainer
diff --git a/funasr/bin/data2vec_train.py b/funasr/bin/data2vec_train.py
deleted file mode 100755
index b9dbdff..0000000
--- a/funasr/bin/data2vec_train.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#!/usr/bin/env python3
-
-import os
-
-from funasr.tasks.data2vec import Data2VecTask
-
-
-def parse_args():
-    parser = Data2VecTask.get_parser()
-    parser.add_argument(
-        "--gpu_id",
-        type=int,
-        default=0,
-        help="local gpu id.",
-    )
-    args = parser.parse_args()
-    return args
-
-
-def main(args=None, cmd=None):
-    # for data2vec Training
-    Data2VecTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
-    args = parse_args()
-
-    # setup local gpu_id
-    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
-    # DDP settings
-    if args.ngpu > 1:
-        args.distributed = True
-    else:
-        args.distributed = False
-    assert args.num_worker_count == 1
-
-    # re-compute batch size: when dataset type is small
-    if args.dataset_type == "small":
-        if args.batch_size is not None:
-            args.batch_size = args.batch_size * args.ngpu
-        if args.batch_bins is not None:
-            args.batch_bins = args.batch_bins * args.ngpu
-
-    main(args=args)
diff --git a/funasr/bin/diar_infer.py b/funasr/bin/diar_infer.py
deleted file mode 100755
index bb40f5e..0000000
--- a/funasr/bin/diar_infer.py
+++ /dev/null
@@ -1,272 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import logging
-import os
-from collections import OrderedDict
-from pathlib import Path
-from typing import Any
-from typing import Optional
-from typing import Union
-
-import numpy as np
-import torch
-from scipy.ndimage import median_filter
-from torch.nn import functional as F
-
-from funasr.models.frontend.wav_frontend import WavFrontendMel23
-from funasr.tasks.diar import DiarTask
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.torch_utils.device_funcs import to_device
-from funasr.utils.misc import statistic_model_parameters
-
-
-class Speech2DiarizationEEND:
-    """Speech2Diarlization class
-
-    Examples:
-        >>> import librosa
-        >>> import numpy as np
-        >>> speech2diar = Speech2DiarizationEEND("diar_sond_config.yml", "diar_sond.pb")
-        >>> profile = np.load("profiles.npy")
-        >>> audio, rate = librosa.load("speech.wav")
-        >>> speech2diar(audio, profile)
-        {"spk1": [(int, int), ...], ...}
-
-    """
-
-    def __init__(
-            self,
-            diar_train_config: Union[Path, str] = None,
-            diar_model_file: Union[Path, str] = None,
-            device: str = "cpu",
-            dtype: str = "float32",
-    ):
-
-        # 1. Build Diarization model
-        diar_model, diar_train_args = build_model_from_file(
-            config_file=diar_train_config,
-            model_file=diar_model_file,
-            device=device,
-            task_name="diar",
-            mode="eend-ola",
-        )
-        frontend = None
-        if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
-            frontend = WavFrontendMel23(**diar_train_args.frontend_conf)
-
-        # set up seed for eda
-        np.random.seed(diar_train_args.seed)
-        torch.manual_seed(diar_train_args.seed)
-        torch.cuda.manual_seed(diar_train_args.seed)
-        os.environ['PYTORCH_SEED'] = str(diar_train_args.seed)
-        logging.info("diar_model: {}".format(diar_model))
-        logging.info("diar_train_args: {}".format(diar_train_args))
-        diar_model.to(dtype=getattr(torch, dtype)).eval()
-
-        self.diar_model = diar_model
-        self.diar_train_args = diar_train_args
-        self.device = device
-        self.dtype = dtype
-        self.frontend = frontend
-
-    @torch.no_grad()
-    def __call__(
-            self,
-            speech: Union[torch.Tensor, np.ndarray],
-            speech_lengths: Union[torch.Tensor, np.ndarray] = None
-    ):
-        """Inference
-
-        Args:
-            speech: Input speech data
-        Returns:
-            diarization results
-
-        """
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-
-        if self.frontend is not None:
-            feats, feats_len = self.frontend.forward(speech, speech_lengths)
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
-            self.diar_model.frontend = None
-        else:
-            feats = speech
-            feats_len = speech_lengths
-        batch = {"speech": feats, "speech_lengths": feats_len}
-        batch = to_device(batch, device=self.device)
-        results = self.diar_model.estimate_sequential(**batch)
-
-        return results
-
-
-class Speech2DiarizationSOND:
-    """Speech2Xvector class
-
-    Examples:
-        >>> import librosa
-        >>> import numpy as np
-        >>> speech2diar = Speech2DiarizationSOND("diar_sond_config.yml", "diar_sond.pb")
-        >>> profile = np.load("profiles.npy")
-        >>> audio, rate = librosa.load("speech.wav")
-        >>> speech2diar(audio, profile)
-        {"spk1": [(int, int), ...], ...}
-
-    """
-
-    def __init__(
-            self,
-            diar_train_config: Union[Path, str] = None,
-            diar_model_file: Union[Path, str] = None,
-            device: Union[str, torch.device] = "cpu",
-            batch_size: int = 1,
-            dtype: str = "float32",
-            streaming: bool = False,
-            smooth_size: int = 83,
-            dur_threshold: float = 10,
-    ):
-
-        # TODO: 1. Build Diarization model
-        diar_model, diar_train_args = build_model_from_file(
-            config_file=diar_train_config,
-            model_file=diar_model_file,
-            device=device,
-            task_name="diar",
-            mode="sond",
-        )
-        logging.info("diar_model: {}".format(diar_model))
-        logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
-        logging.info("diar_train_args: {}".format(diar_train_args))
-        diar_model.to(dtype=getattr(torch, dtype)).eval()
-
-        self.diar_model = diar_model
-        self.diar_train_args = diar_train_args
-        self.token_list = diar_train_args.token_list
-        self.smooth_size = smooth_size
-        self.dur_threshold = dur_threshold
-        self.device = device
-        self.dtype = dtype
-
-    def smooth_multi_labels(self, multi_label):
-        multi_label = median_filter(multi_label, (self.smooth_size, 1), mode="constant", cval=0.0).astype(int)
-        return multi_label
-
-    @staticmethod
-    def calc_spk_turns(label_arr, spk_list):
-        turn_list = []
-        length = label_arr.shape[0]
-        n_spk = label_arr.shape[1]
-        for k in range(n_spk):
-            if spk_list[k] == "None":
-                continue
-            in_utt = False
-            start = 0
-            for i in range(length):
-                if label_arr[i, k] == 1 and in_utt is False:
-                    start = i
-                    in_utt = True
-                if label_arr[i, k] == 0 and in_utt is True:
-                    turn_list.append([spk_list[k], start, i - start])
-                    in_utt = False
-            if in_utt:
-                turn_list.append([spk_list[k], start, length - start])
-        return turn_list
-
-    @staticmethod
-    def seq2arr(seq, vec_dim=8):
-        def int2vec(x, vec_dim=8, dtype=np.int32):
-            b = ('{:0' + str(vec_dim) + 'b}').format(x)
-            # little-endian order: lower bit first
-            return (np.array(list(b)[::-1]) == '1').astype(dtype)
-
-        # process oov
-        seq = np.array([int(x) for x in seq])
-        new_seq = []
-        for i, x in enumerate(seq):
-            if x < 2 ** vec_dim:
-                new_seq.append(x)
-            else:
-                idx_list = np.where(seq < 2 ** vec_dim)[0]
-                if len(idx_list) > 0:
-                    idx = np.abs(idx_list - i).argmin()
-                    new_seq.append(seq[idx_list[idx]])
-                else:
-                    new_seq.append(0)
-        return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
-
-    def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
-        logits_idx = raw_logits.argmax(-1)  # B, T, vocab_size -> B, T
-        # upsampling outputs to match inputs
-        ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
-        logits_idx = F.upsample(
-            logits_idx.unsqueeze(1).float(),
-            size=(ut,),
-            mode="nearest",
-        ).squeeze(1).long()
-        logits_idx = logits_idx[0].tolist()
-        pse_labels = [self.token_list[x] for x in logits_idx]
-        if output_format == "pse_labels":
-            return pse_labels, None
-
-        multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num]  # remove padding speakers
-        multi_labels = self.smooth_multi_labels(multi_labels)
-        if output_format == "binary_labels":
-            return multi_labels, None
-
-        spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
-        spk_turns = self.calc_spk_turns(multi_labels, spk_list)
-        results = OrderedDict()
-        for spk, st, dur in spk_turns:
-            if spk not in results:
-                results[spk] = []
-            if dur > self.dur_threshold:
-                results[spk].append((st, st + dur))
-
-        # sort segments in start time ascending
-        for spk in results:
-            results[spk] = sorted(results[spk], key=lambda x: x[0])
-
-        return results, pse_labels
-
-    @torch.no_grad()
-    def __call__(
-            self,
-            speech: Union[torch.Tensor, np.ndarray],
-            profile: Union[torch.Tensor, np.ndarray],
-            output_format: str = "speaker_turn"
-    ):
-        """Inference
-
-        Args:
-            speech: Input speech data
-            profile: Speaker profiles
-        Returns:
-            diarization results for each speaker
-
-        """
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-        if isinstance(profile, np.ndarray):
-            profile = torch.tensor(profile)
-
-        # data: (Nsamples,) -> (1, Nsamples)
-        speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
-        profile = profile.unsqueeze(0).to(getattr(torch, self.dtype))
-        # lengths: (1,)
-        speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
-        profile_lengths = profile.new_full([1], dtype=torch.long, fill_value=profile.size(1))
-        batch = {"speech": speech, "speech_lengths": speech_lengths,
-                 "profile": profile, "profile_lengths": profile_lengths}
-        # a. To device
-        batch = to_device(batch, device=self.device)
-
-        logits = self.diar_model.prediction_forward(**batch)
-        results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
-
-        return results, pse_labels
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
deleted file mode 100755
index f5a11b1..0000000
--- a/funasr/bin/diar_inference_launch.py
+++ /dev/null
@@ -1,506 +0,0 @@
-# !/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-
-import argparse
-import logging
-import os
-import sys
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-# import librosa
-import librosa
-import torch
-from scipy.signal import medfilt
-
-from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
-from funasr.datasets.iterable_dataset import load_bytes
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-def inference_sond(
-        diar_train_config: str,
-        diar_model_file: str,
-        output_dir: Optional[str] = None,
-        batch_size: int = 1,
-        dtype: str = "float32",
-        ngpu: int = 0,
-        seed: int = 0,
-        num_workers: int = 0,
-        log_level: Union[int, str] = "INFO",
-        key_file: Optional[str] = None,
-        model_tag: Optional[str] = None,
-        allow_variable_data_keys: bool = True,
-        streaming: bool = False,
-        smooth_size: int = 83,
-        dur_threshold: int = 10,
-        out_format: str = "vad",
-        param_dict: Optional[dict] = None,
-        mode: str = "sond",
-        **kwargs,
-):
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-    if batch_size > 1:
-        raise NotImplementedError("batch decoding is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-    logging.info("param_dict: {}".format(param_dict))
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2a. Build speech2xvec [Optional]
-    if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict[
-        "extract_profile"]:
-        assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
-        assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
-        sv_train_config = param_dict["sv_train_config"]
-        sv_model_file = param_dict["sv_model_file"]
-        if "model_dir" in param_dict:
-            sv_train_config = os.path.join(param_dict["model_dir"], sv_train_config)
-            sv_model_file = os.path.join(param_dict["model_dir"], sv_model_file)
-        from funasr.bin.sv_infer import Speech2Xvector
-        speech2xvector_kwargs = dict(
-            sv_train_config=sv_train_config,
-            sv_model_file=sv_model_file,
-            device=device,
-            dtype=dtype,
-            streaming=streaming,
-            embedding_node="resnet1_dense"
-        )
-        logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
-        speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
-        speech2xvector.sv_model.eval()
-
-    # 2b. Build speech2diar
-    speech2diar_kwargs = dict(
-        diar_train_config=diar_train_config,
-        diar_model_file=diar_model_file,
-        device=device,
-        dtype=dtype,
-        streaming=streaming,
-        smooth_size=smooth_size,
-        dur_threshold=dur_threshold,
-    )
-    logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
-    speech2diar = Speech2DiarizationSOND(**speech2diar_kwargs)
-    speech2diar.diar_model.eval()
-
-    def output_results_str(results: dict, uttid: str):
-        rst = []
-        mid = uttid.rsplit("-", 1)[0]
-        for key in results:
-            results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
-        if out_format == "vad":
-            for spk, segs in results.items():
-                rst.append("{} {}".format(spk, segs))
-        else:
-            template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
-            for spk, segs in results.items():
-                rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
-
-        return "\n".join(rst)
-
-    def _forward(
-            data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
-            raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
-            output_dir_v2: Optional[str] = None,
-            param_dict: Optional[dict] = None,
-    ):
-        logging.info("param_dict: {}".format(param_dict))
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, (list, tuple)):
-                if not isinstance(raw_inputs[0], List):
-                    raw_inputs = [raw_inputs]
-
-                assert all([len(example) >= 2 for example in raw_inputs]), \
-                    "The length of test case in raw_inputs must larger than 1 (>=2)."
-
-                def prepare_dataset():
-                    for idx, example in enumerate(raw_inputs):
-                        # read waveform file
-                        example = [load_bytes(x) if isinstance(x, bytes) else x
-                                   for x in example]
-                        # example = [librosa.load(x)[0] if isinstance(x, str) else x
-                        #            for x in example]
-                        example = [librosa.load(x, dtype='float32')[0] if isinstance(x, str) else x
-                                   for x in example]
-                        # convert torch tensor to numpy array
-                        example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
-                                   for x in example]
-                        speech = example[0]
-                        logging.info("Extracting profiles for {} waveforms".format(len(example) - 1))
-                        profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
-                        profile = torch.cat(profile, dim=0)
-                        yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
-
-                loader = prepare_dataset()
-            else:
-                raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
-        else:
-            # 3. Build data-iterator
-            loader = build_streaming_iterator(
-                task_name="diar",
-                preprocess_args=None,
-                data_path_and_name_and_type=data_path_and_name_and_type,
-                dtype=dtype,
-                batch_size=batch_size,
-                key_file=key_file,
-                num_workers=num_workers,
-                use_collate_fn=False,
-            )
-
-        # 7. Start for-loop
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path is not None:
-            os.makedirs(output_path, exist_ok=True)
-            output_writer = open("{}/result.txt".format(output_path), "w")
-            pse_label_writer = open("{}/labels.txt".format(output_path), "w")
-        logging.info("Start to diarize...")
-        result_list = []
-        for idx, (keys, batch) in enumerate(loader):
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
-            results, pse_labels = speech2diar(**batch)
-            # Only supporting batch_size==1
-            key, value = keys[0], output_results_str(results, keys[0])
-            item = {"key": key, "value": value}
-            result_list.append(item)
-            if output_path is not None:
-                output_writer.write(value)
-                output_writer.flush()
-                pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
-                pse_label_writer.flush()
-
-            if idx % 100 == 0:
-                logging.info("Processing {:5d}: {}".format(idx, key))
-
-        if output_path is not None:
-            output_writer.close()
-            pse_label_writer.close()
-
-        return result_list
-
-    return _forward
-
-
-def inference_eend(
-        diar_train_config: str,
-        diar_model_file: str,
-        output_dir: Optional[str] = None,
-        batch_size: int = 1,
-        dtype: str = "float32",
-        ngpu: int = 1,
-        num_workers: int = 0,
-        log_level: Union[int, str] = "INFO",
-        key_file: Optional[str] = None,
-        model_tag: Optional[str] = None,
-        allow_variable_data_keys: bool = True,
-        streaming: bool = False,
-        param_dict: Optional[dict] = None,
-        **kwargs,
-):
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-    if batch_size > 1:
-        raise NotImplementedError("batch decoding is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-    logging.info("param_dict: {}".format(param_dict))
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Build speech2diar
-    speech2diar_kwargs = dict(
-        diar_train_config=diar_train_config,
-        diar_model_file=diar_model_file,
-        device=device,
-        dtype=dtype,
-    )
-    logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
-    speech2diar = Speech2DiarizationEEND(**speech2diar_kwargs)
-    speech2diar.diar_model.eval()
-
-    def output_results_str(results: dict, uttid: str):
-        rst = []
-        mid = uttid.rsplit("-", 1)[0]
-        for key in results:
-            results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
-        template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
-        for spk, segs in results.items():
-            rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
-
-        return "\n".join(rst)
-
-    def _forward(
-            data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
-            raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
-            output_dir_v2: Optional[str] = None,
-            param_dict: Optional[dict] = None,
-    ):
-        # 2. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
-        loader = build_streaming_iterator(
-            task_name="diar",
-            preprocess_args=None,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-        )
-
-        # 3. Start for-loop
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path is not None:
-            os.makedirs(output_path, exist_ok=True)
-            output_writer = open("{}/result.txt".format(output_path), "w")
-        result_list = []
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
-            results = speech2diar(**batch)
-
-            # post process
-            a = results[0][0].cpu().numpy()
-            a = medfilt(a, (11, 1))
-            rst = []
-            for spkid, frames in enumerate(a.T):
-                frames = np.pad(frames, (1, 1), 'constant')
-                changes, = np.where(np.diff(frames, axis=0) != 0)
-                fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
-                for s, e in zip(changes[::2], changes[1::2]):
-                    st = s / 10.
-                    dur = (e - s) / 10.
-                    rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
-
-            # Only supporting batch_size==1
-            value = "\n".join(rst)
-            item = {"key": keys[0], "value": value}
-            result_list.append(item)
-            if output_path is not None:
-                output_writer.write(value)
-                output_writer.flush()
-
-        if output_path is not None:
-            output_writer.close()
-
-        return result_list
-
-    return _forward
-
-
-def inference_launch(mode, **kwargs):
-    if mode == "sond":
-        return inference_sond(mode=mode, **kwargs)
-    elif mode == "sond_demo":
-        param_dict = {
-            "extract_profile": True,
-            "sv_train_config": "sv.yaml",
-            "sv_model_file": "sv.pb",
-        }
-        if "param_dict" in kwargs and kwargs["param_dict"] is not None:
-            for key in param_dict:
-                if key not in kwargs["param_dict"]:
-                    kwargs["param_dict"][key] = param_dict[key]
-        else:
-            kwargs["param_dict"] = param_dict
-        return inference_sond(mode=mode, **kwargs)
-    elif mode == "eend-ola":
-        return inference_eend(mode=mode, **kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
-
-
-def get_parser():
-    parser = config_argparse.ArgumentParser(
-        description="Speaker Verification",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-
-    # Note(kamo): Use '_' instead of '-' as separator.
-    # '-' is confusing if written in yaml.
-    parser.add_argument(
-        "--log_level",
-        type=lambda x: x.upper(),
-        default="INFO",
-        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-        help="The verbose level of logging",
-    )
-
-    parser.add_argument("--output_dir", type=str, required=False)
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=0,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument(
-        "--njob",
-        type=int,
-        default=1,
-        help="The number of jobs for each gpu",
-    )
-    parser.add_argument(
-        "--gpuid_list",
-        type=str,
-        default="",
-        help="The visible gpus",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument(
-        "--dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type",
-    )
-    parser.add_argument(
-        "--num_workers",
-        type=int,
-        default=1,
-        help="The number of workers used for DataLoader",
-    )
-
-    group = parser.add_argument_group("Input data related")
-    group.add_argument(
-        "--data_path_and_name_and_type",
-        type=str2triple_str,
-        required=False,
-        action="append",
-    )
-    group.add_argument("--key_file", type=str_or_none)
-    group.add_argument("--allow_variable_data_keys", type=str2bool, default=True)
-
-    group = parser.add_argument_group("The model configuration related")
-    group.add_argument(
-        "--vad_infer_config",
-        type=str,
-        help="VAD infer configuration",
-    )
-    group.add_argument(
-        "--vad_model_file",
-        type=str,
-        help="VAD model parameter file",
-    )
-    group.add_argument(
-        "--diar_train_config",
-        type=str,
-        help="ASR training configuration",
-    )
-    group.add_argument(
-        "--diar_model_file",
-        type=str,
-        help="ASR model parameter file",
-    )
-    group.add_argument(
-        "--cmvn_file",
-        type=str,
-        help="Global CMVN file",
-    )
-    group.add_argument(
-        "--model_tag",
-        type=str,
-        help="Pretrained model tag. If specify this option, *_train_config and "
-             "*_file will be overwritten",
-    )
-
-    group = parser.add_argument_group("The inference configuration related")
-    group.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="The batch size for inference",
-    )
-    group.add_argument(
-        "--smooth_size",
-        type=int,
-        default=121,
-        help="The smoothing size for post-processing"
-    )
-    group.add_argument(
-        "--dur_threshold",
-        type=int,
-        default=10,
-        help="The threshold of minimum duration"
-    )
-
-    return parser
-
-
-def main(cmd=None):
-    print(get_commandline_args(), file=sys.stderr)
-    parser = get_parser()
-    parser.add_argument(
-        "--mode",
-        type=str,
-        default="sond",
-        help="The decoding mode",
-    )
-    args = parser.parse_args(cmd)
-    kwargs = vars(args)
-    kwargs.pop("config", None)
-
-    # set logging messages
-    logging.basicConfig(
-        level=args.log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-    logging.info("Decoding args: {}".format(kwargs))
-
-    # gpu setting
-    if args.ngpu > 0:
-        jobid = int(args.output_dir.split(".")[-1])
-        gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
-        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
-        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
-    inference_pipeline = inference_launch(**kwargs)
-    return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
-    main()
diff --git a/funasr/bin/diar_train.py b/funasr/bin/diar_train.py
deleted file mode 100755
index 16a4bd0..0000000
--- a/funasr/bin/diar_train.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import os
-
-from funasr.tasks.diar import DiarTask
-
-
-# for ASR Training
-def parse_args():
-    parser = DiarTask.get_parser()
-    parser.add_argument(
-        "--gpu_id",
-        type=int,
-        default=0,
-        help="local gpu id.",
-    )
-    args = parser.parse_args()
-    return args
-
-
-def main(args=None, cmd=None):
-    # for ASR Training
-    DiarTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
-    args = parse_args()
-
-    # setup local gpu_id
-    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
-    # DDP settings
-    if args.ngpu > 1:
-        args.distributed = True
-    else:
-        args.distributed = False
-    assert args.num_worker_count == 1
-
-    # re-compute batch size: when dataset type is small
-    if args.dataset_type == "small":
-        if args.batch_size is not None:
-            args.batch_size = args.batch_size * args.ngpu
-        if args.batch_bins is not None:
-            args.batch_bins = args.batch_bins * args.ngpu
-
-    main(args=args)
diff --git a/funasr/export/export_model.py b/funasr/bin/export_model.py
similarity index 100%
rename from funasr/export/export_model.py
rename to funasr/bin/export_model.py
diff --git a/funasr/bin/inference_cli.py b/funasr/bin/inference_cli.py
deleted file mode 100644
index f4c66f1..0000000
--- a/funasr/bin/inference_cli.py
+++ /dev/null
@@ -1,139 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import os
-
-import logging
-import torch
-import numpy as np
-from funasr.utils.download_and_prepare_model import prepare_model
-
-from funasr.utils.types import str2bool
-
-def infer(task_name: str = "asr",
-          model: str = None,
-          # mode: str = None,
-          vad_model: str = None,
-          disable_vad: bool = False,
-          punc_model: str = None,
-          disable_punc: bool = False,
-          model_hub: str = "ms",
-          cache_dir: str = None,
-          **kwargs,
-          ):
-
-	# set logging messages
-	logging.basicConfig(
-		level=logging.ERROR,
-	)
-
-	model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
-	if task_name == "asr":
-		from funasr.bin.asr_inference_launch import inference_launch
-		
-		inference_pipeline = inference_launch(**kwargs)
-	elif task_name == "":
-		pipeline = 1
-	elif task_name == "":
-		pipeline = 2
-	elif task_name == "":
-		pipeline = 2
-	
-	def _infer_fn(input, **kwargs):
-		data_type = kwargs.get('data_type', 'sound')
-		data_path_and_name_and_type = [input, 'speech', data_type]
-		raw_inputs = None
-		if isinstance(input, torch.Tensor):
-			input = input.numpy()
-		if isinstance(input, np.ndarray):
-			data_path_and_name_and_type = None
-			raw_inputs = input
-		
-		return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
-	
-	return _infer_fn
-
-
-def main(cmd=None):
-	# print(get_commandline_args(), file=sys.stderr)
-	from funasr.bin.argument import get_parser
-	
-	parser = get_parser()
-	parser.add_argument('input', help='input file to transcribe')
-	parser.add_argument(
-	    "--task_name",
-	    type=str,
-	    default="asr",
-	    help="The decoding mode",
-	)
-	parser.add_argument(
-		"-m",
-	    "--model",
-	    type=str,
-	    default="paraformer-zh",
-	    help="The asr mode name",
-	)
-	parser.add_argument(
-		"-v",
-	    "--vad_model",
-	    type=str,
-	    default="fsmn-vad",
-	    help="vad model name",
-	)
-	parser.add_argument(
-		"-dv",
-	    "--disable_vad",
-	    type=str2bool,
-	    default=False,
-	    help="",
-	)
-	parser.add_argument(
-		"-p",
-	    "--punc_model",
-	    type=str,
-	    default="ct-punc",
-	    help="",
-	)
-	parser.add_argument(
-		"-dp",
-	    "--disable_punc",
-	    type=str2bool,
-	    default=False,
-	    help="",
-	)
-	parser.add_argument(
-	    "--batch_size_token",
-	    type=int,
-	    default=5000,
-	    help="",
-	)
-	parser.add_argument(
-	    "--batch_size_token_threshold_s",
-	    type=int,
-	    default=35,
-	    help="",
-	)
-	parser.add_argument(
-	    "--max_single_segment_time",
-	    type=int,
-	    default=5000,
-	    help="",
-	)
-	args = parser.parse_args(cmd)
-	kwargs = vars(args)
-	
-	# set logging messages
-	logging.basicConfig(
-		level=logging.ERROR,
-	)
-	logging.info("Decoding args: {}".format(kwargs))
-	
-	# kwargs["ncpu"] = 2 #os.cpu_count()
-	kwargs.pop("data_path_and_name_and_type")
-	print("args: {}".format(kwargs))
-	p = infer(**kwargs)
-	
-	res = p(**kwargs)
-	print(res)
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
deleted file mode 100644
index f12f50a..0000000
--- a/funasr/bin/lm_inference_launch.py
+++ /dev/null
@@ -1,392 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import os
-import sys
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Union
-
-import numpy as np
-import torch
-from torch.nn.parallel import data_parallel
-
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import float_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-def inference_lm(
-        batch_size: int,
-        dtype: str,
-        ngpu: int,
-        seed: int,
-        num_workers: int,
-        log_level: Union[int, str],
-        key_file: Optional[str],
-        train_config: Optional[str],
-        model_file: Optional[str],
-        log_base: Optional[float] = 10,
-        allow_variable_data_keys: bool = False,
-        split_with_space: Optional[bool] = False,
-        seg_dict_file: Optional[str] = None,
-        output_dir: Optional[str] = None,
-        param_dict: dict = None,
-        **kwargs,
-):
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build Model
-    model, train_args = build_model_from_file(
-        train_config, model_file, None, device, "lm")
-    wrapped_model = ForwardAdaptor(model, "nll")
-    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
-    logging.info(f"Model:\n{model}")
-
-    preprocessor = LMPreprocessor(
-        train=False,
-        token_type=train_args.token_type,
-        token_list=train_args.token_list,
-        bpemodel=train_args.bpemodel,
-        text_cleaner=train_args.cleaner,
-        g2p_type=train_args.g2p,
-        text_name="text",
-        non_linguistic_symbols=train_args.non_linguistic_symbols,
-        split_with_space=split_with_space,
-        seg_dict_file=seg_dict_file
-    )
-
-    def _forward(
-            data_path_and_name_and_type,
-            raw_inputs: Union[List[Any], bytes, str] = None,
-            output_dir_v2: Optional[str] = None,
-            param_dict: dict = None,
-    ):
-        results = []
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-        else:
-            writer = None
-
-        if raw_inputs != None:
-            line = raw_inputs.strip()
-            key = "lm demo"
-            if line == "":
-                item = {'key': key, 'value': ""}
-                results.append(item)
-                return results
-            batch = {}
-            batch['text'] = line
-            if preprocessor != None:
-                batch = preprocessor(key, batch)
-
-            #  Force data-precision
-            for name in batch:
-                value = batch[name]
-                if not isinstance(value, np.ndarray):
-                    raise RuntimeError(
-                        f"All values must be converted to np.ndarray object "
-                        f'by preprocessing, but "{name}" is still {type(value)}.'
-                    )
-                # Cast to desired type
-                if value.dtype.kind == "f":
-                    value = value.astype("float32")
-                elif value.dtype.kind == "i":
-                    value = value.astype("long")
-                else:
-                    raise NotImplementedError(f"Not supported dtype: {value.dtype}")
-                batch[name] = value
-
-            batch["text_lengths"] = torch.from_numpy(
-                np.array([len(batch["text"])], dtype='int32'))
-            batch["text"] = np.expand_dims(batch["text"], axis=0)
-
-            with torch.no_grad():
-                batch = to_device(batch, device)
-                if ngpu <= 1:
-                    nll, lengths = wrapped_model(**batch)
-                else:
-                    nll, lengths = data_parallel(
-                        wrapped_model, (), range(ngpu), module_kwargs=batch
-                    )
-                ## compute ppl
-                ppl_out_batch = ""
-                ids2tokens = preprocessor.token_id_converter.ids2tokens
-                for sent_ids, sent_nll in zip(batch['text'], nll):
-                    pre_word = "<s>"
-                    cur_word = None
-                    sent_lst = ids2tokens(sent_ids) + ['</s>']
-                    ppl_out = " ".join(sent_lst) + "\n"
-                    for word, word_nll in zip(sent_lst, sent_nll):
-                        cur_word = word
-                        word_nll = -word_nll.cpu()
-                        if log_base is None:
-                            word_prob = np.exp(word_nll)
-                        else:
-                            word_prob = log_base ** (word_nll / np.log(log_base))
-                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
-                            cur=cur_word,
-                            pre=pre_word,
-                            prob=round(word_prob.item(), 8),
-                            word_nll=round(word_nll.item(), 8)
-                        )
-                        pre_word = cur_word
-
-                    sent_nll_mean = sent_nll.mean().cpu().numpy()
-                    sent_nll_sum = sent_nll.sum().cpu().numpy()
-                    if log_base is None:
-                        sent_ppl = np.exp(sent_nll_mean)
-                    else:
-                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
-                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
-                        sent_nll=round(-sent_nll_sum.item(), 4),
-                        sent_ppl=round(sent_ppl.item(), 4)
-                    )
-                    ppl_out_batch += ppl_out
-                    item = {'key': key, 'value': ppl_out}
-                    if writer is not None:
-                        writer["ppl"][key + ":\n"] = ppl_out
-                    results.append(item)
-
-            return results
-
-        # 3. Build data-iterator
-        loader = build_streaming_iterator(
-            task_name="lm",
-            preprocess_args=train_args,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            batch_size=batch_size,
-            key_file=key_file,
-            preprocess_fn=preprocessor,
-            num_workers=num_workers,
-        )
-
-        # 4. Start for-loop
-        total_nll = 0.0
-        total_ntokens = 0
-        ppl_out_all = ""
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
-            ppl_out_batch = ""
-            with torch.no_grad():
-                batch = to_device(batch, device)
-                if ngpu <= 1:
-                    # NOTE(kamo): data_parallel also should work with ngpu=1,
-                    # but for debuggability it's better to keep this block.
-                    nll, lengths = wrapped_model(**batch)
-                else:
-                    nll, lengths = data_parallel(
-                        wrapped_model, (), range(ngpu), module_kwargs=batch
-                    )
-                ## print ppl
-                ids2tokens = preprocessor.token_id_converter.ids2tokens
-                for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
-                    pre_word = "<s>"
-                    cur_word = None
-                    sent_lst = ids2tokens(sent_ids) + ['</s>']
-                    ppl_out = " ".join(sent_lst) + "\n"
-                    for word, word_nll in zip(sent_lst, sent_nll):
-                        cur_word = word
-                        word_nll = -word_nll.cpu()
-                        if log_base is None:
-                            word_prob = np.exp(word_nll)
-                        else:
-                            word_prob = log_base ** (word_nll / np.log(log_base))
-                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
-                            cur=cur_word,
-                            pre=pre_word,
-                            prob=round(word_prob.item(), 8),
-                            word_nll=round(word_nll.item(), 8)
-                        )
-                        pre_word = cur_word
-
-                    sent_nll_mean = sent_nll.mean().cpu().numpy()
-                    sent_nll_sum = sent_nll.sum().cpu().numpy()
-                    if log_base is None:
-                        sent_ppl = np.exp(sent_nll_mean)
-                    else:
-                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
-                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
-                        sent_nll=round(-sent_nll_sum.item(), 4),
-                        sent_ppl=round(sent_ppl.item(), 4)
-                    )
-                    ppl_out_batch += ppl_out
-                    utt2nll = round(-sent_nll_sum.item(), 5)
-                    item = {'key': key, 'value': ppl_out}
-                    if writer is not None:
-                        writer["ppl"][key + ":\n"] = ppl_out
-                        writer["utt2nll"][key] = str(utt2nll)
-                    results.append(item)
-
-            ppl_out_all += ppl_out_batch
-
-            assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
-            # nll: (B, L) -> (B,)
-            nll = nll.detach().cpu().numpy().sum(1)
-            # lengths: (B,)
-            lengths = lengths.detach().cpu().numpy()
-            total_nll += nll.sum()
-            total_ntokens += lengths.sum()
-
-        if log_base is None:
-            ppl = np.exp(total_nll / total_ntokens)
-        else:
-            ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
-
-        avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
-            total_nll=round(-total_nll.item(), 4),
-            total_ppl=round(ppl.item(), 4)
-        )
-        item = {'key': 'AVG PPL', 'value': avg_ppl}
-        ppl_out_all += avg_ppl
-        if writer is not None:
-            writer["ppl"]["AVG PPL : "] = avg_ppl
-        results.append(item)
-
-        return results
-
-    return _forward
-
-
-def inference_launch(mode, **kwargs):
-    if mode == "transformer":
-        return inference_lm(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
-
-
-def get_parser():
-    parser = config_argparse.ArgumentParser(
-        description="Calc perplexity",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-
-    parser.add_argument(
-        "--log_level",
-        type=lambda x: x.upper(),
-        default="INFO",
-        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-        help="The verbose level of logging",
-    )
-    parser.add_argument("--output_dir", type=str, required=True)
-    parser.add_argument("--gpuid_list", type=str, required=True)
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=0,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument("--njob", type=int, default=1, help="Random seed")
-    parser.add_argument(
-        "--dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type",
-    )
-    parser.add_argument(
-        "--num_workers",
-        type=int,
-        default=1,
-        help="The number of workers used for DataLoader",
-    )
-    parser.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="The batch size for inference",
-    )
-    parser.add_argument(
-        "--log_base",
-        type=float_or_none,
-        default=10,
-        help="The base of logarithm for Perplexity. "
-             "If None, napier's constant is used.",
-        required=False
-    )
-
-    group = parser.add_argument_group("Input data related")
-    group.add_argument(
-        "--data_path_and_name_and_type",
-        type=str2triple_str,
-        action="append",
-        required=False
-    )
-    group.add_argument(
-        "--raw_inputs",
-        type=str,
-        required=False
-    )
-    group.add_argument("--key_file", type=str_or_none)
-    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
-    group.add_argument("--split_with_space", type=str2bool, default=False)
-    group.add_argument("--seg_dict_file", type=str_or_none)
-
-    group = parser.add_argument_group("The model configuration related")
-    group.add_argument("--train_config", type=str)
-    group.add_argument("--model_file", type=str)
-    group.add_argument("--mode", type=str, default="lm")
-    return parser
-
-
-def main(cmd=None):
-    print(get_commandline_args(), file=sys.stderr)
-    parser = get_parser()
-    args = parser.parse_args(cmd)
-    kwargs = vars(args)
-    kwargs.pop("config", None)
-
-    # set logging messages
-    logging.basicConfig(
-        level=args.log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-    logging.info("Decoding args: {}".format(kwargs))
-
-    # gpu setting
-    if args.ngpu > 0:
-        jobid = int(args.output_dir.split(".")[-1])
-        gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
-        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
-        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
-    kwargs.pop("gpuid_list", None)
-    kwargs.pop("njob", None)
-    inference_pipeline = inference_launch(**kwargs)
-    return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
-    main()
diff --git a/funasr/bin/lm_train.py b/funasr/bin/lm_train.py
deleted file mode 100755
index 22b5f9c..0000000
--- a/funasr/bin/lm_train.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import os
-
-from funasr.tasks.lm import LMTask
-
-
-# for LM Training
-def parse_args():
-    parser = LMTask.get_parser()
-    parser.add_argument(
-        "--gpu_id",
-        type=int,
-        default=0,
-        help="local gpu id.",
-    )
-    args = parser.parse_args()
-    return args
-
-
-def main(args=None, cmd=None):
-    # for LM Training
-    LMTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
-    args = parse_args()
-
-    # setup local gpu_id
-    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
-    # DDP settings
-    if args.ngpu > 1:
-        args.distributed = True
-    else:
-        args.distributed = False
-    assert args.num_worker_count == 1
-
-    # re-compute batch size: when dataset type is small
-    if args.dataset_type == "small" and args.ngpu != 0:
-        if args.batch_size is not None:
-            args.batch_size = args.batch_size * args.ngpu
-        if args.batch_bins is not None:
-            args.batch_bins = args.batch_bins * args.ngpu
-
-    main(args=args)
diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
deleted file mode 100644
index 9efeb5b..0000000
--- a/funasr/bin/punc_infer.py
+++ /dev/null
@@ -1,282 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-from typing import Optional
-from typing import Union
-
-import numpy as np
-import torch
-import os
-
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
-from funasr.datasets.preprocessor import split_to_mini_sentence
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-
-
-class Text2Punc:
-
-    def __init__(
-            self,
-            train_config: Optional[str],
-            model_file: Optional[str],
-            device: str = "cpu",
-            dtype: str = "float32",
-    ):
-        #  Build Model
-        model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
-        self.device = device
-        # Wrape model to make model.nll() data-parallel
-        self.wrapped_model = ForwardAdaptor(model, "inference")
-        self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
-        # logging.info(f"Model:\n{model}")
-        self.punc_list = train_args.punc_list
-        self.period = 0
-        for i in range(len(self.punc_list)):
-            if self.punc_list[i] == ",":
-                self.punc_list[i] = "锛�"
-            elif self.punc_list[i] == "?":
-                self.punc_list[i] = "锛�"
-            elif self.punc_list[i] == "銆�":
-                self.period = i
-        self.seg_dict_file = None
-        self.seg_jieba = False
-        if "seg_jieba" in train_args:
-            self.seg_jieba = train_args.seg_jieba
-            self.seg_dict_file = os.path.dirname(model_file)+"/"+ "jieba_usr_dict"
-        self.preprocessor = CodeMixTokenizerCommonPreprocessor(
-            train=False,
-            token_type=train_args.token_type,
-            token_list=train_args.token_list,
-            bpemodel=train_args.bpemodel,
-            text_cleaner=train_args.cleaner,
-            g2p_type=train_args.g2p,
-            text_name="text",
-            non_linguistic_symbols=train_args.non_linguistic_symbols,
-            seg_jieba=self.seg_jieba,
-            seg_dict_file=self.seg_dict_file
-        )
-
-    @torch.no_grad()
-    def __call__(self, text: Union[list, str], split_size=20):
-        data = {"text": text}
-        result = self.preprocessor(data=data, uid="12938712838719")
-        split_text = self.preprocessor.pop_split_text_data(result)
-        mini_sentences = split_to_mini_sentence(split_text, split_size)
-        mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
-        assert len(mini_sentences) == len(mini_sentences_id)
-        cache_sent = []
-        cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
-        new_mini_sentence = ""
-        new_mini_sentence_punc = []
-        cache_pop_trigger_limit = 200
-        for mini_sentence_i in range(len(mini_sentences)):
-            mini_sentence = mini_sentences[mini_sentence_i]
-            mini_sentence_id = mini_sentences_id[mini_sentence_i]
-            mini_sentence = cache_sent + mini_sentence
-            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
-            data = {
-                "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
-                "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
-            }
-            data = to_device(data, self.device)
-            y, _ = self.wrapped_model(**data)
-            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
-            punctuations = indices
-            if indices.size()[0] != 1:
-                punctuations = torch.squeeze(indices)
-            assert punctuations.size()[0] == len(mini_sentence)
-
-            # Search for the last Period/QuestionMark as cache
-            if mini_sentence_i < len(mini_sentences) - 1:
-                sentenceEnd = -1
-                last_comma_index = -1
-                for i in range(len(punctuations) - 2, 1, -1):
-                    if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
-                        sentenceEnd = i
-                        break
-                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
-                        last_comma_index = i
-
-                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
-                    # The sentence it too long, cut off at a comma.
-                    sentenceEnd = last_comma_index
-                    punctuations[sentenceEnd] = self.period
-                cache_sent = mini_sentence[sentenceEnd + 1:]
-                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
-                mini_sentence = mini_sentence[0:sentenceEnd + 1]
-                punctuations = punctuations[0:sentenceEnd + 1]
-
-            # if len(punctuations) == 0:
-            #    continue
-
-            punctuations_np = punctuations.cpu().numpy()
-            new_mini_sentence_punc += [int(x) for x in punctuations_np]
-            words_with_punc = []
-            for i in range(len(mini_sentence)):
-                if (i==0 or self.punc_list[punctuations[i-1]] == "銆�" or self.punc_list[punctuations[i-1]] == "锛�") and len(mini_sentence[i][0].encode()) == 1:
-                    mini_sentence[i] = mini_sentence[i].capitalize()
-                if i == 0:
-                    if len(mini_sentence[i][0].encode()) == 1:
-                        mini_sentence[i] = " " + mini_sentence[i]
-                if i > 0:
-                    if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
-                        mini_sentence[i] = " " + mini_sentence[i]
-                words_with_punc.append(mini_sentence[i])
-                if self.punc_list[punctuations[i]] != "_":
-                    punc_res = self.punc_list[punctuations[i]]
-                    if len(mini_sentence[i][0].encode()) == 1:
-                        if punc_res == "锛�":
-                            punc_res = ","
-                        elif punc_res == "銆�":
-                            punc_res = "."
-                        elif punc_res == "锛�":
-                            punc_res = "?"
-                    words_with_punc.append(punc_res)
-            new_mini_sentence += "".join(words_with_punc)
-            # Add Period for the end of the sentence
-            new_mini_sentence_out = new_mini_sentence
-            new_mini_sentence_punc_out = new_mini_sentence_punc
-            if mini_sentence_i == len(mini_sentences) - 1:
-                if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
-                    new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
-                elif new_mini_sentence[-1] == ",":
-                    new_mini_sentence_out = new_mini_sentence[:-1] + "."
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
-                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())==0:
-                    new_mini_sentence_out = new_mini_sentence + "銆�"
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
-                elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
-                    new_mini_sentence_out = new_mini_sentence + "."
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
-        return new_mini_sentence_out, new_mini_sentence_punc_out
-
-
-class Text2PuncVADRealtime:
-
-    def __init__(
-            self,
-            train_config: Optional[str],
-            model_file: Optional[str],
-            device: str = "cpu",
-            dtype: str = "float32",
-    ):
-        #  Build Model
-        model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
-        self.device = device
-        # Wrape model to make model.nll() data-parallel
-        self.wrapped_model = ForwardAdaptor(model, "inference")
-        self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
-        # logging.info(f"Model:\n{model}")
-        self.punc_list = train_args.punc_list
-        self.period = 0
-        for i in range(len(self.punc_list)):
-            if self.punc_list[i] == ",":
-                self.punc_list[i] = "锛�"
-            elif self.punc_list[i] == "?":
-                self.punc_list[i] = "锛�"
-            elif self.punc_list[i] == "銆�":
-                self.period = i
-        self.preprocessor = CodeMixTokenizerCommonPreprocessor(
-            train=False,
-            token_type=train_args.token_type,
-            token_list=train_args.token_list,
-            bpemodel=train_args.bpemodel,
-            text_cleaner=train_args.cleaner,
-            g2p_type=train_args.g2p,
-            text_name="text",
-            non_linguistic_symbols=train_args.non_linguistic_symbols,
-        )
-
-    @torch.no_grad()
-    def __call__(self, text: Union[list, str], cache: list, split_size=20):
-        if cache is not None and len(cache) > 0:
-            precache = "".join(cache)
-        else:
-            precache = ""
-            cache = []
-        data = {"text": precache + " " + text}
-        result = self.preprocessor(data=data, uid="12938712838719")
-        split_text = self.preprocessor.pop_split_text_data(result)
-        mini_sentences = split_to_mini_sentence(split_text, split_size)
-        mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
-        assert len(mini_sentences) == len(mini_sentences_id)
-        cache_sent = []
-        cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
-        sentence_punc_list = []
-        sentence_words_list = []
-        cache_pop_trigger_limit = 200
-        skip_num = 0
-        for mini_sentence_i in range(len(mini_sentences)):
-            mini_sentence = mini_sentences[mini_sentence_i]
-            mini_sentence_id = mini_sentences_id[mini_sentence_i]
-            mini_sentence = cache_sent + mini_sentence
-            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
-            data = {
-                "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
-                "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
-                "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
-            }
-            data = to_device(data, self.device)
-            y, _ = self.wrapped_model(**data)
-            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
-            punctuations = indices
-            if indices.size()[0] != 1:
-                punctuations = torch.squeeze(indices)
-            assert punctuations.size()[0] == len(mini_sentence)
-
-            # Search for the last Period/QuestionMark as cache
-            if mini_sentence_i < len(mini_sentences) - 1:
-                sentenceEnd = -1
-                last_comma_index = -1
-                for i in range(len(punctuations) - 2, 1, -1):
-                    if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
-                        sentenceEnd = i
-                        break
-                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
-                        last_comma_index = i
-
-                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
-                    # The sentence it too long, cut off at a comma.
-                    sentenceEnd = last_comma_index
-                    punctuations[sentenceEnd] = self.period
-                cache_sent = mini_sentence[sentenceEnd + 1:]
-                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
-                mini_sentence = mini_sentence[0:sentenceEnd + 1]
-                punctuations = punctuations[0:sentenceEnd + 1]
-
-            punctuations_np = punctuations.cpu().numpy()
-            sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
-            sentence_words_list += mini_sentence
-
-        assert len(sentence_punc_list) == len(sentence_words_list)
-        words_with_punc = []
-        sentence_punc_list_out = []
-        for i in range(0, len(sentence_words_list)):
-            if i > 0:
-                if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
-                    sentence_words_list[i] = " " + sentence_words_list[i]
-            if skip_num < len(cache):
-                skip_num += 1
-            else:
-                words_with_punc.append(sentence_words_list[i])
-            if skip_num >= len(cache):
-                sentence_punc_list_out.append(sentence_punc_list[i])
-                if sentence_punc_list[i] != "_":
-                    words_with_punc.append(sentence_punc_list[i])
-        sentence_out = "".join(words_with_punc)
-
-        sentenceEnd = -1
-        for i in range(len(sentence_punc_list) - 2, 1, -1):
-            if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
-                sentenceEnd = i
-                break
-        cache_out = sentence_words_list[sentenceEnd + 1:]
-        if sentence_out[-1] in self.punc_list:
-            sentence_out = sentence_out[:-1]
-            sentence_punc_list_out[-1] = "_"
-        return sentence_out, sentence_punc_list_out, cache_out
diff --git a/funasr/bin/punc_inference_launch.py b/funasr/bin/punc_inference_launch.py
deleted file mode 100755
index 5d917f5..0000000
--- a/funasr/bin/punc_inference_launch.py
+++ /dev/null
@@ -1,252 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import os
-import sys
-from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Union
-
-import torch
-
-from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-def inference_punc(
-        batch_size: int,
-        dtype: str,
-        ngpu: int,
-        seed: int,
-        num_workers: int,
-        log_level: Union[int, str],
-        key_file: Optional[str],
-        train_config: Optional[str],
-        model_file: Optional[str],
-        output_dir: Optional[str] = None,
-        param_dict: dict = None,
-        **kwargs,
-):
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-    text2punc = Text2Punc(train_config, model_file, device)
-
-    def _forward(
-            data_path_and_name_and_type,
-            raw_inputs: Union[List[Any], bytes, str] = None,
-            output_dir_v2: Optional[str] = None,
-            cache: List[Any] = None,
-            param_dict: dict = None,
-    ):
-        results = []
-        split_size = 20
-
-        if raw_inputs != None:
-            line = raw_inputs.strip()
-            key = "demo"
-            if line == "":
-                item = {'key': key, 'value': ""}
-                results.append(item)
-                return results
-            result, _ = text2punc(line)
-            item = {'key': key, 'value': result}
-            results.append(item)
-            return results
-
-        for inference_text, _, _ in data_path_and_name_and_type:
-            with open(inference_text, "r", encoding="utf-8") as fin:
-                for line in fin:
-                    line = line.strip()
-                    segs = line.split("\t")
-                    if len(segs) != 2:
-                        continue
-                    key = segs[0]
-                    if len(segs[1]) == 0:
-                        continue
-                    result, _ = text2punc(segs[1])
-                    item = {'key': key, 'value': result}
-                    results.append(item)
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path != None:
-            output_file_name = "infer.out"
-            Path(output_path).mkdir(parents=True, exist_ok=True)
-            output_file_path = (Path(output_path) / output_file_name).absolute()
-            with open(output_file_path, "w", encoding="utf-8") as fout:
-                for item_i in results:
-                    key_out = item_i["key"]
-                    value_out = item_i["value"]
-                    fout.write(f"{key_out}\t{value_out}\n")
-        return results
-
-    return _forward
-
-
-def inference_punc_vad_realtime(
-        batch_size: int,
-        dtype: str,
-        ngpu: int,
-        seed: int,
-        num_workers: int,
-        log_level: Union[int, str],
-        # cache: list,
-        key_file: Optional[str],
-        train_config: Optional[str],
-        model_file: Optional[str],
-        output_dir: Optional[str] = None,
-        param_dict: dict = None,
-        **kwargs,
-):
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-    text2punc = Text2PuncVADRealtime(train_config, model_file, device)
-
-    def _forward(
-            data_path_and_name_and_type,
-            raw_inputs: Union[List[Any], bytes, str] = None,
-            output_dir_v2: Optional[str] = None,
-            cache: List[Any] = None,
-            param_dict: dict = None,
-    ):
-        results = []
-        split_size = 10
-        cache_in = param_dict["cache"]
-        if raw_inputs != None:
-            line = raw_inputs.strip()
-            key = "demo"
-            if line == "":
-                item = {'key': key, 'value': ""}
-                results.append(item)
-                return results
-            result, _, cache = text2punc(line, cache_in)
-            param_dict["cache"] = cache
-            item = {'key': key, 'value': result}
-            results.append(item)
-            return results
-
-        return results
-
-    return _forward
-
-
-def inference_launch(mode, **kwargs):
-    if mode == "punc":
-        return inference_punc(**kwargs)
-    if mode == "punc_VadRealtime":
-        return inference_punc_vad_realtime(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
-
-
-def get_parser():
-    parser = config_argparse.ArgumentParser(
-        description="Punctuation inference",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-
-    parser.add_argument(
-        "--log_level",
-        type=lambda x: x.upper(),
-        default="INFO",
-        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-        help="The verbose level of logging",
-    )
-    parser.add_argument("--output_dir", type=str, required=True)
-    parser.add_argument("--gpuid_list", type=str, required=True)
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=0,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument("--njob", type=int, default=1, help="Random seed")
-    parser.add_argument(
-        "--dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type",
-    )
-    parser.add_argument(
-        "--num_workers",
-        type=int,
-        default=1,
-        help="The number of workers used for DataLoader",
-    )
-    parser.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="The batch size for inference",
-    )
-
-    group = parser.add_argument_group("Input data related")
-    group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
-    group.add_argument("--raw_inputs", type=str, required=False)
-    group.add_argument("--key_file", type=str_or_none)
-    group.add_argument("--cache", type=list, required=False)
-    group.add_argument("--param_dict", type=dict, required=False)
-    group = parser.add_argument_group("The model configuration related")
-    group.add_argument("--train_config", type=str)
-    group.add_argument("--model_file", type=str)
-    group.add_argument("--mode", type=str, default="punc")
-    return parser
-
-
-def main(cmd=None):
-    print(get_commandline_args(), file=sys.stderr)
-    parser = get_parser()
-    args = parser.parse_args(cmd)
-    kwargs = vars(args)
-    kwargs.pop("config", None)
-
-    # set logging messages
-    logging.basicConfig(
-        level=args.log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-    logging.info("Decoding args: {}".format(kwargs))
-
-    # gpu setting
-    if args.ngpu > 0:
-        jobid = int(args.output_dir.split(".")[-1])
-        gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
-        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
-        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
-    kwargs.pop("gpuid_list", None)
-    kwargs.pop("njob", None)
-    inference_pipeline = inference_launch(**kwargs)
-    return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
-    main()
diff --git a/funasr/bin/punc_train.py b/funasr/bin/punc_train.py
deleted file mode 100644
index c3cbee9..0000000
--- a/funasr/bin/punc_train.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import os
-from funasr.tasks.punctuation import PunctuationTask
-
-
-def parse_args():
-    parser = PunctuationTask.get_parser()
-    parser.add_argument(
-        "--gpu_id",
-        type=int,
-        default=0,
-        help="local gpu id.",
-    )
-    parser.add_argument(
-        "--punc_list",
-        type=str,
-        default=None,
-        help="Punctuation list",
-    )
-    args = parser.parse_args()
-    return args
-
-
-def main(args=None, cmd=None):
-    """
-    punc training.
-    """
-    PunctuationTask.main(args=args, cmd=cmd)
-
-
-if __name__ == "__main__":
-    args = parse_args()
-
-    # setup local gpu_id
-    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
-    # DDP settings
-    if args.ngpu > 1:
-        args.distributed = True
-    else:
-        args.distributed = False
-
-    if args.dataset_type == "small":
-        if args.batch_size is not None:
-            args.batch_size = args.batch_size * args.ngpu * args.num_worker_count
-        if args.batch_bins is not None:
-            args.batch_bins = args.batch_bins * args.ngpu * args.num_worker_count
-
-    main(args=args)
diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py
deleted file mode 100755
index 67106cf..0000000
--- a/funasr/bin/sa_asr_train.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import os
-
-from funasr.tasks.sa_asr import ASRTask
-
-
-# for ASR Training
-def parse_args():
-    parser = ASRTask.get_parser()
-    parser.add_argument(
-        "--gpu_id",
-        type=int,
-        default=0,
-        help="local gpu id.",
-    )
-    args = parser.parse_args()
-    return args
-
-
-def main(args=None, cmd=None):
-    # for ASR Training
-    ASRTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
-    args = parse_args()
-
-    # setup local gpu_id
-    if args.ngpu > 0:
-        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
-    # DDP settings
-    if args.ngpu > 1:
-        args.distributed = True
-    else:
-        args.distributed = False
-    assert args.num_worker_count == 1
-
-    # re-compute batch size: when dataset type is small
-    if args.dataset_type == "small":
-        if args.batch_size is not None and args.ngpu > 0:
-            args.batch_size = args.batch_size * args.ngpu
-        if args.batch_bins is not None and args.ngpu > 0:
-            args.batch_bins = args.batch_bins * args.ngpu
-
-    main(args=args)
diff --git a/funasr/bin/ss_infer.py b/funasr/bin/ss_infer.py
deleted file mode 100644
index a3eca11..0000000
--- a/funasr/bin/ss_infer.py
+++ /dev/null
@@ -1,127 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-
-import logging
-from pathlib import Path
-from typing import List
-from typing import Union
-
-import numpy as np
-import torch
-
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.torch_utils.device_funcs import to_device
-
-
-class SpeechSeparator:
-    """SpeechSeparator class
-
-    Examples:
-        >>> import librosa
-        >>> speech_separator = MossFormer("ss_config.yml", "ss.pt")
-        >>> audio, rate = librosa.load("speech.wav")
-        >>> separated_wavs = speech_separator(audio)        
-
-    """
-
-    def __init__(
-            self,
-            ss_infer_config: Union[Path, str] = None,
-            ss_model_file: Union[Path, str] = None,
-            device: str = "cpu",
-            batch_size: int = 1,
-            dtype: str = "float32",
-            **kwargs,
-    ):
-
-        # 1. Build ss model
-        ss_model, ss_infer_args = build_model_from_file(
-            ss_infer_config, ss_model_file, None, device, task_name="ss"
-        )
-
-        logging.info("ss_model: {}".format(ss_model))
-        logging.info("ss_infer_args: {}".format(ss_infer_args))
-
-        ss_model.to(dtype=getattr(torch, dtype)).eval()
-
-        self.ss_model = ss_model
-        self.ss_infer_args = ss_infer_args
-        self.device = device
-        self.dtype = dtype
-        self.batch_size = batch_size
-
-    def decode(self, model, args, inputs, nsamples):
-        decode_do_segment = False
-        with torch.no_grad():       
-            out = []
-            window = args.sample_rate * args.decode_window  # decoding window length
-            stride = int(window*0.75)  # decoding stride if segmentation is used
-            b, t = inputs.shape
-            if t > window * args.one_time_decode_length:
-                decode_do_segment = True  # set segment decoding to true for very long sequence
-
-            if t < window:
-                inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], window-t))], 1)
-            elif t < window + stride:
-                padding = window + stride - t
-                inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
-            else:
-                if (t - window) % stride != 0:
-                    padding = t - (t-window)//stride * stride
-                    inputs = np.concatenate([inputs, np.zeros((inputs.shape[0], padding))], 1)
-            inputs = torch.from_numpy(np.float32(inputs))
-            inputs = to_device(inputs, device=self.device)
-            b, t = inputs.shape
-            if decode_do_segment:
-                outputs = np.zeros((args.num_spks, t))
-                give_up_length = (window - stride)//2
-                current_idx = 0
-                while current_idx + window <= t:
-                    tmp_input = inputs[:, current_idx:current_idx+window]
-                    tmp_out_list = model(tmp_input,)
-                    for spk in range(args.num_spks):
-                        tmp_out_list[spk] = tmp_out_list[spk][0, :].cpu().numpy()
-                        if current_idx == 0:
-                            outputs[spk, current_idx:current_idx+window-give_up_length] = \
-                                tmp_out_list[spk][:-give_up_length]
-                        else:
-                            outputs[spk, current_idx+give_up_length:current_idx+window-give_up_length] = \
-                                tmp_out_list[spk][give_up_length:-give_up_length]
-                    current_idx += stride
-                for spk in range(args.num_spks):
-                    out.append(outputs[spk, :])
-            else:
-                out_list = model(inputs)
-                for spk in range(args.num_spks):
-                    out.append(out_list[spk][0, :].cpu().numpy())
-
-            max_abs = 0
-            for spk in range(args.num_spks):
-                if max_abs < max(abs(out[spk])):
-                    max_abs = max(abs(out[spk]))
-            for spk in range(args.num_spks):
-                out[spk] = out[spk][:nsamples]
-                out[spk] = out[spk]/max_abs
-
-        return out
-
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
-    ) -> List[torch.Tensor]:
-        """Inference
-
-        Args:
-            speech: Input speech data
-        Returns:
-            speech list: list of speech data
-
-        """
-
-        out = self.decode(self.ss_model, self.ss_infer_args, speech, speech_lengths)
-
-        return out
-
diff --git a/funasr/bin/ss_inference_launch.py b/funasr/bin/ss_inference_launch.py
deleted file mode 100644
index 0c02419..0000000
--- a/funasr/bin/ss_inference_launch.py
+++ /dev/null
@@ -1,258 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-
-import argparse
-import logging
-import os
-import sys
-from typing import Optional
-from typing import Union
-
-import numpy as np
-import torch
-import librosa
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2triple_str
-from funasr.bin.ss_infer import SpeechSeparator
-
-
-def inference_ss(
-        batch_size: int,
-        ngpu: int,
-        log_level: Union[int, str],
-        ss_infer_config: Optional[str],
-        ss_model_file: Optional[str],
-        output_dir: Optional[str] = None,
-        dtype: str = "float32",
-        seed: int = 0,
-        num_workers: int = 1,
-        num_spks: int = 2,
-        sample_rate: int = 8000,
-        param_dict: dict = None,
-        **kwargs,
-):
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-    if batch_size > 1:
-        raise NotImplementedError("batch decoding is not implemented")
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-        batch_size = 1
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech separator
-    speech_separator_kwargs = dict(
-        ss_infer_config=ss_infer_config,
-        ss_model_file=ss_model_file,
-        device=device,
-        dtype=dtype,
-    )
-    logging.info("speech_separator_kwargs: {}".format(speech_separator_kwargs))
-    speech_separator = SpeechSeparator(**speech_separator_kwargs)
-
-    def _forward(
-            data_path_and_name_and_type,
-            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-            output_dir_v2: Optional[str] = None,
-            fs: dict = None,
-            param_dict: dict = None
-    ):
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = build_streaming_iterator(
-            task_name="ss",
-            preprocess_args=None,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            fs=fs,
-            batch_size=batch_size,
-            num_workers=num_workers,
-        )
-
-        # 4 .Start for-loop
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if not os.path.exists(output_path):
-            cmd = 'mkdir -p ' + output_path 
-            os.system(cmd)       
- 
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
-            # do speech separation
-            logging.info('decoding: {}'.format(keys[0]))
-            ss_results = speech_separator(**batch)
-            
-            for spk in range(num_spks):
-                # sf.write(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
-                try:
-                    librosa.output.write_wav(os.path.join(output_path, keys[0] + '_s' + str(spk+1)+'.wav'), ss_results[spk], sample_rate)
-                except:
-                    print("To write wav by librosa, you should install librosa<=0.8.0")
-                    raise
-        torch.cuda.empty_cache()
-        return ss_results
-
-    return _forward
-
-
-def inference_launch(mode, **kwargs):
-    if mode == "mossformer":
-        return inference_ss(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
-
-
-def get_parser():
-    parser = config_argparse.ArgumentParser(
-        description="Speech Separator Decoding",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-
-    # Note(kamo): Use '_' instead of '-' as separator.
-    # '-' is confusing if written in yaml.
-    parser.add_argument(
-        "--log_level",
-        type=lambda x: x.upper(),
-        default="INFO",
-        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-        help="The verbose level of logging",
-    )
-
-    parser.add_argument("--output_dir", type=str, required=True)
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=1,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument(
-        "--njob",
-        type=int,
-        default=1,
-        help="The number of jobs for each gpu",
-    )
-    parser.add_argument(
-        "--gpuid_list",
-        type=str,
-        default="2",
-        help="The visible gpus",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument(
-        "--dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type",
-    )
-    parser.add_argument(
-        "--num_workers",
-        type=int,
-        default=1,
-        help="The number of workers used for DataLoader",
-    )
-
-    group = parser.add_argument_group("Input data related")
-    group.add_argument(
-        "--data_path_and_name_and_type",
-        type=str2triple_str,
-        required=True,
-        action="append",
-    )
-
-    group = parser.add_argument_group("The model configuration related")
-    group.add_argument(
-        "--ss_infer_config",
-        type=str,
-        help="SS infer configuration",
-    )
-    group.add_argument(
-        "--ss_model_file",
-        type=str,
-        help="SS model parameter file",
-    )
-    group.add_argument(
-        "--ss_train_config",
-        type=str,
-        help="SS training configuration",
-    )
-
-    group = parser.add_argument_group("The inference configuration related")
-    group.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="The batch size for inference",
-    )
-
-    parser.add_argument(
-        '--num-spks', dest='num_spks', type=int, default=2)
-
-    parser.add_argument(
-        '--one-time-decode-length', dest='one_time_decode_length', type=int,
-        default=60, help='the max length (second) for one-time decoding')
-
-    parser.add_argument(
-        '--decode-window', dest='decode_window', type=int,
-        default=1, help='segmental decoding window length (second)')
-
-    parser.add_argument(
-        '--sample-rate', dest='sample_rate', type=int, default='8000')
-    return parser
-
-
-def main(cmd=None):
-    print(get_commandline_args(), file=sys.stderr)
-    parser = get_parser()
-    parser.add_argument(
-        "--mode",
-        type=str,
-        default="mossformer",
-        help="The decoding mode",
-    )
-    args = parser.parse_args(cmd)
-    kwargs = vars(args)
-    kwargs.pop("config", None)
-
-    # set logging messages
-    logging.basicConfig(
-        level=args.log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-    logging.info("Decoding args: {}".format(kwargs))
-
-    # gpu setting
-    if args.ngpu > 0:
-        jobid = int(args.output_dir.split(".")[-1])
-        gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
-        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
-        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
-    inference_pipeline = inference_launch(**kwargs)
-    return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
-    main()
-
diff --git a/funasr/bin/sv_infer.py b/funasr/bin/sv_infer.py
deleted file mode 100755
index 19cfc2e..0000000
--- a/funasr/bin/sv_infer.py
+++ /dev/null
@@ -1,116 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import logging
-from pathlib import Path
-from typing import Any
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.torch_utils.device_funcs import to_device
-from funasr.utils.misc import statistic_model_parameters
-
-
-class Speech2Xvector:
-    """Speech2Xvector class
-
-    Examples:
-        >>> import librosa
-        >>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
-        >>> audio, rate = librosa.load("speech.wav")
-        >>> speech2xvector(audio)
-        [(text, token, token_int, hypothesis object), ...]
-
-    """
-
-    def __init__(
-            self,
-            sv_train_config: Union[Path, str] = None,
-            sv_model_file: Union[Path, str] = None,
-            device: str = "cpu",
-            batch_size: int = 1,
-            dtype: str = "float32",
-            streaming: bool = False,
-            embedding_node: str = "resnet1_dense",
-    ):
-
-        # TODO: 1. Build SV model
-        sv_model, sv_train_args = build_model_from_file(
-            config_file=sv_train_config,
-            model_file=sv_model_file,
-            cmvn_file=None,
-            device=device,
-            task_name="sv",
-            mode="sv",
-        )
-        logging.info("sv_model: {}".format(sv_model))
-        logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
-        logging.info("sv_train_args: {}".format(sv_train_args))
-        sv_model.to(dtype=getattr(torch, dtype)).eval()
-
-        self.sv_model = sv_model
-        self.sv_train_args = sv_train_args
-        self.device = device
-        self.dtype = dtype
-        self.embedding_node = embedding_node
-
-    @torch.no_grad()
-    def calculate_embedding(self, speech: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-
-        # data: (Nsamples,) -> (1, Nsamples)
-        speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
-        # lengths: (1,)
-        lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
-        batch = {"speech": speech, "speech_lengths": lengths}
-
-        # a. To device
-        batch = to_device(batch, device=self.device)
-
-        # b. Forward Encoder
-        enc, ilens = self.sv_model.encode(**batch)
-
-        # c. Forward Pooling
-        pooling = self.sv_model.pooling_layer(enc)
-
-        # d. Forward Decoder
-        outputs, embeddings = self.sv_model.decoder(pooling)
-
-        if self.embedding_node not in embeddings:
-            raise ValueError("Required embedding node {} not in {}".format(
-                self.embedding_node, embeddings.keys()))
-
-        return embeddings[self.embedding_node]
-
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray],
-            ref_speech: Optional[Union[torch.Tensor, np.ndarray]] = None,
-    ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Union[torch.Tensor, None]]:
-        """Inference
-
-        Args:
-            speech: Input speech data
-            ref_speech: Reference speech to compare
-        Returns:
-            embedding, ref_embedding, similarity_score
-
-        """
-        self.sv_model.eval()
-        embedding = self.calculate_embedding(speech)
-        ref_emb, score = None, None
-        if ref_speech is not None:
-            ref_emb = self.calculate_embedding(ref_speech)
-            score = torch.cosine_similarity(embedding, ref_emb)
-
-        results = (embedding, ref_emb, score)
-        return results
diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py
deleted file mode 100755
index 2f9e276..0000000
--- a/funasr/bin/sv_inference_launch.py
+++ /dev/null
@@ -1,309 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import os
-import sys
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-from kaldiio import WriteHelper
-
-from funasr.bin.sv_infer import Speech2Xvector
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-def inference_sv(
-        output_dir: Optional[str] = None,
-        batch_size: int = 1,
-        dtype: str = "float32",
-        ngpu: int = 1,
-        seed: int = 0,
-        num_workers: int = 0,
-        log_level: Union[int, str] = "INFO",
-        key_file: Optional[str] = None,
-        sv_train_config: Optional[str] = "sv.yaml",
-        sv_model_file: Optional[str] = "sv.pb",
-        model_tag: Optional[str] = None,
-        allow_variable_data_keys: bool = True,
-        streaming: bool = False,
-        embedding_node: str = "resnet1_dense",
-        sv_threshold: float = 0.9465,
-        param_dict: Optional[dict] = None,
-        **kwargs,
-):
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-
-    if batch_size > 1:
-        raise NotImplementedError("batch decoding is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-    logging.info("param_dict: {}".format(param_dict))
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2xvector
-    speech2xvector_kwargs = dict(
-        sv_train_config=sv_train_config,
-        sv_model_file=sv_model_file,
-        device=device,
-        dtype=dtype,
-        streaming=streaming,
-        embedding_node=embedding_node
-    )
-    logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
-    speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
-    speech2xvector.sv_model.eval()
-
-    def _forward(
-            data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
-            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-            output_dir_v2: Optional[str] = None,
-            param_dict: Optional[dict] = None,
-    ):
-        logging.info("param_dict: {}".format(param_dict))
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-
-        # 3. Build data-iterator
-        loader = build_streaming_iterator(
-            task_name="sv",
-            preprocess_args=None,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-            use_collate_fn=False,
-        )
-
-        # 7 .Start for-loop
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        embd_writer, ref_embd_writer, score_writer = None, None, None
-        if output_path is not None:
-            os.makedirs(output_path, exist_ok=True)
-            embd_writer = WriteHelper("ark,scp:{}/xvector.ark,{}/xvector.scp".format(output_path, output_path))
-        sv_result_list = []
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
-            embedding, ref_embedding, score = speech2xvector(**batch)
-            # Only supporting batch_size==1
-            key = keys[0]
-            normalized_score = 0.0
-            if score is not None:
-                score = score.item()
-                normalized_score = max(score - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
-                item = {"key": key, "value": normalized_score}
-            else:
-                item = {"key": key, "value": embedding.squeeze(0).cpu().numpy()}
-            sv_result_list.append(item)
-            if output_path is not None:
-                embd_writer(key, embedding[0].cpu().numpy())
-                if ref_embedding is not None:
-                    if ref_embd_writer is None:
-                        ref_embd_writer = WriteHelper(
-                            "ark,scp:{}/ref_xvector.ark,{}/ref_xvector.scp".format(output_path, output_path)
-                        )
-                        score_writer = open(os.path.join(output_path, "score.txt"), "w")
-                    ref_embd_writer(key, ref_embedding[0].cpu().numpy())
-                    score_writer.write("{} {:.6f}\n".format(key, normalized_score))
-
-        if output_path is not None:
-            embd_writer.close()
-            if ref_embd_writer is not None:
-                ref_embd_writer.close()
-                score_writer.close()
-
-        return sv_result_list
-
-    return _forward
-
-
-def inference_launch(mode, **kwargs):
-    if mode == "sv":
-        return inference_sv(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
-
-
-def get_parser():
-    parser = config_argparse.ArgumentParser(
-        description="Speaker Verification",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-
-    # Note(kamo): Use '_' instead of '-' as separator.
-    # '-' is confusing if written in yaml.
-    parser.add_argument(
-        "--log_level",
-        type=lambda x: x.upper(),
-        default="INFO",
-        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-        help="The verbose level of logging",
-    )
-
-    parser.add_argument("--output_dir", type=str, required=False)
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=0,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument(
-        "--njob",
-        type=int,
-        default=1,
-        help="The number of jobs for each gpu",
-    )
-    parser.add_argument(
-        "--gpuid_list",
-        type=str,
-        default="",
-        help="The visible gpus",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument(
-        "--dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type",
-    )
-    parser.add_argument(
-        "--num_workers",
-        type=int,
-        default=1,
-        help="The number of workers used for DataLoader",
-    )
-
-    group = parser.add_argument_group("Input data related")
-    group.add_argument(
-        "--data_path_and_name_and_type",
-        type=str2triple_str,
-        required=False,
-        action="append",
-    )
-    group.add_argument("--key_file", type=str_or_none)
-    group.add_argument("--allow_variable_data_keys", type=str2bool, default=True)
-
-    group = parser.add_argument_group("The model configuration related")
-    group.add_argument(
-        "--vad_infer_config",
-        type=str,
-        help="VAD infer configuration",
-    )
-    group.add_argument(
-        "--vad_model_file",
-        type=str,
-        help="VAD model parameter file",
-    )
-    group.add_argument(
-        "--sv_train_config",
-        type=str,
-        help="ASR training configuration",
-    )
-    group.add_argument(
-        "--sv_model_file",
-        type=str,
-        help="ASR model parameter file",
-    )
-    group.add_argument(
-        "--cmvn_file",
-        type=str,
-        help="Global CMVN file",
-    )
-    group.add_argument(
-        "--model_tag",
-        type=str,
-        help="Pretrained model tag. If specify this option, *_train_config and "
-             "*_file will be overwritten",
-    )
-
-    group = parser.add_argument_group("The inference configuration related")
-    group.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="The batch size for inference",
-    )
-    group.add_argument(
-        "--sv_threshold",
-        type=float,
-        default=0.9465,
-        help="The threshold for verification"
-    )
-    parser.add_argument(
-        "--embedding_node",
-        type=str,
-        default="resnet1_dense",
-        help="The network node to extract embedding"
-    )
-
-    return parser
-
-
-def main(cmd=None):
-    print(get_commandline_args(), file=sys.stderr)
-    parser = get_parser()
-    parser.add_argument(
-        "--mode",
-        type=str,
-        default="sv",
-        help="The decoding mode",
-    )
-    args = parser.parse_args(cmd)
-    kwargs = vars(args)
-    kwargs.pop("config", None)
-
-    # set logging messages
-    logging.basicConfig(
-        level=args.log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-    logging.info("Decoding args: {}".format(kwargs))
-
-    # gpu setting
-    if args.ngpu > 0:
-        jobid = int(args.output_dir.split(".")[-1])
-        gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
-        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
-        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
-    inference_pipeline = inference_launch(**kwargs)
-    return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
-    main()
diff --git a/funasr/bin/tp_infer.py b/funasr/bin/tp_infer.py
deleted file mode 100644
index cfe534f..0000000
--- a/funasr/bin/tp_infer.py
+++ /dev/null
@@ -1,92 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import logging
-from pathlib import Path
-from typing import Union
-
-import numpy as np
-import torch
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.tokenizer.token_id_converter import TokenIDConverter
-from funasr.torch_utils.device_funcs import to_device
-
-
-class Speech2Timestamp:
-    def __init__(
-            self,
-            timestamp_infer_config: Union[Path, str] = None,
-            timestamp_model_file: Union[Path, str] = None,
-            timestamp_cmvn_file: Union[Path, str] = None,
-            device: str = "cpu",
-            dtype: str = "float32",
-            **kwargs,
-    ):
-        # 1. Build ASR model
-        tp_model, tp_train_args = build_model_from_file(
-            timestamp_infer_config, timestamp_model_file, cmvn_file=None, device=device, task_name="asr", mode="tp"
-        )
-        if 'cuda' in device:
-            tp_model = tp_model.cuda()  # force model to cuda
-
-        frontend = None
-        if tp_train_args.frontend is not None:
-            frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf)
-
-        logging.info("tp_model: {}".format(tp_model))
-        logging.info("tp_train_args: {}".format(tp_train_args))
-        tp_model.to(dtype=getattr(torch, dtype)).eval()
-
-        logging.info(f"Decoding device={device}, dtype={dtype}")
-
-        self.tp_model = tp_model
-        self.tp_train_args = tp_train_args
-
-        token_list = self.tp_model.token_list
-        self.converter = TokenIDConverter(token_list=token_list)
-
-        self.device = device
-        self.dtype = dtype
-        self.frontend = frontend
-        self.encoder_downsampling_factor = 1
-        if tp_train_args.encoder_conf["input_layer"] == "conv2d":
-            self.encoder_downsampling_factor = 4
-
-    @torch.no_grad()
-    def __call__(
-            self,
-            speech: Union[torch.Tensor, np.ndarray],
-            speech_lengths: Union[torch.Tensor, np.ndarray] = None,
-            text_lengths: Union[torch.Tensor, np.ndarray] = None
-    ):
-
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-        if self.frontend is not None:
-            feats, feats_len = self.frontend.forward(speech, speech_lengths)
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
-            self.tp_model.frontend = None
-        else:
-            feats = speech
-            feats_len = speech_lengths
-
-        # lfr_factor = max(1, (feats.size()[-1]//80)-1)
-        batch = {"speech": feats, "speech_lengths": feats_len}
-
-        # a. To device
-        batch = to_device(batch, device=self.device)
-
-        # b. Forward Encoder
-        enc, enc_len = self.tp_model.encode(**batch)
-        if isinstance(enc, tuple):
-            enc = enc[0]
-
-        # c. Forward Predictor
-        _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len,
-                                                                           text_lengths.to(self.device) + 1)
-        return us_alphas, us_peaks
diff --git a/funasr/bin/tp_inference_launch.py b/funasr/bin/tp_inference_launch.py
deleted file mode 100644
index 6c10254..0000000
--- a/funasr/bin/tp_inference_launch.py
+++ /dev/null
@@ -1,287 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-
-import argparse
-import logging
-import os
-import sys
-from typing import Optional
-from typing import Union
-
-import numpy as np
-import torch
-
-from funasr.bin.tp_infer import Speech2Timestamp
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-def inference_tp(
-        batch_size: int,
-        ngpu: int,
-        log_level: Union[int, str],
-        # data_path_and_name_and_type,
-        timestamp_infer_config: Optional[str],
-        timestamp_model_file: Optional[str],
-        timestamp_cmvn_file: Optional[str] = None,
-        # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-        key_file: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        output_dir: Optional[str] = None,
-        dtype: str = "float32",
-        seed: int = 0,
-        num_workers: int = 1,
-        split_with_space: bool = True,
-        seg_dict_file: Optional[str] = None,
-        **kwargs,
-):
-    ncpu = kwargs.get("ncpu", 1)
-    torch.set_num_threads(ncpu)
-
-    if batch_size > 1:
-        raise NotImplementedError("batch decoding is not implemented")
-    if ngpu > 1:
-        raise NotImplementedError("only single GPU decoding is supported")
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2vadsegment
-    speechtext2timestamp_kwargs = dict(
-        timestamp_infer_config=timestamp_infer_config,
-        timestamp_model_file=timestamp_model_file,
-        timestamp_cmvn_file=timestamp_cmvn_file,
-        device=device,
-        dtype=dtype,
-    )
-    logging.info("speechtext2timestamp_kwargs: {}".format(speechtext2timestamp_kwargs))
-    speechtext2timestamp = Speech2Timestamp(**speechtext2timestamp_kwargs)
-
-    preprocessor = LMPreprocessor(
-        train=False,
-        token_type=speechtext2timestamp.tp_train_args.token_type,
-        token_list=speechtext2timestamp.tp_train_args.token_list,
-        bpemodel=None,
-        text_cleaner=None,
-        g2p_type=None,
-        text_name="text",
-        non_linguistic_symbols=speechtext2timestamp.tp_train_args.non_linguistic_symbols,
-        split_with_space=split_with_space,
-        seg_dict_file=seg_dict_file,
-    )
-
-    if output_dir is not None:
-        writer = DatadirWriter(output_dir)
-        tp_writer = writer[f"timestamp_prediction"]
-        # ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
-    else:
-        tp_writer = None
-
-    def _forward(
-            data_path_and_name_and_type,
-            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-            output_dir_v2: Optional[str] = None,
-            fs: dict = None,
-            param_dict: dict = None,
-            **kwargs
-    ):
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        writer = None
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-            tp_writer = writer[f"timestamp_prediction"]
-        else:
-            tp_writer = None
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-
-        loader = build_streaming_iterator(
-            task_name="asr",
-            preprocess_args=speechtext2timestamp.tp_train_args,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-            preprocess_fn=preprocessor,
-        )
-
-        tp_result_list = []
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
-            logging.info("timestamp predicting, utt_id: {}".format(keys))
-            _batch = {'speech': batch['speech'],
-                      'speech_lengths': batch['speech_lengths'],
-                      'text_lengths': batch['text_lengths']}
-            us_alphas, us_cif_peak = speechtext2timestamp(**_batch)
-
-            for batch_id in range(_bs):
-                key = keys[batch_id]
-                token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
-                ts_str, ts_list = ts_prediction_lfr6_standard(us_alphas[batch_id], us_cif_peak[batch_id], token,
-                                                              force_time_shift=-3.0)
-                logging.warning(ts_str)
-                item = {'key': key, 'value': ts_str, 'timestamp': ts_list}
-                if tp_writer is not None:
-                    tp_writer["tp_sync"][key + '#'] = ts_str
-                    tp_writer["tp_time"][key + '#'] = str(ts_list)
-                tp_result_list.append(item)
-        return tp_result_list
-
-    return _forward
-
-
-def inference_launch(mode, **kwargs):
-    if mode == "tp_norm":
-        return inference_tp(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
-
-
-def get_parser():
-    parser = config_argparse.ArgumentParser(
-        description="Timestamp Prediction Inference",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-
-    # Note(kamo): Use '_' instead of '-' as separator.
-    # '-' is confusing if written in yaml.
-    parser.add_argument(
-        "--log_level",
-        type=lambda x: x.upper(),
-        default="INFO",
-        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-        help="The verbose level of logging",
-    )
-
-    parser.add_argument("--output_dir", type=str, required=False)
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=0,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument(
-        "--njob",
-        type=int,
-        default=1,
-        help="The number of jobs for each gpu",
-    )
-    parser.add_argument(
-        "--gpuid_list",
-        type=str,
-        default="",
-        help="The visible gpus",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument(
-        "--dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type",
-    )
-    parser.add_argument(
-        "--num_workers",
-        type=int,
-        default=1,
-        help="The number of workers used for DataLoader",
-    )
-
-    group = parser.add_argument_group("Input data related")
-    group.add_argument(
-        "--data_path_and_name_and_type",
-        type=str2triple_str,
-        required=True,
-        action="append",
-    )
-    group.add_argument("--key_file", type=str_or_none)
-    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
-    group = parser.add_argument_group("The model configuration related")
-    group.add_argument(
-        "--timestamp_infer_config",
-        type=str,
-        help="VAD infer configuration",
-    )
-    group.add_argument(
-        "--timestamp_model_file",
-        type=str,
-        help="VAD model parameter file",
-    )
-    group.add_argument(
-        "--timestamp_cmvn_file",
-        type=str,
-        help="Global CMVN file",
-    )
-
-    group = parser.add_argument_group("The inference configuration related")
-    group.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="The batch size for inference",
-    )
-    return parser
-
-
-def main(cmd=None):
-    print(get_commandline_args(), file=sys.stderr)
-    parser = get_parser()
-    parser.add_argument(
-        "--mode",
-        type=str,
-        default="tp_norm",
-        help="The decoding mode",
-    )
-    args = parser.parse_args(cmd)
-    kwargs = vars(args)
-    kwargs.pop("config", None)
-
-    # set logging messages
-    logging.basicConfig(
-        level=args.log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-    logging.info("Decoding args: {}".format(kwargs))
-
-    # gpu setting
-    if args.ngpu > 0:
-        jobid = int(args.output_dir.split(".")[-1])
-        gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
-        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
-        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
-    inference_pipeline = inference_launch(**kwargs)
-    return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
-    main()
diff --git a/funasr/cli/train_cli.py b/funasr/bin/train.py
similarity index 96%
rename from funasr/cli/train_cli.py
rename to funasr/bin/train.py
index a22d5d4..4187476 100644
--- a/funasr/cli/train_cli.py
+++ b/funasr/bin/train.py
@@ -19,18 +19,13 @@
 # from funasr.tokenizer.token_id_converter import TokenIDConverter
 from funasr.tokenizer.funtoken import build_tokenizer
 from funasr.datasets.dataset_jsonl import AudioDataset
-from funasr.cli.trainer import Trainer
+from funasr.utils.trainer import Trainer
 # from funasr.utils.load_fr_py import load_class_from_path
 from funasr.utils.dynamic_import import dynamic_import
 import torch.distributed as dist
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 from funasr.utils.download_from_hub import download_model
-
-def preprocess_config(cfg: DictConfig):
-	for key, value in cfg.items():
-		if value == 'None':
-			cfg[key] = None
 
 @hydra.main(config_name=None, version_base=None)
 def main_hydra(kwargs: DictConfig):
diff --git a/funasr/bin/vad_infer.py b/funasr/bin/vad_infer.py
deleted file mode 100644
index 5763873..0000000
--- a/funasr/bin/vad_infer.py
+++ /dev/null
@@ -1,180 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import logging
-import math
-from pathlib import Path
-from typing import Dict
-from typing import List
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-
-from funasr.build_utils.build_model_from_file import build_model_from_file
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.torch_utils.device_funcs import to_device
-
-
-class Speech2VadSegment:
-    """Speech2VadSegment class
-
-    Examples:
-        >>> import librosa
-        >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
-        >>> audio, rate = librosa.load("speech.wav")
-        >>> speech2segment(audio)
-        [[10, 230], [245, 450], ...]
-
-    """
-
-    def __init__(
-            self,
-            vad_infer_config: Union[Path, str] = None,
-            vad_model_file: Union[Path, str] = None,
-            vad_cmvn_file: Union[Path, str] = None,
-            device: str = "cpu",
-            batch_size: int = 1,
-            dtype: str = "float32",
-            **kwargs,
-    ):
-
-        # 1. Build vad model
-        vad_model, vad_infer_args = build_model_from_file(
-            vad_infer_config, vad_model_file, None, device, task_name="vad"
-        )
-        frontend = None
-        if vad_infer_args.frontend is not None:
-            frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
-
-        logging.info("vad_model: {}".format(vad_model))
-        logging.info("vad_infer_args: {}".format(vad_infer_args))
-        vad_model.to(dtype=getattr(torch, dtype)).eval()
-
-        self.vad_model = vad_model
-        self.vad_infer_args = vad_infer_args
-        self.device = device
-        self.dtype = dtype
-        self.frontend = frontend
-        self.batch_size = batch_size
-
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
-            in_cache: Dict[str, torch.Tensor] = dict()
-    ) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]:
-        """Inference
-
-        Args:
-            speech: Input speech data
-        Returns:
-            text, token, token_int, hyp
-
-        """
-
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-
-        if self.frontend is not None:
-            self.frontend.filter_length_max = math.inf
-            fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
-            feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
-            fbanks = to_device(fbanks, device=self.device)
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
-        else:
-            raise Exception("Need to extract feats first, please configure frontend configuration")
-
-        # b. Forward Encoder streaming
-        t_offset = 0
-        step = min(feats_len.max(), 6000)
-        segments = [[]] * self.batch_size
-        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
-            if t_offset + step >= feats_len - 1:
-                step = feats_len - t_offset
-                is_final = True
-            else:
-                is_final = False
-            batch = {
-                "feats": feats[:, t_offset:t_offset + step, :],
-                "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
-                "is_final": is_final,
-                "in_cache": in_cache
-            }
-            # a. To device
-            # batch = to_device(batch, device=self.device)
-            segments_part, in_cache = self.vad_model(**batch)
-            if segments_part:
-                for batch_num in range(0, self.batch_size):
-                    segments[batch_num] += segments_part[batch_num]
-        return fbanks, segments
-
-
-class Speech2VadSegmentOnline(Speech2VadSegment):
-    """Speech2VadSegmentOnline class
-
-    Examples:
-        >>> import librosa
-        >>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
-        >>> audio, rate = librosa.load("speech.wav")
-        >>> speech2segment(audio)
-        [[10, 230], [245, 450], ...]
-
-    """
-
-    def __init__(self, **kwargs):
-        super(Speech2VadSegmentOnline, self).__init__(**kwargs)
-        vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
-        self.frontend = None
-        if self.vad_infer_args.frontend is not None:
-            self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
-
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
-            in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
-    ) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
-        """Inference
-
-        Args:
-            speech: Input speech data
-        Returns:
-            text, token, token_int, hyp
-
-        """
-
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-        batch_size = speech.shape[0]
-        segments = [[]] * batch_size
-        if self.frontend is not None:
-            reset = in_cache == dict()
-            feats, feats_len = self.frontend.forward(speech, speech_lengths, is_final, reset)
-            fbanks, _ = self.frontend.get_fbank()
-        else:
-            raise Exception("Need to extract feats first, please configure frontend configuration")
-        if feats.shape[0]:
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
-            waveforms = self.frontend.get_waveforms()
-            if max_end_sil == 800 and self.vad_infer_args.vad_post_conf["max_end_silence_time"] != 800:
-                max_end_sil = self.vad_infer_args.vad_post_conf["max_end_silence_time"]
-
-            batch = {
-                "feats": feats,
-                "waveform": waveforms,
-                "in_cache": in_cache,
-                "is_final": is_final,
-                "max_end_sil": max_end_sil
-            }
-            # a. To device
-            batch = to_device(batch, device=self.device)
-            segments, in_cache = self.vad_model.forward_online(**batch)
-            # in_cache.update(batch['in_cache'])
-            # in_cache = {key: value for key, value in batch['in_cache'].items()}
-        return fbanks, segments, in_cache
diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py
deleted file mode 100644
index a031a5a..0000000
--- a/funasr/bin/vad_inference_launch.py
+++ /dev/null
@@ -1,379 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-
-import torch
-
-torch.set_num_threads(1)
-
-import argparse
-import logging
-import os
-import sys
-import json
-from typing import Optional
-from typing import Union
-
-import numpy as np
-import torch
-from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline
-
-
-def inference_vad(
-        batch_size: int,
-        ngpu: int,
-        log_level: Union[int, str],
-        # data_path_and_name_and_type,
-        vad_infer_config: Optional[str],
-        vad_model_file: Optional[str],
-        vad_cmvn_file: Optional[str] = None,
-        # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-        key_file: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        output_dir: Optional[str] = None,
-        dtype: str = "float32",
-        seed: int = 0,
-        num_workers: int = 1,
-        **kwargs,
-):
-    if batch_size > 1:
-        raise NotImplementedError("batch decoding is not implemented")
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-        batch_size = 1
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2vadsegment
-    speech2vadsegment_kwargs = dict(
-        vad_infer_config=vad_infer_config,
-        vad_model_file=vad_model_file,
-        vad_cmvn_file=vad_cmvn_file,
-        device=device,
-        dtype=dtype,
-    )
-    logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
-    speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
-
-    def _forward(
-            data_path_and_name_and_type,
-            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-            output_dir_v2: Optional[str] = None,
-            fs: dict = None,
-            param_dict: dict = None
-    ):
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = build_streaming_iterator(
-            task_name="vad",
-            preprocess_args=None,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            fs=fs,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-        )
-
-        finish_count = 0
-        file_count = 1
-        # 7 .Start for-loop
-        # FIXME(kamo): The output format should be discussed about
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-            ibest_writer = writer[f"1best_recog"]
-        else:
-            writer = None
-            ibest_writer = None
-
-        vad_results = []
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
-            # do vad segment
-            _, results = speech2vadsegment(**batch)
-            for i, _ in enumerate(keys):
-                if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
-                    results[i] = json.dumps(results[i])
-                item = {'key': keys[i], 'value': results[i]}
-                vad_results.append(item)
-                if writer is not None:
-                    ibest_writer["text"][keys[i]] = "{}".format(results[i])
-        torch.cuda.empty_cache()
-        return vad_results
-
-    return _forward
-
-
-def inference_vad_online(
-        batch_size: int,
-        ngpu: int,
-        log_level: Union[int, str],
-        # data_path_and_name_and_type,
-        vad_infer_config: Optional[str],
-        vad_model_file: Optional[str],
-        vad_cmvn_file: Optional[str] = None,
-        # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-        key_file: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        output_dir: Optional[str] = None,
-        dtype: str = "float32",
-        seed: int = 0,
-        num_workers: int = 1,
-        **kwargs,
-):
-
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-
-    if ngpu >= 1 and torch.cuda.is_available():
-        device = "cuda"
-    else:
-        device = "cpu"
-        batch_size = 1
-
-    # 1. Set random-seed
-    set_all_random_seed(seed)
-
-    # 2. Build speech2vadsegment
-    speech2vadsegment_kwargs = dict(
-        vad_infer_config=vad_infer_config,
-        vad_model_file=vad_model_file,
-        vad_cmvn_file=vad_cmvn_file,
-        device=device,
-        dtype=dtype,
-    )
-    logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
-    speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
-
-    def _forward(
-            data_path_and_name_and_type,
-            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-            output_dir_v2: Optional[str] = None,
-            fs: dict = None,
-            param_dict: dict = None,
-    ):
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = build_streaming_iterator(
-            task_name="vad",
-            preprocess_args=None,
-            data_path_and_name_and_type=data_path_and_name_and_type,
-            dtype=dtype,
-            fs=fs,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-        )
-
-        finish_count = 0
-        file_count = 1
-        # 7 .Start for-loop
-        # FIXME(kamo): The output format should be discussed about
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-            ibest_writer = writer[f"1best_recog"]
-        else:
-            writer = None
-            ibest_writer = None
-
-        vad_results = []
-        if param_dict is None:
-            param_dict = dict()
-            param_dict['in_cache'] = dict()
-            param_dict['is_final'] = True
-        batch_in_cache = param_dict.get('in_cache', dict())
-        is_final = param_dict.get('is_final', False)
-        max_end_sil = param_dict.get('max_end_sil', 800)
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            batch['in_cache'] = batch_in_cache
-            batch['is_final'] = is_final
-            batch['max_end_sil'] = max_end_sil
-
-            # do vad segment
-            _, results, param_dict['in_cache'] = speech2vadsegment(**batch)
-            # param_dict['in_cache'] = batch['in_cache']
-            if results:
-                for i, _ in enumerate(keys):
-                    if results[i]:
-                        if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
-                            results[i] = json.dumps(results[i])
-                        item = {'key': keys[i], 'value': results[i]}
-                        vad_results.append(item)
-                        if writer is not None:
-                            ibest_writer["text"][keys[i]] = "{}".format(results[i])
-
-        return vad_results
-
-    return _forward
-
-
-def inference_launch(mode, **kwargs):
-    if mode == "offline":
-        return inference_vad(**kwargs)
-    elif mode == "online":
-        return inference_vad_online(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
-
-
-def get_parser():
-    parser = config_argparse.ArgumentParser(
-        description="VAD Decoding",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-
-    # Note(kamo): Use '_' instead of '-' as separator.
-    # '-' is confusing if written in yaml.
-    parser.add_argument(
-        "--log_level",
-        type=lambda x: x.upper(),
-        default="INFO",
-        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-        help="The verbose level of logging",
-    )
-
-    parser.add_argument("--output_dir", type=str, required=True)
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=0,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument(
-        "--njob",
-        type=int,
-        default=1,
-        help="The number of jobs for each gpu",
-    )
-    parser.add_argument(
-        "--gpuid_list",
-        type=str,
-        default="",
-        help="The visible gpus",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument(
-        "--dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type",
-    )
-    parser.add_argument(
-        "--num_workers",
-        type=int,
-        default=1,
-        help="The number of workers used for DataLoader",
-    )
-
-    group = parser.add_argument_group("Input data related")
-    group.add_argument(
-        "--data_path_and_name_and_type",
-        type=str2triple_str,
-        required=True,
-        action="append",
-    )
-    group.add_argument("--key_file", type=str_or_none)
-    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
-    group = parser.add_argument_group("The model configuration related")
-    group.add_argument(
-        "--vad_infer_config",
-        type=str,
-        help="VAD infer configuration",
-    )
-    group.add_argument(
-        "--vad_model_file",
-        type=str,
-        help="VAD model parameter file",
-    )
-    group.add_argument(
-        "--vad_cmvn_file",
-        type=str,
-        help="Global CMVN file",
-    )
-    group.add_argument(
-        "--vad_train_config",
-        type=str,
-        help="VAD training configuration",
-    )
-
-    group = parser.add_argument_group("The inference configuration related")
-    group.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="The batch size for inference",
-    )
-    return parser
-
-
-def main(cmd=None):
-    print(get_commandline_args(), file=sys.stderr)
-    parser = get_parser()
-    parser.add_argument(
-        "--mode",
-        type=str,
-        default="vad",
-        help="The decoding mode",
-    )
-    args = parser.parse_args(cmd)
-    kwargs = vars(args)
-    kwargs.pop("config", None)
-
-    # set logging messages
-    logging.basicConfig(
-        level=args.log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
-    logging.info("Decoding args: {}".format(kwargs))
-
-    # gpu setting
-    if args.ngpu > 0:
-        jobid = int(args.output_dir.split(".")[-1])
-        gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
-        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
-        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
-    inference_pipeline = inference_launch(**kwargs)
-    return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
-
-if __name__ == "__main__":
-    main()
diff --git a/funasr/cli/__init__.py b/funasr/cli/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/cli/__init__.py
+++ /dev/null
diff --git a/funasr/cli/models/__init__.py b/funasr/cli/models/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/cli/models/__init__.py
+++ /dev/null
diff --git a/funasr/cli/models/paraformer.py b/funasr/cli/models/paraformer.py
deleted file mode 100644
index 7ca80f5..0000000
--- a/funasr/cli/models/paraformer.py
+++ /dev/null
@@ -1,655 +0,0 @@
-import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-
-import torch
-import torch.nn as nn
-import random
-import numpy as np
-
-# from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
-    LabelSmoothingLoss,  # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.modules.add_sos_eos import add_sos_eos
-from funasr.modules.nets_utils import make_pad_mask, pad_list
-from funasr.modules.nets_utils import th_accuracy
-from funasr.torch_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
-
-from funasr.cli.model_class_factory import *
-
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
-	from torch.cuda.amp import autocast
-else:
-	# Nothing to do if torch<1.6.0
-	@contextmanager
-	def autocast(enabled=True):
-		yield
-
-
-class Paraformer(nn.Module):
-	"""
-	Author: Speech Lab of DAMO Academy, Alibaba Group
-	Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
-	https://arxiv.org/abs/2206.08317
-	"""
-	
-	def __init__(
-		self,
-		# token_list: Union[Tuple[str, ...], List[str]],
-		frontend: Optional[str] = None,
-		frontend_conf: Optional[Dict] = None,
-		specaug: Optional[str] = None,
-		specaug_conf: Optional[Dict] = None,
-		normalize: str = None,
-		normalize_conf: Optional[Dict] = None,
-		encoder: str = None,
-		encoder_conf: Optional[Dict] = None,
-		decoder: str = None,
-		decoder_conf: Optional[Dict] = None,
-		ctc: str = None,
-		ctc_conf: Optional[Dict] = None,
-		predictor: str = None,
-		predictor_conf: Optional[Dict] = None,
-		ctc_weight: float = 0.5,
-		interctc_weight: float = 0.0,
-		input_size: int = 80,
-		vocab_size: int = -1,
-		ignore_id: int = -1,
-		blank_id: int = 0,
-		sos: int = 1,
-		eos: int = 2,
-		lsm_weight: float = 0.0,
-		length_normalized_loss: bool = False,
-		# report_cer: bool = True,
-		# report_wer: bool = True,
-		# sym_space: str = "<space>",
-		# sym_blank: str = "<blank>",
-		# extract_feats_in_collect_stats: bool = True,
-		# predictor=None,
-		predictor_weight: float = 0.0,
-		predictor_bias: int = 0,
-		sampling_ratio: float = 0.2,
-		share_embedding: bool = False,
-		# preencoder: Optional[AbsPreEncoder] = None,
-		# postencoder: Optional[AbsPostEncoder] = None,
-		use_1st_decoder_loss: bool = False,
-		**kwargs,
-	):
-		assert 0.0 <= ctc_weight <= 1.0, ctc_weight
-		assert 0.0 <= interctc_weight < 1.0, interctc_weight
-		
-		super().__init__()
-		
-		# import pdb;
-		# pdb.set_trace()
-		
-		if frontend is not None:
-			frontend_class = frontend_choices.get_class(frontend)
-			frontend = frontend_class(**frontend_conf)
-		if specaug is not None:
-			specaug_class = specaug_choices.get_class(specaug)
-			specaug = specaug_class(**specaug_conf)
-		if normalize is not None:
-			normalize_class = normalize_choices.get_class(normalize)
-			normalize = normalize_class(**normalize_conf)
-		encoder_class = encoder_choices.get_class(encoder)
-		encoder = encoder_class(input_size=input_size, **encoder_conf)
-		encoder_output_size = encoder.output_size()
-		if decoder is not None:
-			decoder_class = decoder_choices.get_class(decoder)
-			decoder = decoder_class(
-				vocab_size=vocab_size,
-				encoder_output_size=encoder_output_size,
-				**decoder_conf,
-			)
-		if ctc_weight > 0.0:
-			
-			if ctc_conf is None:
-				ctc_conf = {}
-				
-			ctc = CTC(
-				odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
-			)
-		if predictor is not None:
-			predictor_class = predictor_choices.get_class(predictor)
-			predictor = predictor_class(**predictor_conf)
-		
-		# note that eos is the same as sos (equivalent ID)
-		self.blank_id = blank_id
-		self.sos = sos if sos is not None else vocab_size - 1
-		self.eos = eos if eos is not None else vocab_size - 1
-		self.vocab_size = vocab_size
-		self.ignore_id = ignore_id
-		self.ctc_weight = ctc_weight
-		self.interctc_weight = interctc_weight
-		# self.token_list = token_list.copy()
-		#
-		self.frontend = frontend
-		self.specaug = specaug
-		self.normalize = normalize
-		# self.preencoder = preencoder
-		# self.postencoder = postencoder
-		self.encoder = encoder
-		#
-		# if not hasattr(self.encoder, "interctc_use_conditioning"):
-		# 	self.encoder.interctc_use_conditioning = False
-		# if self.encoder.interctc_use_conditioning:
-		# 	self.encoder.conditioning_layer = torch.nn.Linear(
-		# 		vocab_size, self.encoder.output_size()
-		# 	)
-		#
-		# self.error_calculator = None
-		#
-		if ctc_weight == 1.0:
-			self.decoder = None
-		else:
-			self.decoder = decoder
-
-		self.criterion_att = LabelSmoothingLoss(
-			size=vocab_size,
-			padding_idx=ignore_id,
-			smoothing=lsm_weight,
-			normalize_length=length_normalized_loss,
-		)
-		#
-		# if report_cer or report_wer:
-		# 	self.error_calculator = ErrorCalculator(
-		# 		token_list, sym_space, sym_blank, report_cer, report_wer
-		# 	)
-		#
-		if ctc_weight == 0.0:
-			self.ctc = None
-		else:
-			self.ctc = ctc
-		#
-		# self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
-		self.predictor = predictor
-		self.predictor_weight = predictor_weight
-		self.predictor_bias = predictor_bias
-		self.sampling_ratio = sampling_ratio
-		self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
-		# self.step_cur = 0
-		#
-		self.share_embedding = share_embedding
-		if self.share_embedding:
-			self.decoder.embed = None
-
-		self.use_1st_decoder_loss = use_1st_decoder_loss
-		self.length_normalized_loss = length_normalized_loss
-	
-	def forward(
-		self,
-		speech: torch.Tensor,
-		speech_lengths: torch.Tensor,
-		text: torch.Tensor,
-		text_lengths: torch.Tensor,
-		**kwargs,
-	) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-		"""Frontend + Encoder + Decoder + Calc loss
-		Args:
-				speech: (Batch, Length, ...)
-				speech_lengths: (Batch, )
-				text: (Batch, Length)
-				text_lengths: (Batch,)
-				decoding_ind: int
-		"""
-		decoding_ind = kwargs.get("kwargs", None)
-		# import pdb;
-		# pdb.set_trace()
-		if len(text_lengths.size()) > 1:
-			text_lengths = text_lengths[:, 0]
-		if len(speech_lengths.size()) > 1:
-			speech_lengths = speech_lengths[:, 0]
-
-		batch_size = speech.shape[0]
-		
-		# # for data-parallel
-		# text = text[:, : text_lengths.max()]
-		# speech = speech[:, :speech_lengths.max()]
-		
-		# 1. Encoder
-		if hasattr(self.encoder, "overlap_chunk_cls"):
-			ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
-			encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
-		else:
-			encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-		intermediate_outs = None
-		if isinstance(encoder_out, tuple):
-			intermediate_outs = encoder_out[1]
-			encoder_out = encoder_out[0]
-		
-		loss_att, pre_loss_att, acc_att, cer_att, wer_att = None, None, None, None, None
-		loss_ctc, cer_ctc = None, None
-		loss_pre = None
-		stats = dict()
-		
-		# 1. CTC branch
-		if self.ctc_weight != 0.0:
-			loss_ctc, cer_ctc = self._calc_ctc_loss(
-				encoder_out, encoder_out_lens, text, text_lengths
-			)
-			
-			# Collect CTC branch stats
-			stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
-			stats["cer_ctc"] = cer_ctc
-		
-		# Intermediate CTC (optional)
-		loss_interctc = 0.0
-		if self.interctc_weight != 0.0 and intermediate_outs is not None:
-			for layer_idx, intermediate_out in intermediate_outs:
-				# we assume intermediate_out has the same length & padding
-				# as those of encoder_out
-				loss_ic, cer_ic = self._calc_ctc_loss(
-					intermediate_out, encoder_out_lens, text, text_lengths
-				)
-				loss_interctc = loss_interctc + loss_ic
-				
-				# Collect Intermedaite CTC stats
-				stats["loss_interctc_layer{}".format(layer_idx)] = (
-					loss_ic.detach() if loss_ic is not None else None
-				)
-				stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-			
-			loss_interctc = loss_interctc / len(intermediate_outs)
-			
-			# calculate whole encoder loss
-			loss_ctc = (
-				           1 - self.interctc_weight
-			           ) * loss_ctc + self.interctc_weight * loss_interctc
-		
-		# 2b. Attention decoder branch
-		if self.ctc_weight != 1.0:
-			loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
-				encoder_out, encoder_out_lens, text, text_lengths
-			)
-		
-		# 3. CTC-Att loss definition
-		if self.ctc_weight == 0.0:
-			loss = loss_att + loss_pre * self.predictor_weight
-		elif self.ctc_weight == 1.0:
-			loss = loss_ctc
-		else:
-			loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
-		
-		if self.use_1st_decoder_loss and pre_loss_att is not None:
-			loss = loss + (1 - self.ctc_weight) * pre_loss_att
-		
-		# Collect Attn branch stats
-		stats["loss_att"] = loss_att.detach() if loss_att is not None else None
-		stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
-		stats["acc"] = acc_att
-		stats["cer"] = cer_att
-		stats["wer"] = wer_att
-		stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
-		
-		stats["loss"] = torch.clone(loss.detach())
-		
-		# force_gatherable: to-device and to-tensor if scalar for DataParallel
-		if self.length_normalized_loss:
-			batch_size = (text_lengths + self.predictor_bias).sum()
-		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-		return loss, stats, weight
-	
-	def collect_feats(
-		self,
-		speech: torch.Tensor,
-		speech_lengths: torch.Tensor,
-		text: torch.Tensor,
-		text_lengths: torch.Tensor,
-	) -> Dict[str, torch.Tensor]:
-		if self.extract_feats_in_collect_stats:
-			feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-		else:
-			# Generate dummy stats if extract_feats_in_collect_stats is False
-			logging.warning(
-				"Generating dummy stats for feats and feats_lengths, "
-				"because encoder_conf.extract_feats_in_collect_stats is "
-				f"{self.extract_feats_in_collect_stats}"
-			)
-			feats, feats_lengths = speech, speech_lengths
-		return {"feats": feats, "feats_lengths": feats_lengths}
-	
-	def encode(
-		self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
-	) -> Tuple[torch.Tensor, torch.Tensor]:
-		"""Frontend + Encoder. Note that this method is used by asr_inference.py
-		Args:
-				speech: (Batch, Length, ...)
-				speech_lengths: (Batch, )
-				ind: int
-		"""
-		with autocast(False):
-			# # 1. Extract feats
-			# feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-			
-			# 2. Data augmentation
-			if self.specaug is not None and self.training:
-				feats, feats_lengths = self.specaug(speech, speech_lengths)
-			
-			# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
-			if self.normalize is not None:
-				feats, feats_lengths = self.normalize(feats, feats_lengths)
-		
-		# # Pre-encoder, e.g. used for raw input data
-		# if self.preencoder is not None:
-		# 	feats, feats_lengths = self.preencoder(feats, feats_lengths)
-		
-		# 4. Forward encoder
-		# feats: (Batch, Length, Dim)
-		# -> encoder_out: (Batch, Length2, Dim2)
-		if self.encoder.interctc_use_conditioning:
-			if hasattr(self.encoder, "overlap_chunk_cls"):
-				encoder_out, encoder_out_lens, _ = self.encoder(
-					feats, feats_lengths, ctc=self.ctc, ind=ind
-				)
-				encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
-				                                                                            encoder_out_lens,
-				                                                                            chunk_outs=None)
-			else:
-				encoder_out, encoder_out_lens, _ = self.encoder(
-					feats, feats_lengths, ctc=self.ctc
-				)
-		else:
-			if hasattr(self.encoder, "overlap_chunk_cls"):
-				encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
-				encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
-				                                                                            encoder_out_lens,
-				                                                                            chunk_outs=None)
-			else:
-				encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
-		intermediate_outs = None
-		if isinstance(encoder_out, tuple):
-			intermediate_outs = encoder_out[1]
-			encoder_out = encoder_out[0]
-		
-		# # Post-encoder, e.g. NLU
-		# if self.postencoder is not None:
-		# 	encoder_out, encoder_out_lens = self.postencoder(
-		# 		encoder_out, encoder_out_lens
-		# 	)
-		
-		assert encoder_out.size(0) == speech.size(0), (
-			encoder_out.size(),
-			speech.size(0),
-		)
-		assert encoder_out.size(1) <= encoder_out_lens.max(), (
-			encoder_out.size(),
-			encoder_out_lens.max(),
-		)
-		
-		if intermediate_outs is not None:
-			return (encoder_out, intermediate_outs), encoder_out_lens
-		
-		return encoder_out, encoder_out_lens
-	
-	def calc_predictor(self, encoder_out, encoder_out_lens):
-		
-		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-			encoder_out.device)
-		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
-		                                                                               ignore_id=self.ignore_id)
-		return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-	
-	def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
-		
-		decoder_outs = self.decoder(
-			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
-		)
-		decoder_out = decoder_outs[0]
-		decoder_out = torch.log_softmax(decoder_out, dim=-1)
-		return decoder_out, ys_pad_lens
-	
-	def _extract_feats(
-		self, speech: torch.Tensor, speech_lengths: torch.Tensor
-	) -> Tuple[torch.Tensor, torch.Tensor]:
-		assert speech_lengths.dim() == 1, speech_lengths.shape
-		
-		# for data-parallel
-		speech = speech[:, : speech_lengths.max()]
-		if self.frontend is not None:
-			# Frontend
-			#  e.g. STFT and Feature extract
-			#       data_loader may send time-domain signal in this case
-			# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
-			feats, feats_lengths = self.frontend(speech, speech_lengths)
-		else:
-			# No frontend and no feature extract
-			feats, feats_lengths = speech, speech_lengths
-		return feats, feats_lengths
-	
-	def nll(
-		self,
-		encoder_out: torch.Tensor,
-		encoder_out_lens: torch.Tensor,
-		ys_pad: torch.Tensor,
-		ys_pad_lens: torch.Tensor,
-	) -> torch.Tensor:
-		"""Compute negative log likelihood(nll) from transformer-decoder
-		Normally, this function is called in batchify_nll.
-		Args:
-				encoder_out: (Batch, Length, Dim)
-				encoder_out_lens: (Batch,)
-				ys_pad: (Batch, Length)
-				ys_pad_lens: (Batch,)
-		"""
-		ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
-		ys_in_lens = ys_pad_lens + 1
-		
-		# 1. Forward decoder
-		decoder_out, _ = self.decoder(
-			encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
-		)  # [batch, seqlen, dim]
-		batch_size = decoder_out.size(0)
-		decoder_num_class = decoder_out.size(2)
-		# nll: negative log-likelihood
-		nll = torch.nn.functional.cross_entropy(
-			decoder_out.view(-1, decoder_num_class),
-			ys_out_pad.view(-1),
-			ignore_index=self.ignore_id,
-			reduction="none",
-		)
-		nll = nll.view(batch_size, -1)
-		nll = nll.sum(dim=1)
-		assert nll.size(0) == batch_size
-		return nll
-	
-	def batchify_nll(
-		self,
-		encoder_out: torch.Tensor,
-		encoder_out_lens: torch.Tensor,
-		ys_pad: torch.Tensor,
-		ys_pad_lens: torch.Tensor,
-		batch_size: int = 100,
-	):
-		"""Compute negative log likelihood(nll) from transformer-decoder
-		To avoid OOM, this fuction seperate the input into batches.
-		Then call nll for each batch and combine and return results.
-		Args:
-				encoder_out: (Batch, Length, Dim)
-				encoder_out_lens: (Batch,)
-				ys_pad: (Batch, Length)
-				ys_pad_lens: (Batch,)
-				batch_size: int, samples each batch contain when computing nll,
-										you may change this to avoid OOM or increase
-										GPU memory usage
-		"""
-		total_num = encoder_out.size(0)
-		if total_num <= batch_size:
-			nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-		else:
-			nll = []
-			start_idx = 0
-			while True:
-				end_idx = min(start_idx + batch_size, total_num)
-				batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
-				batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
-				batch_ys_pad = ys_pad[start_idx:end_idx, :]
-				batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
-				batch_nll = self.nll(
-					batch_encoder_out,
-					batch_encoder_out_lens,
-					batch_ys_pad,
-					batch_ys_pad_lens,
-				)
-				nll.append(batch_nll)
-				start_idx = end_idx
-				if start_idx == total_num:
-					break
-			nll = torch.cat(nll)
-		assert nll.size(0) == total_num
-		return nll
-	
-	def _calc_att_loss(
-		self,
-		encoder_out: torch.Tensor,
-		encoder_out_lens: torch.Tensor,
-		ys_pad: torch.Tensor,
-		ys_pad_lens: torch.Tensor,
-	):
-		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-			encoder_out.device)
-		if self.predictor_bias == 1:
-			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
-			ys_pad_lens = ys_pad_lens + self.predictor_bias
-		pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
-		                                                                          ignore_id=self.ignore_id)
-		
-		# 0. sampler
-		decoder_out_1st = None
-		pre_loss_att = None
-		if self.sampling_ratio > 0.0:
-
-
-			if self.use_1st_decoder_loss:
-				sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
-				                                                                       pre_acoustic_embeds)
-			else:
-				sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
-				                                               pre_acoustic_embeds)
-		else:
-			if self.step_cur < 2:
-				logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
-			sematic_embeds = pre_acoustic_embeds
-		
-		# 1. Forward decoder
-		decoder_outs = self.decoder(
-			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
-		)
-		decoder_out, _ = decoder_outs[0], decoder_outs[1]
-		
-		if decoder_out_1st is None:
-			decoder_out_1st = decoder_out
-		# 2. Compute attention loss
-		loss_att = self.criterion_att(decoder_out, ys_pad)
-		acc_att = th_accuracy(
-			decoder_out_1st.view(-1, self.vocab_size),
-			ys_pad,
-			ignore_label=self.ignore_id,
-		)
-		loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-		
-		# Compute cer/wer using attention-decoder
-		if self.training or self.error_calculator is None:
-			cer_att, wer_att = None, None
-		else:
-			ys_hat = decoder_out_1st.argmax(dim=-1)
-			cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-		
-		return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
-	
-	def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
-		
-		tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
-		ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
-		if self.share_embedding:
-			ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
-		else:
-			ys_pad_embed = self.decoder.embed(ys_pad_masked)
-		with torch.no_grad():
-			decoder_outs = self.decoder(
-				encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
-			)
-			decoder_out, _ = decoder_outs[0], decoder_outs[1]
-			pred_tokens = decoder_out.argmax(-1)
-			nonpad_positions = ys_pad.ne(self.ignore_id)
-			seq_lens = (nonpad_positions).sum(1)
-			same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
-			input_mask = torch.ones_like(nonpad_positions)
-			bsz, seq_len = ys_pad.size()
-			for li in range(bsz):
-				target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
-				if target_num > 0:
-					input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
-			input_mask = input_mask.eq(1)
-			input_mask = input_mask.masked_fill(~nonpad_positions, False)
-			input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
-		
-		sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
-			input_mask_expand_dim, 0)
-		return sematic_embeds * tgt_mask, decoder_out * tgt_mask
-	
-	def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
-		tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
-		ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
-		if self.share_embedding:
-			ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
-		else:
-			ys_pad_embed = self.decoder.embed(ys_pad_masked)
-		decoder_outs = self.decoder(
-			encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
-		)
-		pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
-		decoder_out, _ = decoder_outs[0], decoder_outs[1]
-		pred_tokens = decoder_out.argmax(-1)
-		nonpad_positions = ys_pad.ne(self.ignore_id)
-		seq_lens = (nonpad_positions).sum(1)
-		same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
-		input_mask = torch.ones_like(nonpad_positions)
-		bsz, seq_len = ys_pad.size()
-		for li in range(bsz):
-			target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
-			if target_num > 0:
-				input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
-		input_mask = input_mask.eq(1)
-		input_mask = input_mask.masked_fill(~nonpad_positions, False)
-		input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
-		
-		sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
-			input_mask_expand_dim, 0)
-		
-		return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
-	
-	def _calc_ctc_loss(
-		self,
-		encoder_out: torch.Tensor,
-		encoder_out_lens: torch.Tensor,
-		ys_pad: torch.Tensor,
-		ys_pad_lens: torch.Tensor,
-	):
-		# Calc CTC loss
-		loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-		
-		# Calc CER using CTC
-		cer_ctc = None
-		if not self.training and self.error_calculator is not None:
-			ys_hat = self.ctc.argmax(encoder_out).data
-			cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
-		return loss_ctc, cer_ctc
diff --git a/funasr/cli/model_class_factory.py b/funasr/models/model_class_factory.py
similarity index 93%
rename from funasr/cli/model_class_factory.py
rename to funasr/models/model_class_factory.py
index b329492..819ca21 100644
--- a/funasr/cli/model_class_factory.py
+++ b/funasr/models/model_class_factory.py
@@ -123,27 +123,7 @@
     default=None,
     optional=True,
 )
-# model_choices = ClassChoices(
-#     "model",
-#     classes=dict(
-#         asr=ASRModel,
-#         uniasr=UniASR,
-#         paraformer=Paraformer,
-#         paraformer_online=ParaformerOnline,
-#         paraformer_bert=ParaformerBert,
-#         bicif_paraformer=BiCifParaformer,
-#         contextual_paraformer=ContextualParaformer,
-#         neatcontextual_paraformer=NeatContextualParaformer,
-#         mfcca=MFCCA,
-#         timestamp_prediction=TimestampPredictor,
-#         rnnt=TransducerModel,
-#         rnnt_unified=UnifiedTransducerModel,
-#         bat=BATModel,
-#         sa_asr=SAASRModel,
-#     ),
-#     type_check=None,
-#     default="asr",
-# )
+
 preencoder_choices = ClassChoices(
     name="preencoder",
     classes=dict(
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 75b36a9..50e7cd7 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -37,7 +37,7 @@
 # from funasr.models.predictor.cif import CifPredictorV3
 from funasr.models.paraformer.search import Hypothesis
 
-from funasr.cli.model_class_factory import *
+from funasr.models.model_class_factory import *
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
 	from torch.cuda.amp import autocast
diff --git a/funasr/cli/trainer.py b/funasr/utils/trainer.py
similarity index 100%
rename from funasr/cli/trainer.py
rename to funasr/utils/trainer.py
diff --git a/setup.py b/setup.py
index 197f346..a1e47af 100644
--- a/setup.py
+++ b/setup.py
@@ -131,6 +131,6 @@
         "Topic :: Software Development :: Libraries :: Python Modules",
     ],
     entry_points={"console_scripts": [
-        "funasr = funasr.bin.inference_cli:main",
+        "funasr = funasr.bin.inference:main_hydra",
     ]},
 )

--
Gitblit v1.9.1