jmwang66
2023-06-20 2ff405b2f4ab899eff9bece232969fbb0c8f0555
funasr/bin/asr_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -7,109 +7,77 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
#!/usr/bin/env python3
import argparse
import logging
import sys
import time
import copy
import os
import codecs
import tempfile
import requests
from pathlib import Path
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import yaml
import numpy as np
import torch
import torchaudio
import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import BeamSearch
# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr.bin.asr_infer import Speech2Text
from funasr.bin.asr_infer import Speech2TextMFCCA
from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
from funasr.bin.asr_infer import Speech2TextSAASR
from funasr.bin.asr_infer import Speech2TextTransducer
from funasr.bin.asr_infer import Speech2TextUniASR
from funasr.bin.punc_infer import Text2Punc
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.vad_infer import Speech2VadSegment
from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.asr import ASRTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import asr_utils, postprocess_utils
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.bin.asr_infer import Speech2Text
from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
from funasr.bin.asr_infer import Speech2TextUniASR
from funasr.bin.asr_infer import Speech2TextMFCCA
from funasr.bin.vad_infer import Speech2VadSegment
from funasr.bin.punc_infer import Text2Punc
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.asr_infer import Speech2TextTransducer
from funasr.bin.asr_infer import Speech2TextSAASR
def inference_asr(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    streaming: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    mc: bool = False,
    param_dict: dict = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        streaming: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        mc: bool = False,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
@@ -120,23 +88,23 @@
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
@@ -160,7 +128,7 @@
    )
    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
    speech2text = Speech2Text(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
@@ -173,20 +141,18 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            mc=mc,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
@@ -197,14 +163,14 @@
            writer = DatadirWriter(output_path)
        else:
            writer = None
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            # N-best list of (text, token, token_int, hyp_object)
            try:
                results = speech2text(**batch)
@@ -212,19 +178,19 @@
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["sil"], [2], hyp]] * nbest
            # Only supporting batch_size==1
            key = keys[0]
            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
                # Create a directory: outdir/{n}best_recog
                if writer is not None:
                    ibest_writer = writer[f"{n}best_recog"]
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                if text is not None:
                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
@@ -233,67 +199,67 @@
                    asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text
                logging.info("uttid: {}".format(key))
                logging.info("text predictions: {}\n".format(text))
        return asr_result_list
    return _forward
def inference_paraformer(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    output_dir: Optional[str] = None,
    timestamp_infer_config: Union[Path, str] = None,
    timestamp_model_file: Union[Path, str] = None,
    param_dict: dict = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        output_dir: Optional[str] = None,
        timestamp_infer_config: Union[Path, str] = None,
        timestamp_model_file: Union[Path, str] = None,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if word_lm_train_config is not None:
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    export_mode = False
    if param_dict is not None:
        hotword_list_or_file = param_dict.get('hotword')
        export_mode = param_dict.get("export_mode", False)
    else:
        hotword_list_or_file = None
    if kwargs.get("device", None) == "cpu":
        ngpu = 0
    if ngpu >= 1 and torch.cuda.is_available():
@@ -301,10 +267,10 @@
    else:
        device = "cpu"
        batch_size = 1
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
@@ -326,9 +292,9 @@
        nbest=nbest,
        hotword_list_or_file=hotword_list_or_file,
    )
    speech2text = Speech2TextParaformer(**speech2text_kwargs)
    if timestamp_model_file is not None:
        speechtext2timestamp = Speech2Timestamp(
            timestamp_cmvn_file=cmvn_file,
@@ -337,16 +303,16 @@
        )
    else:
        speechtext2timestamp = None
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[np.ndarray, torch.Tensor] = None,
        output_dir_v2: Optional[str] = None,
        fs: dict = None,
        param_dict: dict = None,
        **kwargs,
            data_path_and_name_and_type,
            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
            output_dir_v2: Optional[str] = None,
            fs: dict = None,
            param_dict: dict = None,
            **kwargs,
    ):
        hotword_list_or_file = None
        if param_dict is not None:
            hotword_list_or_file = param_dict.get('hotword')
@@ -354,30 +320,28 @@
            hotword_list_or_file = kwargs['hotword']
        if hotword_list_or_file is not None or 'hotword' in kwargs:
            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        if param_dict is not None:
            use_timestamp = param_dict.get('use_timestamp', True)
        else:
            use_timestamp = True
        forward_time_total = 0.0
        length_total = 0.0
        finish_count = 0
