游雁
2023-03-31 4ba1011b42e041ee1d71448eefd7ef2e7bd61bb6
funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -1,8 +1,14 @@
#!/usr/bin/env python3
import json
import argparse
import logging
import sys
import time
import os
import codecs
import tempfile
import requests
from pathlib import Path
from typing import Optional
from typing import Sequence
@@ -12,6 +18,7 @@
from typing import Any
from typing import List
import math
import copy
import numpy as np
import torch
from typeguard import check_argument_types
@@ -36,27 +43,22 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_lfr6
from funasr.tasks.punctuation import PunctuationTask
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.punctuation.text_preprocessor import split_words, split_to_mini_sentence
from funasr.bin.vad_inference import Speech2VadSegment
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.bin.punctuation_infer import Text2Punc
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
header_colors = '\033[95m'
end_colors = '\033[0m'
global_asr_language: str = 'zh-cn'
global_sample_rate: Union[int, Dict[Any, int]] = {
    'audio_fs': 16000,
    'model_fs': 16000
}
class Speech2Text:
    """Speech2Text class
    Examples:
            >>> import soundfile
            >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
            >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
            >>> audio, rate = soundfile.read("speech.wav")
            >>> speech2text(audio)
            [(text, token, token_int, hypothesis object), ...]
@@ -83,6 +85,7 @@
            penalty: float = 0.0,
            nbest: int = 1,
            frontend_conf: dict = None,
            hotword_list_or_file: str = None,
            **kwargs,
    ):
        assert check_argument_types()
@@ -100,10 +103,13 @@
        # logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        if asr_model.ctc != None:
            ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
            scorers.update(
                ctc=ctc
            )
        token_list = asr_model.token_list
        scorers.update(
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )
@@ -145,7 +151,7 @@
        for scorer in scorers.values():
            if isinstance(scorer, torch.nn.Module):
                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
        logging.info(f"Decoding device={device}, dtype={dtype}")
        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
@@ -170,8 +176,13 @@
        self.asr_train_args = asr_train_args
        self.converter = converter
        self.tokenizer = tokenizer
        # 6. [Optional] Build hotword list from str, local file or url
        self.hotword_list = None
        self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
        is_use_lm = lm_weight != 0.0 and lm_file is not None
        if ctc_weight == 0.0 and not is_use_lm:
        if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
            beam_search = None
        self.beam_search = beam_search
        logging.info(f"Beam_search: {self.beam_search}")
@@ -185,12 +196,11 @@
        self.encoder_downsampling_factor = 1
        if asr_train_args.encoder_conf["input_layer"] == "conv2d":
            self.encoder_downsampling_factor = 4
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, begin_time: int = 0, end_time: int = None,
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
            begin_time: int = 0, end_time: int = None,
    ):
        """Inference
@@ -216,7 +226,7 @@
        else:
            feats = speech
            feats_len = speech_lengths
        lfr_factor = max(1, (feats.size()[-1]//80)-1)
        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
        batch = {"speech": feats, "speech_lengths": feats_len}
        # a. To device
@@ -230,10 +240,24 @@
        enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
        predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], predictor_outs[2], predictor_outs[3]
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
                                                                        predictor_outs[2], predictor_outs[3]
        pre_token_length = pre_token_length.round().long()
        decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        if torch.max(pre_token_length) < 1:
            return []
        if not isinstance(self.asr_model, ContextualParaformer):
            if self.hotword_list:
                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        else:
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        if isinstance(self.asr_model, BiCifParaformer):
            _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
                                                                                   pre_token_length)  # test no bias cif2
        results = []
        b, n, d = decoder_out.size()
@@ -244,7 +268,7 @@
                nbest_hyps = self.beam_search(
                    x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
                )
                nbest_hyps = nbest_hyps[: self.nbest]
            else:
                yseq = am_scores.argmax(dim=-1)