@@ -390,17 +354,17 @@
            writer = DatadirWriter(output_path)
        else:
            writer = None
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
            logging.info("decoding, utt_id: {}".format(keys))
            # N-best list of (text, token, token_int, hyp_object)
            time_beg = time.time()
            results = speech2text(**batch)
            if len(results) < 1:
@@ -416,10 +380,10 @@
                                                                                               100 * forward_time / (
                                                                                                       length * lfr_factor))
            logging.info(rtf_cur)
            for batch_id in range(_bs):
                result = [results[batch_id][:-2]]
                key = keys[batch_id]
                for n, result in zip(range(1, nbest + 1), result):
                    text, token, token_int, hyp = result[0], result[1], result[2], result[3]
@@ -438,13 +402,13 @@
                    # Create a directory: outdir/{n}best_recog
                    if writer is not None:
                        ibest_writer = writer[f"{n}best_recog"]
                        # Write the result to each file
                        ibest_writer["token"][key] = " ".join(token)
                        # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                        ibest_writer["score"][key] = str(hyp.score)
                        ibest_writer["rtf"][key] = rtf_cur
                    if text is not None:
                        if use_timestamp and timestamp is not None:
                            postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
@@ -465,7 +429,7 @@
                        # asr_utils.print_progress(finish_count / file_count)
                        if writer is not None:
                            ibest_writer["text"][key] = " ".join(word_lists)
                    logging.info("decoding, utt: {}, predictions: {}".format(key, text))
        rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
                                                                                                           forward_time_total,
@@ -475,74 +439,74 @@
        if writer is not None:
            ibest_writer["rtf"]["rtf_avf"] = rtf_avg
        return asr_result_list
    return _forward
def inference_paraformer_vad_punc(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    vad_infer_config: Optional[str] = None,
    vad_model_file: Optional[str] = None,
    vad_cmvn_file: Optional[str] = None,
    time_stamp_writer: bool = True,
    punc_infer_config: Optional[str] = None,
    punc_model_file: Optional[str] = None,
    outputs_dict: Optional[bool] = True,
    param_dict: dict = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        vad_infer_config: Optional[str] = None,
        vad_model_file: Optional[str] = None,
        vad_cmvn_file: Optional[str] = None,
        time_stamp_writer: bool = True,
        punc_infer_config: Optional[str] = None,
        punc_model_file: Optional[str] = None,
        outputs_dict: Optional[bool] = True,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if word_lm_train_config is not None:
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if param_dict is not None:
        hotword_list_or_file = param_dict.get('hotword')
    else:
        hotword_list_or_file = None
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2vadsegment
    speech2vadsegment_kwargs = dict(
        vad_infer_config=vad_infer_config,
@@ -553,7 +517,7 @@
    )
    # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
    speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
    # 3. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
@@ -579,12 +543,12 @@
    text2punc = None
    if punc_model_file is not None:
        text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
    if output_dir is not None:
        writer = DatadirWriter(output_dir)
        ibest_writer = writer[f"1best_recog"]
        ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
@@ -592,43 +556,41 @@
                 param_dict: dict = None,
                 **kwargs,
                 ):
        hotword_list_or_file = None
        if param_dict is not None:
            hotword_list_or_file = param_dict.get('hotword')
        if 'hotword' in kwargs:
            hotword_list_or_file = kwargs['hotword']
        batch_size_token = kwargs.get("batch_size_token", 6000)
        print("batch_size_token: ", batch_size_token)
        if speech2text.hotword_list is None:
            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=None,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            batch_size=1,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
            collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        if param_dict is not None:
            use_timestamp = param_dict.get('use_timestamp', True)
        else:
            use_timestamp = True
        finish_count = 0
        file_count = 1
        lfr_factor = 6
@@ -639,7 +601,7 @@
        if output_path is not None:
            writer = DatadirWriter(output_path)
            ibest_writer = writer[f"1best_recog"]
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
@@ -648,15 +610,16 @@
            beg_vad = time.time()
            vad_results = speech2vadsegment(**batch)
            end_vad = time.time()
            print("time cost vad: ", end_vad-beg_vad)
            print("time cost vad: ", end_vad - beg_vad)
            _, vadsegments = vad_results[0], vad_results[1][0]
            speech, speech_lengths = batch["speech"], batch["speech_lengths"]
            n = len(vadsegments)
            data_with_index = [(vadsegments[i], i) for i in range(n)]
            sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
            results_sorted = []
            batch_size_token_ms = batch_size_token*60
            if speech2text.device == "cpu":
                batch_size_token_ms = 0
@@ -666,7 +629,8 @@
            beg_idx = 0
            for j, _ in enumerate(range(0, n)):
                batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
                if j < n-1 and (batch_size_token_ms_cum + sorted_data[j+1][0][1] - sorted_data[j+1][0][0])<batch_size_token_ms:
                if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][
                    0]) < batch_size_token_ms:
                    continue
                batch_size_token_ms_cum = 0
                end_idx = j + 1
@@ -679,11 +643,11 @@
                results = speech2text(**batch)
                end_asr = time.time()
                print("time cost asr: ", end_asr - beg_asr)
                if len(results) < 1:
                    results = [["", [], [], [], [], [], []]]
                results_sorted.extend(results)
            restored_data = [0] * n
            for j in range(n):
                index = sorted_data[j][1]
@@ -699,12 +663,12 @@
                        t[1] += vadsegments[j][0]
                    result[4] += restored_data[j][4]
                # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
            key = keys[0]
            # result = result_segments[0]
            text, token, token_int = result[0], result[1], result[2]
            time_stamp = result[4] if len(result[4]) > 0 else None
            if use_timestamp and time_stamp is not None:
                postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
            else:
@@ -718,23 +682,23 @@
                                                                           postprocessed_result[2]
            else:
                text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
            text_postprocessed_punc = text_postprocessed
            punc_id_list = []
            if len(word_lists) > 0 and text2punc is not None:
                beg_punc = time.time()
                text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
                end_punc = time.time()
                print("time cost punc: ", end_punc-beg_punc)
                print("time cost punc: ", end_punc - beg_punc)
            item = {'key': key, 'value': text_postprocessed_punc}
            if text_postprocessed != "":
                item['text_postprocessed'] = text_postprocessed
            if time_stamp_postprocessed != "":
                item['time_stamp'] = time_stamp_postprocessed
            item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
            asr_result_list.append(item)
            finish_count += 1
            # asr_utils.print_progress(finish_count / file_count)
@@ -747,11 +711,12 @@
                ibest_writer["text_with_punc"][key] = text_postprocessed_punc
                if time_stamp_postprocessed is not None:
                    ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
            logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
        return asr_result_list
    return _forward