@@ -255,349 +279,134 @@
                    [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
                )
                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
            for hyp in nbest_hyps:
                assert isinstance(hyp, (Hypothesis)), type(hyp)
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq[1:last_pos]
                else:
                    token_int = hyp.yseq[1:last_pos].tolist()
                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
                if len(token_int) == 0:
                    continue
                # Change integer-ids to tokens
                token = self.converter.ids2tokens(token_int)
                if self.tokenizer is not None:
                    text = self.tokenizer.tokens2text(token)
                else:
                    text = None
                time_stamp = time_stamp_lfr6(alphas[i:i+1,], enc_len[i:i+1,], token, begin_time, end_time)
                results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
                if isinstance(self.asr_model, BiCifParaformer):
                    _, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
                                                            us_peaks[i],
                                                            copy.copy(token),
                                                            vad_offset=begin_time)
                    results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
                else:
                    results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
        # assert check_return_type(results)
        return results
class Speech2VadSegment:
    """Speech2VadSegment class
    Examples:
        >>> import soundfile
        >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2segment(audio)
        [[10, 230], [245, 450], ...]
    """
    def __init__(
            self,
            vad_infer_config: Union[Path, str] = None,
            vad_model_file: Union[Path, str] = None,
            vad_cmvn_file: Union[Path, str] = None,
            device: str = "cpu",
            batch_size: int = 1,
            dtype: str = "float32",
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build vad model
        vad_model, vad_infer_args = VADTask.build_model_from_file(
            vad_infer_config, vad_model_file, device
        )
        frontend = None
        if vad_infer_args.frontend is not None:
            frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
        # logging.info("vad_model: {}".format(vad_model))
        # logging.info("vad_infer_args: {}".format(vad_infer_args))
        vad_model.to(dtype=getattr(torch, dtype)).eval()
        self.vad_model = vad_model
        self.vad_infer_args = vad_infer_args
        self.device = device
        self.dtype = dtype
        self.frontend = frontend
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
    ) -> List[List[int]]:
        """Inference
        Args:
            speech: Input speech data
        Returns:
            text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            self.frontend.filter_length_max = math.inf
            fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
            feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
            fbanks = to_device(fbanks, device=self.device)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
    def generate_hotwords_list(self, hotword_list_or_file):
        # for None
        if hotword_list_or_file is None:
            hotword_list = None
        # for local txt inputs
        elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
            logging.info("Attempting to parse hotwords from local txt...")
            hotword_list = []
            hotword_str_list = []
            with codecs.open(hotword_list_or_file, 'r') as fin:
                for line in fin.readlines():
                    hw = line.strip()
                    hotword_str_list.append(hw)
                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                hotword_list.append([self.asr_model.sos])
                hotword_str_list.append('<s>')
            logging.info("Initialized hotword list from file: {}, hotword list: {}."
                         .format(hotword_list_or_file, hotword_str_list))
        # for url, download and generate txt
        elif hotword_list_or_file.startswith('http'):
            logging.info("Attempting to parse hotwords from url...")
            work_dir = tempfile.TemporaryDirectory().name
            if not os.path.exists(work_dir):
                os.makedirs(work_dir)
            text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
            local_file = requests.get(hotword_list_or_file)
            open(text_file_path, "wb").write(local_file.content)
            hotword_list_or_file = text_file_path
            hotword_list = []
            hotword_str_list = []
            with codecs.open(hotword_list_or_file, 'r') as fin:
                for line in fin.readlines():
                    hw = line.strip()
                    hotword_str_list.append(hw)
                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                hotword_list.append([self.asr_model.sos])
                hotword_str_list.append('<s>')
            logging.info("Initialized hotword list from file: {}, hotword list: {}."
                         .format(hotword_list_or_file, hotword_str_list))
        # for text str input
        elif not hotword_list_or_file.endswith('.txt'):
            logging.info("Attempting to parse hotwords as str...")
            hotword_list = []
            hotword_str_list = []
            for hw in hotword_list_or_file.strip().split():
                hotword_str_list.append(hw)
                hotword_list.append(self.converter.tokens2ids([i for i in hw]))
            hotword_list.append([self.asr_model.sos])
            hotword_str_list.append('<s>')
            logging.info("Hotword list: {}.".format(hotword_str_list))
        else:
            raise Exception("Need to extract feats first, please configure frontend configuration")
        batch = {"feats": feats, "feats_lengths": feats_len, "waveform": speech}
            hotword_list = None
        return hotword_list
        # a. To device
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        segments = self.vad_model(**batch)
        return fbanks, segments
# def inference(
#         maxlenratio: float,
#         minlenratio: float,
#         batch_size: int,
#         beam_size: int,
#         ngpu: int,
#         ctc_weight: float,
#         lm_weight: float,
#         penalty: float,
#         log_level: Union[int, str],
#         data_path_and_name_and_type,
#         asr_train_config: Optional[str],
#         asr_model_file: Optional[str],
#         cmvn_file: Optional[str] = None,
#         raw_inputs: Union[np.ndarray, torch.Tensor] = None,
#         lm_train_config: Optional[str] = None,
#         lm_file: Optional[str] = None,
#         token_type: Optional[str] = None,
#         key_file: Optional[str] = None,
#         word_lm_train_config: Optional[str] = None,
#         bpemodel: Optional[str] = None,
#         allow_variable_data_keys: bool = False,
#         streaming: bool = False,
#         output_dir: Optional[str] = None,
#         dtype: str = "float32",
#         seed: int = 0,
#         ngram_weight: float = 0.9,
#         nbest: int = 1,
#         num_workers: int = 1,
#         vad_infer_config: Optional[str] = None,
#         vad_model_file: Optional[str] = None,
#         vad_cmvn_file: Optional[str] = None,
#         time_stamp_writer: bool = False,
#         punc_infer_config: Optional[str] = None,
#         punc_model_file: Optional[str] = None,
#         **kwargs,
# ):
#     assert check_argument_types()
#
#     if word_lm_train_config is not None:
#         raise NotImplementedError("Word LM is not implemented")
#     if ngpu > 1:
#         raise NotImplementedError("only single GPU decoding is supported")
#
#     logging.basicConfig(
#         level=log_level,
#         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
#     )
#
#     if ngpu >= 1 and torch.cuda.is_available():
#         device = "cuda"
#     else:
#         device = "cpu"
#
#     # 1. Set random-seed
#     set_all_random_seed(seed)
#
#     # 2. Build speech2vadsegment
#     speech2vadsegment_kwargs = dict(
#         vad_infer_config=vad_infer_config,
#         vad_model_file=vad_model_file,
#         vad_cmvn_file=vad_cmvn_file,
#         device=device,
#         dtype=dtype,
#     )
#     # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
#     speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
#
#     # 3. Build speech2text
#     speech2text_kwargs = dict(
#         asr_train_config=asr_train_config,
#         asr_model_file=asr_model_file,
#         cmvn_file=cmvn_file,
#         lm_train_config=lm_train_config,
#         lm_file=lm_file,
#         token_type=token_type,
#         bpemodel=bpemodel,
#         device=device,
#         maxlenratio=maxlenratio,
#         minlenratio=minlenratio,
#         dtype=dtype,
#         beam_size=beam_size,
#         ctc_weight=ctc_weight,
#         lm_weight=lm_weight,
#         ngram_weight=ngram_weight,
#         penalty=penalty,
#         nbest=nbest,
#         frontend_conf=frontend_conf,
#     )
#     speech2text = Speech2Text(**speech2text_kwargs)
#
#     text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
#
#     # 3. Build data-iterator
#     loader = ASRTask.build_streaming_iterator(
#         data_path_and_name_and_type,
#         dtype=dtype,
#         batch_size=1,
#         key_file=key_file,
#         num_workers=num_workers,
#         preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
#         collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
#         allow_variable_data_keys=allow_variable_data_keys,
#         inference=True,
#     )
#
#     forward_time_total = 0.0
#     length_total = 0.0
#     finish_count = 0
#     file_count = 1
#     # 7 .Start for-loop
#     asr_result_list = []
#     if output_dir is not None:
#         writer = DatadirWriter(output_dir)
#     else:
#         writer = None
#
#     for keys, batch in loader:
#         assert isinstance(batch, dict), type(batch)
#         assert all(isinstance(s, str) for s in keys), keys
#         _bs = len(next(iter(batch.values())))
#         assert len(keys) == _bs, f"{len(keys)} != {_bs}"
#         # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
#
#         logging.info("decoding, utt_id: {}".format(keys))
#         # N-best list of (text, token, token_int, hyp_object)
#         time_beg = time.time()
#         vad_results = speech2vadsegment(**batch)
#         time_end = time.time()
#         fbanks, vadsegments = vad_results[0], vad_results[1]
#         for i, segments in enumerate(vadsegments):
#             result_segments = [["", [], [], ]]
#             for j, segment_idx in enumerate(segments):
#                 bed_idx, end_idx = int(segment_idx[0]/10), int(segment_idx[1]/10)
#                 segment = fbanks[:, bed_idx:end_idx, :].to(device)
#                 speech_lengths = torch.Tensor([end_idx-bed_idx]).int().to(device)
#                 batch = {"speech": segment, "speech_lengths": speech_lengths, "begin_time": vadsegments[i][j][0], "end_time": vadsegments[i][j][1]}
#                 results = speech2text(**batch)
#                 if len(results) < 1:
#                     hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
#                     results = [[" ", ["<space>"], [2], 10, 6]] * nbest
#                 time_end = time.time()
#                 forward_time = time_end - time_beg
#                 lfr_factor = results[0][-1]
#                 length = results[0][-2]
#                 forward_time_total += forward_time
#                 length_total += length
#                 logging.info(
#                     "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".
#                         format(length, forward_time, 100 * forward_time / (length*lfr_factor)))
#                 result_cur = [results[0][:-2]]
#                 if j == 0:
#                     result_segments = result_cur
#                 else:
#                     result_segments = [[result_segments[0][i] + result_cur[0][i] for i in range(len(result_cur[0]))]]
#
#             key = keys[0]
#             result = result_segments[0]
#             text, token, token_int, time_stamp = result
#
#             # Create a directory: outdir/{n}best_recog
#             if writer is not None:
#                 ibest_writer = writer[f"1best_recog"]
#
#                 # Write the result to each file
#                 ibest_writer["token"][key] = " ".join(token)
#                 ibest_writer["token_int"][key] = " ".join(map(str, token_int))
#
#             if text is not None:
#                 postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
#                 if len(postprocessed_result) == 3:
#                     text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1], postprocessed_result[2]
#                     text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
#                     text_postprocessed_punc_time_stamp = "predictions: {}  time_stamp: {}".format(text_postprocessed_punc, time_stamp_postprocessed)
#                 else:
#                     text_postprocessed = postprocessed_result
#                     time_stamp_postprocessed = None
#                     word_lists = None
#                     text_postprocessed_punc_time_stamp = None
#                     punc_id_list = None
#
#                 item = {'key': key, 'value': text_postprocessed_punc_time_stamp, 'text': text_postprocessed, 'time_stamp': time_stamp_postprocessed, 'punc': punc_id_list}
#                 asr_result_list.append(item)
#                 finish_count += 1
#                 # asr_utils.print_progress(finish_count / file_count)
#                 if writer is not None:
#                     ibest_writer["text"][key] = text_postprocessed
#                     if time_stamp_writer and time_stamp_postprocessed is not None:
#                         ibest_writer["time_stamp"][key] = " ".join(["-".join(map(str, ts)) for ts in time_stamp_postprocessed])
#
#             logging.info("decoding, utt: {}, predictions: {}, time_stamp: {}".format(key, text_postprocessed_punc, time_stamp_postprocessed))
#
#     logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
#                  format(length_total, forward_time_total, 100 * forward_time_total / (length_total*lfr_factor)))
#     return asr_result_list
def inference(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    raw_inputs: Union[np.ndarray, torch.Tensor] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    streaming: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    vad_infer_config: Optional[str] = None,
    vad_model_file: Optional[str] = None,
    vad_cmvn_file: Optional[str] = None,
    time_stamp_writer: bool = False,
    punc_infer_config: Optional[str] = None,
    punc_model_file: Optional[str] = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        raw_inputs: Union[np.ndarray, torch.Tensor] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        streaming: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        vad_infer_config: Optional[str] = None,
        vad_model_file: Optional[str] = None,
        vad_cmvn_file: Optional[str] = None,
        time_stamp_writer: bool = False,
        punc_infer_config: Optional[str] = None,
        punc_model_file: Optional[str] = None,
        **kwargs,
):
    inference_pipeline = inference_modelscope(
        maxlenratio=maxlenratio,
        minlenratio=minlenratio,
@@ -636,61 +445,69 @@
    )
    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    vad_infer_config: Optional[str] = None,
    vad_model_file: Optional[str] = None,
    vad_cmvn_file: Optional[str] = None,
    time_stamp_writer: bool = False,
    punc_infer_config: Optional[str] = None,
    punc_model_file: Optional[str] = None,
    **kwargs,
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        vad_infer_config: Optional[str] = None,
        vad_model_file: Optional[str] = None,
        vad_cmvn_file: Optional[str] = None,
        time_stamp_writer: bool = True,
        punc_infer_config: Optional[str] = None,
        punc_model_file: Optional[str] = None,
        outputs_dict: Optional[bool] = True,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    if word_lm_train_config is not None:
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if param_dict is not None:
        hotword_list_or_file = param_dict.get('hotword')
    else:
        hotword_list_or_file = None
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2vadsegment
    speech2vadsegment_kwargs = dict(
        vad_infer_config=vad_infer_config,
@@ -701,7 +518,7 @@
    )
    # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
    speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
    # 3. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
@@ -721,15 +538,36 @@
        ngram_weight=ngram_weight,
        penalty=penalty,
        nbest=nbest,
        hotword_list_or_file=hotword_list_or_file,
    )
    speech2text = Speech2Text(**speech2text_kwargs)
    text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
    text2punc = None
    if punc_model_file is not None:
        text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
    if output_dir is not None:
        writer = DatadirWriter(output_dir)
        ibest_writer = writer[f"1best_recog"]
        ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
                 fs: dict = None,
                 param_dict: dict = None,
                 **kwargs,
                 ):
        hotword_list_or_file = None
        if param_dict is not None:
            hotword_list_or_file = param_dict.get('hotword')
        if 'hotword' in kwargs:
            hotword_list_or_file = kwargs['hotword']
        if speech2text.hotword_list is None:
            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
@@ -738,6 +576,7 @@
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            batch_size=1,
            key_file=key_file,
            num_workers=num_workers,
@@ -746,34 +585,33 @@
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        forward_time_total = 0.0
        length_total = 0.0
        if param_dict is not None:
            use_timestamp = param_dict.get('use_timestamp', True)
        else:
            use_timestamp = True
        finish_count = 0
        file_count = 1
        lfr_factor = 6
        # 7 .Start for-loop
        asr_result_list = []
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        writer = None
        if output_path is not None:
            writer = DatadirWriter(output_path)
        else:
            writer = None
            ibest_writer = writer[f"1best_recog"]
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
            logging.info("decoding, utt_id: {}".format(keys))
            # N-best list of (text, token, token_int, hyp_object)
            time_beg = time.time()
            vad_results = speech2vadsegment(**batch)
            time_end = time.time()
            fbanks, vadsegments = vad_results[0], vad_results[1]
            for i, segments in enumerate(vadsegments):
                result_segments = [["", [], [], ]]
                result_segments = [["", [], [], []]]
                for j, segment_idx in enumerate(segments):
                    bed_idx, end_idx = int(segment_idx[0] / 10), int(segment_idx[1] / 10)
                    segment = fbanks[:, bed_idx:end_idx, :].to(device)
@@ -782,159 +620,66 @@
                             "end_time": vadsegments[i][j][1]}
                    results = speech2text(**batch)
                    if len(results) < 1:
                        hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                        results = [[" ", ["<space>"], [2], 10, 6]] * nbest
                    time_end = time.time()
                    forward_time = time_end - time_beg
                    lfr_factor = results[0][-1]
                    length = results[0][-2]
                    forward_time_total += forward_time
                    length_total += length
                    logging.info(
                        "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".
                        format(length, forward_time, 100 * forward_time / (length * lfr_factor)))
                        continue
                    result_cur = [results[0][:-2]]
                    if j == 0:
                        result_segments = result_cur
                    else:
                        result_segments = [[result_segments[0][i] + result_cur[0][i] for i in range(len(result_cur[0]))]]
                        result_segments = [
                            [result_segments[0][i] + result_cur[0][i] for i in range(len(result_cur[0]))]]
                key = keys[0]
                result = result_segments[0]
                text, token, token_int, time_stamp = result
                # Create a directory: outdir/{n}best_recog
                text, token, token_int = result[0], result[1], result[2]
                time_stamp = None if len(result) < 4 else result[3]
                if use_timestamp and time_stamp is not None:
                    postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
                else:
                    postprocessed_result = postprocess_utils.sentence_postprocess(token)
                text_postprocessed = ""
                time_stamp_postprocessed = ""
                text_postprocessed_punc = postprocessed_result
                if len(postprocessed_result) == 3:
                    text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
                                                                               postprocessed_result[1], \
                                                                               postprocessed_result[2]
                else:
                    text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
                text_postprocessed_punc = text_postprocessed
                punc_id_list = []
                if len(word_lists) > 0 and text2punc is not None:
                    text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
                item = {'key': key, 'value': text_postprocessed_punc}
                if text_postprocessed != "":
                    item['text_postprocessed'] = text_postprocessed
                if time_stamp_postprocessed != "":
                    item['time_stamp'] = time_stamp_postprocessed
                item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
                asr_result_list.append(item)
                finish_count += 1
                # asr_utils.print_progress(finish_count / file_count)
                if writer is not None:
                    ibest_writer = writer[f"1best_recog"]
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                if text is not None:
                    postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
                    if len(postprocessed_result) == 3:
                        text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
                                                                                   postprocessed_result[1], \
                                                                                   postprocessed_result[2]
                        text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
                        text_postprocessed_punc_time_stamp = "predictions: {}  time_stamp: {}".format(
                            text_postprocessed_punc, time_stamp_postprocessed)
                    else:
                        text_postprocessed = postprocessed_result
                        time_stamp_postprocessed = None
                        word_lists = None
                        text_postprocessed_punc_time_stamp = None
                        punc_id_list = None
                    item = {'key': key, 'value': text_postprocessed_punc_time_stamp, 'text': text_postprocessed,
                            'time_stamp': time_stamp_postprocessed, 'punc': punc_id_list}
                    asr_result_list.append(item)
                    finish_count += 1
                    # asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text_postprocessed
                        if time_stamp_writer and time_stamp_postprocessed is not None:
                            ibest_writer["time_stamp"][key] = " ".join(
                                ["-".join(map(str, ts)) for ts in time_stamp_postprocessed])
                logging.info("decoding, utt: {}, predictions: {}, time_stamp: {}".format(key, text_postprocessed_punc,
                                                                                         time_stamp_postprocessed))
        logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
                     format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)))
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["vad"][key] = "{}".format(vadsegments)
                    ibest_writer["text"][key] = text_postprocessed
                    ibest_writer["text_with_punc"][key] = text_postprocessed_punc
                    if time_stamp_postprocessed is not None:
                        ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
                logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
        return asr_result_list
    return _forward
def Text2Punc(
    train_config: Optional[str],
    model_file: Optional[str],
    device: str = "cpu",
    dtype: str = "float32",
):
    # 2. Build Model
    model, train_args = PunctuationTask.build_model_from_file(
        train_config, model_file, device)
    # Wrape model to make model.nll() data-parallel
    wrapped_model = ForwardAdaptor(model, "inference")
    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
    # logging.info(f"Model:\n{model}")
    punc_list = train_args.punc_list
    period = 0
    for i in range(len(punc_list)):
        if punc_list[i] == ",":
            punc_list[i] = ","
        elif punc_list[i] == "?":
            punc_list[i] = "?"
        elif punc_list[i] == "。":
            period = i
    preprocessor = CommonPreprocessor(
        train=False,
        token_type="word",
        token_list=train_args.token_list,
        bpemodel=train_args.bpemodel,
        text_cleaner=train_args.cleaner,
        g2p_type=train_args.g2p,
        text_name="text",
        non_linguistic_symbols=train_args.non_linguistic_symbols,
    )
    print("start decoding!!!")
    def _forward(words, split_size = 20):
        cache_sent = []
        mini_sentences = split_to_mini_sentence(words, split_size)
        new_mini_sentence = ""
        new_mini_sentence_punc = ""
        for mini_sentence_i in range(len(mini_sentences)):
            mini_sentence = mini_sentences[mini_sentence_i]
            mini_sentence = cache_sent + mini_sentence
            data = {"text": " ".join(mini_sentence)}
            batch = preprocessor(data=data, uid="12938712838719")
            batch["text_lengths"] = torch.from_numpy(np.array([len(batch["text"])], dtype='int32'))
            batch["text"] = torch.from_numpy(batch["text"])
            # Extend one dimension to fake a batch dim.
            batch["text"] = torch.unsqueeze(batch["text"], 0)
            batch = to_device(batch, device)
            y, _ = wrapped_model(**batch)
            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
            punctuations = indices
            if indices.size()[0] != 1:
                punctuations = torch.squeeze(indices)
            assert punctuations.size()[0] == len(mini_sentence)
            # Search for the last Period/QuestionMark as cache
            if mini_sentence_i < len(mini_sentences) - 1:
                sentenceEnd = -1
                for i in range(len(punctuations) - 2, 1, -1):
                    if punc_list[punctuations[i]] == "。" or punc_list[punctuations[i]] == "?":
                        sentenceEnd = i
                        break
                cache_sent = mini_sentence[sentenceEnd + 1:]
                mini_sentence = mini_sentence[0:sentenceEnd + 1]
                punctuations = punctuations[0:sentenceEnd + 1]
            # if len(punctuations) == 0:
            #    continue
            punctuations_np = punctuations.cpu().numpy()
            new_mini_sentence_punc += "".join([str(x) for x in punctuations_np])
            words_with_punc = []
            for i in range(len(mini_sentence)):
                if i > 0:
                    if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
                        mini_sentence[i] = " " + mini_sentence[i]
                words_with_punc.append(mini_sentence[i])
                if punc_list[punctuations[i]] != "_":
                    words_with_punc.append(punc_list[punctuations[i]])
            new_mini_sentence += "".join(words_with_punc)
        return new_mini_sentence, new_mini_sentence_punc
    return _forward
def get_parser():
    parser = config_argparse.ArgumentParser(