def inference_paraformer_online(
        maxlenratio: float,
@@ -852,7 +817,7 @@
            data = yaml.load(f, Loader=yaml.Loader)
        return data
    def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
    def _prepare_cache(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
        if len(cache) > 0:
            return cache
        config = _read_yaml(asr_train_config)
@@ -868,14 +833,15 @@
        return cache
    def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
    def _cache_reset(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
        if len(cache) > 0:
            config = _read_yaml(asr_train_config)
            enc_output_size = config["encoder_conf"]["output_size"]
            feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
            cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
                        "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
                        "tail_chunk": False}
            cache["encoder"] = cache_en
            cache_de = {"decode_fsmn": None}
@@ -920,7 +886,7 @@
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
            sample_offset = 0
            speech_length = raw_inputs.shape[1]
            stride_size =  chunk_size[1] * 960
            stride_size = chunk_size[1] * 960
            cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
            final_result = ""
            for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
@@ -949,40 +915,40 @@
def inference_uniasr(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    ngram_file: Optional[str] = None,
    cmvn_file: Optional[str] = None,
    # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    streaming: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    token_num_relax: int = 1,
    decoding_ind: int = 0,
    decoding_mode: str = "model1",
    param_dict: dict = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        ngram_file: Optional[str] = None,
        cmvn_file: Optional[str] = None,
        # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        streaming: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        token_num_relax: int = 1,
        decoding_ind: int = 0,
        decoding_mode: str = "model1",
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
@@ -993,17 +959,17 @@
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    if param_dict is not None and "decoding_model" in param_dict:
        if param_dict["decoding_model"] == "fast":
            decoding_ind = 0
@@ -1016,10 +982,10 @@
            decoding_mode = "model2"
        else:
            raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
@@ -1046,7 +1012,7 @@
        decoding_mode=decoding_mode,
    )
    speech2text = Speech2TextUniASR(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
@@ -1059,19 +1025,17 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
@@ -1082,14 +1046,14 @@
            writer = DatadirWriter(output_path)
        else:
            writer = None
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            # N-best list of (text, token, token_int, hyp_object)
            try:
                results = speech2text(**batch)
@@ -1097,7 +1061,7 @@
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["sil"], [2], hyp]] * nbest
            # Only supporting batch_size==1
            key = keys[0]
            logging.info(f"Utterance: {key}")
@@ -1105,12 +1069,12 @@
                # Create a directory: outdir/{n}best_recog
                if writer is not None:
                    ibest_writer = writer[f"{n}best_recog"]
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                if text is not None:
                    text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
@@ -1120,40 +1084,40 @@
                    if writer is not None:
                        ibest_writer["text"][key] = " ".join(word_lists)
        return asr_result_list
    return _forward
def inference_mfcca(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    streaming: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    param_dict: dict = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        streaming: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
@@ -1164,20 +1128,20 @@
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
@@ -1201,7 +1165,7 @@
    )
    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
    speech2text = Speech2TextMFCCA(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
@@ -1214,20 +1178,18 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            fs=fs,
            mc=True,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
@@ -1238,14 +1200,14 @@
            writer = DatadirWriter(output_path)
        else:
            writer = None
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            # N-best list of (text, token, token_int, hyp_object)
            try:
                results = speech2text(**batch)
@@ -1253,19 +1215,19 @@
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["<space>"], [2], hyp]] * nbest
            # Only supporting batch_size==1
            key = keys[0]
            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
                # Create a directory: outdir/{n}best_recog
                if writer is not None:
                    ibest_writer = writer[f"{n}best_recog"]
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                if text is not None:
                    text_postprocessed = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
@@ -1275,42 +1237,43 @@
                    if writer is not None:
                        ibest_writer["text"][key] = text
        return asr_result_list
    return _forward
def inference_transducer(
    output_dir: str,
    batch_size: int,
    dtype: str,
    beam_size: int,
    ngpu: int,
    seed: int,
    lm_weight: float,
    nbest: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str],
    beam_search_config: Optional[dict],
    lm_train_config: Optional[str],
    lm_file: Optional[str],
    model_tag: Optional[str],
    token_type: Optional[str],
    bpemodel: Optional[str],
    key_file: Optional[str],
    allow_variable_data_keys: bool,
    quantize_asr_model: Optional[bool],
    quantize_modules: Optional[List[str]],
    quantize_dtype: Optional[str],
    streaming: Optional[bool],
    simu_streaming: Optional[bool],
    chunk_size: Optional[int],
    left_context: Optional[int],
    right_context: Optional[int],
    display_partial_hypotheses: bool,
    **kwargs,
        output_dir: str,
        batch_size: int,
        dtype: str,
        beam_size: int,
        ngpu: int,
        seed: int,
        lm_weight: float,
        nbest: int,
        num_workers: int,
        log_level: Union[int, str],
        data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str],
        beam_search_config: Optional[dict],
        lm_train_config: Optional[str],
        lm_file: Optional[str],
        model_tag: Optional[str],
        token_type: Optional[str],
        bpemodel: Optional[str],
        key_file: Optional[str],
        allow_variable_data_keys: bool,
        quantize_asr_model: Optional[bool],
        quantize_modules: Optional[List[str]],
        quantize_dtype: Optional[str],
        streaming: Optional[bool],
        simu_streaming: Optional[bool],
        chunk_size: Optional[int],
        left_context: Optional[int],
        right_context: Optional[int],
        display_partial_hypotheses: bool,
        **kwargs,
) -> None:
    """Transducer model inference.
    Args:
@@ -1391,7 +1354,7 @@
        model_tag=model_tag,
        **speech2text_kwargs,
    )
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
@@ -1400,106 +1363,99 @@
                 **kwargs,
                 ):
        # 3. Build data-iterator
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(
                speech2text.asr_train_args, False
            ),
            collate_fn=ASRTask.build_collate_fn(
                speech2text.asr_train_args, False
            ),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        # 4 .Start for-loop
        with DatadirWriter(output_dir) as writer:
            for keys, batch in loader:
                assert isinstance(batch, dict), type(batch)
                assert all(isinstance(s, str) for s in keys), keys
                _bs = len(next(iter(batch.values())))
                assert len(keys) == _bs, f"{len(keys)} != {_bs}"
                batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
                assert len(batch.keys()) == 1
                try:
                    if speech2text.streaming:
                        speech = batch["speech"]
                        _steps = len(speech) // speech2text._ctx
                        _end = 0
                        for i in range(_steps):
                            _end = (i + 1) * speech2text._ctx
                            speech2text.streaming_decode(
                                speech[i * speech2text._ctx : _end], is_final=False
                                speech[i * speech2text._ctx: _end], is_final=False
                            )
                        final_hyps = speech2text.streaming_decode(
                            speech[_end : len(speech)], is_final=True
                            speech[_end: len(speech)], is_final=True
                        )
                    elif speech2text.simu_streaming:
                        final_hyps = speech2text.simu_streaming_decode(**batch)
                    else:
                        final_hyps = speech2text(**batch)
                    results = speech2text.hypotheses_to_results(final_hyps)
                except TooShortUttError as e:
                    logging.warning(f"Utterance {keys} {e}")
                    hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
                    results = [[" ", ["<space>"], [2], hyp]] * nbest
                key = keys[0]
                for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
                    ibest_writer = writer[f"{n}best_recog"]
                    ibest_writer["token"][key] = " ".join(token)
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                    if text is not None:
                        ibest_writer["text"][key] = text
    return _forward
def inference_sa_asr(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    streaming: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    mc: bool = False,
    param_dict: dict = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        streaming: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        mc: bool = False,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    if batch_size > 1:
@@ -1508,23 +1464,23 @@
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
@@ -1548,7 +1504,7 @@
    )
    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
    speech2text = Speech2TextSAASR(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
@@ -1561,20 +1517,18 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=speech2text.asr_train_args,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            mc=mc,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
@@ -1585,7 +1539,7 @@
            writer = DatadirWriter(output_path)
        else:
            writer = None
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
@@ -1599,20 +1553,20 @@
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["sil"], [2], hyp]] * nbest
            # Only supporting batch_size==1
            key = keys[0]
            for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
                # Create a directory: outdir/{n}best_recog
                if writer is not None:
                    ibest_writer = writer[f"{n}best_recog"]
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                    ibest_writer["text_id"][key] = text_id
                if text is not None:
                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
@@ -1621,12 +1575,12 @@
                    asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text
                logging.info("uttid: {}".format(key))
                logging.info("text predictions: {}".format(text))
                logging.info("text_id predictions: {}\n".format(text_id))
        return asr_result_list
    return _forward
@@ -1664,7 +1618,7 @@
        description="ASR Decoding",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # Note(kamo): Use '_' instead of '-' as separator.
    # '-' is confusing if written in yaml.
    parser.add_argument(
@@ -1674,7 +1628,7 @@
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument(
        "--ngpu",
@@ -1707,7 +1661,7 @@
        default=1,
        help="The number of workers used for DataLoader",
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument(
        "--data_path_and_name_and_type",
@@ -1729,7 +1683,7 @@
        default=False,
        help="MultiChannel input",
    )
    group = parser.add_argument_group("The model configuration related")
    group.add_argument(
        "--vad_infer_config",
@@ -1792,7 +1746,7 @@
        default={},
        help="The keyword arguments for transducer beam search.",
    )
    group = parser.add_argument_group("Beam-search related")
    group.add_argument(
        "--batch_size",
@@ -1839,7 +1793,7 @@
        default=False,
        help="Whether to display partial hypotheses during chunk-by-chunk inference.",
    )
    group = parser.add_argument_group("Dynamic quantization related")
    group.add_argument(
        "--quantize_asr_model",
@@ -1864,7 +1818,7 @@
        choices=["float16", "qint8"],
        help="Dtype for dynamic quantization.",
    )
    group = parser.add_argument_group("Text converter related")
    group.add_argument(
        "--token_type",
@@ -1922,7 +1876,6 @@
    inference_pipeline = inference_launch(**kwargs)
    return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
if __name__ == "__main__":