嘉渊
2023-06-14 3d70934e7fed7c0d3179fec340761466205cb3e9
update repo
1个文件已修改
543 ■■■■ 已修改文件
funasr/bin/asr_infer.py 543 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_infer.py
@@ -1,66 +1,48 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import sys
import time
import codecs
import copy
import logging
import os
import re
import codecs
import tempfile
import requests
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
from typing import Any
from typing import List
import numpy as np
import requests
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.modules.beam_search.beam_search import BeamSearch
# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.asr import ASRTask
from funasr.tasks.asr import frontend_choices
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.vad_infer import Speech2VadSegment
from funasr.bin.punc_infer import Text2Punc
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.tasks.asr import frontend_choices
class Speech2Text:
    """Speech2Text class
@@ -73,33 +55,33 @@
        [(text, token, token_int, hypothesis object), ...]
    """
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
        batch_size: int = 1,
        dtype: str = "float32",
        beam_size: int = 20,
        ctc_weight: float = 0.5,
        lm_weight: float = 1.0,
        ngram_weight: float = 0.9,
        penalty: float = 0.0,
        nbest: int = 1,
        streaming: bool = False,
        frontend_conf: dict = None,
        **kwargs,
            self,
            asr_train_config: Union[Path, str] = None,
            asr_model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            lm_train_config: Union[Path, str] = None,
            lm_file: Union[Path, str] = None,
            token_type: str = None,
            bpemodel: str = None,
            device: str = "cpu",
            maxlenratio: float = 0.0,
            minlenratio: float = 0.0,
            batch_size: int = 1,
            dtype: str = "float32",
            beam_size: int = 20,
            ctc_weight: float = 0.5,
            lm_weight: float = 1.0,
            ngram_weight: float = 0.9,
            penalty: float = 0.0,
            nbest: int = 1,
            streaming: bool = False,
            frontend_conf: dict = None,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
@@ -113,13 +95,13 @@
                from funasr.tasks.asr import frontend_choices
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        decoder = asr_model.decoder
        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        token_list = asr_model.token_list
        scorers.update(
@@ -127,24 +109,24 @@
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, None, device
            )
            scorers["lm"] = lm.lm
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        # 4. Build BeamSearch object
        # transducer is not supported now
        beam_search_transducer = None
        from funasr.modules.beam_search.beam_search import BeamSearch
        weights = dict(
            decoder=1.0 - ctc_weight,
            ctc=ctc_weight,
@@ -162,13 +144,13 @@
            token_list=token_list,
            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
        )
        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
@@ -180,7 +162,7 @@
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
@@ -193,10 +175,10 @@
        self.dtype = dtype
        self.nbest = nbest
        self.frontend = frontend
    @torch.no_grad()
    def __call__(
        self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
    ) -> List[
        Tuple[
            Optional[str],
@@ -214,11 +196,11 @@
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
@@ -229,48 +211,49 @@
            feats_len = speech_lengths
        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
        batch = {"speech": feats, "speech_lengths": feats_len}
        # a. To device
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        enc, _ = self.asr_model.encode(**batch)
        if isinstance(enc, tuple):
            enc = enc[0]
        assert len(enc) == 1, len(enc)
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
        )
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        for hyp in nbest_hyps:
            assert isinstance(hyp, (Hypothesis)), type(hyp)
            # remove sos/eos and get results
            last_pos = -1
            if isinstance(hyp.yseq, list):
                token_int = hyp.yseq[1:last_pos]
            else:
                token_int = hyp.yseq[1:last_pos].tolist()
            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != 0, token_int))
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, hyp))
        assert check_return_type(results)
        return results
class Speech2TextParaformer:
    """Speech2Text class
@@ -466,18 +449,21 @@
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, NeatContextualParaformer):
        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
                                                                                   NeatContextualParaformer):
            if self.hotword_list:
                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                     pre_token_length)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        else:
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                     pre_token_length, hw_list=self.hotword_list)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        if isinstance(self.asr_model, BiCifParaformer):
            _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
                                                                                   pre_token_length)  # test no bias cif2
                                                                                pre_token_length)  # test no bias cif2
        results = []
        b, n, d = decoder_out.size()
@@ -527,12 +513,11 @@
                    text = None
                timestamp = []
                if isinstance(self.asr_model, BiCifParaformer):
                    _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i]*3],
                                                            us_peaks[i][:enc_len[i]*3],
                                                            copy.copy(token),
                                                            vad_offset=begin_time)
                    _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i] * 3],
                                                               us_peaks[i][:enc_len[i] * 3],
                                                               copy.copy(token),
                                                               vad_offset=begin_time)
                results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
        # assert check_return_type(results)
        return results
@@ -590,6 +575,7 @@
        else:
            hotword_list = None
        return hotword_list
class Speech2TextParaformerOnline:
    """Speech2Text class
@@ -789,7 +775,7 @@
        enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
        predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
        pre_acoustic_embeds, pre_token_length= predictor_outs[0], predictor_outs[1]
        pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
        if torch.max(pre_token_length) < 1:
            return []
        decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
@@ -839,11 +825,12 @@
                        postprocessed_result += item + " "
                    else:
                        postprocessed_result += item
                results.append(postprocessed_result)
        # assert check_return_type(results)
        return results
class Speech2TextUniASR:
    """Speech2Text class
@@ -1077,7 +1064,7 @@
        assert check_return_type(results)
        return results
class Speech2TextMFCCA:
    """Speech2Text class
@@ -1090,45 +1077,45 @@
        [(text, token, token_int, hypothesis object), ...]
    """
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
        batch_size: int = 1,
        dtype: str = "float32",
        beam_size: int = 20,
        ctc_weight: float = 0.5,
        lm_weight: float = 1.0,
        ngram_weight: float = 0.9,
        penalty: float = 0.0,
        nbest: int = 1,
        streaming: bool = False,
        **kwargs,
            self,
            asr_train_config: Union[Path, str] = None,
            asr_model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            lm_train_config: Union[Path, str] = None,
            lm_file: Union[Path, str] = None,
            token_type: str = None,
            bpemodel: str = None,
            device: str = "cpu",
            maxlenratio: float = 0.0,
            minlenratio: float = 0.0,
            batch_size: int = 1,
            dtype: str = "float32",
            beam_size: int = 20,
            ctc_weight: float = 0.5,
            lm_weight: float = 1.0,
            ngram_weight: float = 0.9,
            penalty: float = 0.0,
            nbest: int = 1,
            streaming: bool = False,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        decoder = asr_model.decoder
        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        token_list = asr_model.token_list
        scorers.update(
@@ -1136,7 +1123,7 @@
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
@@ -1148,11 +1135,11 @@
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        # 4. Build BeamSearch object
        # transducer is not supported now
        beam_search_transducer = None
        weights = dict(
            decoder=1.0 - ctc_weight,
            ctc=ctc_weight,
@@ -1176,7 +1163,7 @@
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
@@ -1188,7 +1175,7 @@
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
@@ -1200,10 +1187,10 @@
        self.device = device
        self.dtype = dtype
        self.nbest = nbest
    @torch.no_grad()
    def __call__(
        self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
    ) -> List[
        Tuple[
            Optional[str],
@@ -1231,45 +1218,45 @@
        # lenghts: (1,)
        lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
        batch = {"speech": speech, "speech_lengths": lengths}
        # a. To device
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        enc, _ = self.asr_model.encode(**batch)
        assert len(enc) == 1, len(enc)
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
        )
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        for hyp in nbest_hyps:
            assert isinstance(hyp, (Hypothesis)), type(hyp)
            # remove sos/eos and get results
            last_pos = -1
            if isinstance(hyp.yseq, list):
                token_int = hyp.yseq[1:last_pos]
            else:
                token_int = hyp.yseq[1:last_pos].tolist()
            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != 0, token_int))
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, hyp))
        assert check_return_type(results)
        return results
@@ -1298,45 +1285,45 @@
        right_context: Number of frames in right context AFTER subsampling.
        display_partial_hypotheses: Whether to display partial hypotheses.
    """
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        beam_search_config: Dict[str, Any] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        beam_size: int = 5,
        dtype: str = "float32",
        lm_weight: float = 1.0,
        quantize_asr_model: bool = False,
        quantize_modules: List[str] = None,
        quantize_dtype: str = "qint8",
        nbest: int = 1,
        streaming: bool = False,
        simu_streaming: bool = False,
        chunk_size: int = 16,
        left_context: int = 32,
        right_context: int = 0,
        display_partial_hypotheses: bool = False,
            self,
            asr_train_config: Union[Path, str] = None,
            asr_model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            beam_search_config: Dict[str, Any] = None,
            lm_train_config: Union[Path, str] = None,
            lm_file: Union[Path, str] = None,
            token_type: str = None,
            bpemodel: str = None,
            device: str = "cpu",
            beam_size: int = 5,
            dtype: str = "float32",
            lm_weight: float = 1.0,
            quantize_asr_model: bool = False,
            quantize_modules: List[str] = None,
            quantize_dtype: str = "qint8",
            nbest: int = 1,
            streaming: bool = False,
            simu_streaming: bool = False,
            chunk_size: int = 16,
            left_context: int = 32,
            right_context: int = 0,
            display_partial_hypotheses: bool = False,
    ) -> None:
        """Construct a Speech2Text object."""
        super().__init__()
        assert check_argument_types()
        from funasr.tasks.asr import ASRTransducerTask
        asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
        if quantize_asr_model:
            if quantize_modules is not None:
                if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
@@ -1344,24 +1331,24 @@
                        "Only 'Linear' and 'LSTM' modules are currently supported"
                        " by PyTorch and in --quantize_modules"
                    )
                q_config = set([getattr(torch.nn, q) for q in quantize_modules])
            else:
                q_config = {torch.nn.Linear}
            if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
                raise ValueError(
                    "float16 dtype for dynamic quantization is not supported with torch"
                    " version < 1.5.0. Switching to qint8 dtype instead."
                )
            q_dtype = getattr(torch, quantize_dtype)
            asr_model = torch.quantization.quantize_dynamic(
                asr_model, q_config, dtype=q_dtype
            ).eval()
        else:
            asr_model.to(dtype=getattr(torch, dtype)).eval()
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
@@ -1369,11 +1356,11 @@
            lm_scorer = lm.lm
        else:
            lm_scorer = None
        # 4. Build BeamSearch object
        if beam_search_config is None:
            beam_search_config = {}
        beam_search = BeamSearchTransducer(
            asr_model.decoder,
            asr_model.joint_network,
@@ -1383,14 +1370,14 @@
            nbest=nbest,
            **beam_search_config,
        )
        token_list = asr_model.token_list
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
@@ -1402,60 +1389,60 @@
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.device = device
        self.dtype = dtype
        self.nbest = nbest
        self.converter = converter
        self.tokenizer = tokenizer
        self.beam_search = beam_search
        self.streaming = streaming
        self.simu_streaming = simu_streaming
        self.chunk_size = max(chunk_size, 0)
        self.left_context = left_context
        self.right_context = max(right_context, 0)
        if not streaming or chunk_size == 0:
            self.streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False
        if not simu_streaming or chunk_size == 0:
            self.simu_streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False
        self.frontend = frontend
        self.window_size = self.chunk_size + self.right_context
        if self.streaming:
            self._ctx = self.asr_model.encoder.get_encoder_input_size(
                self.window_size
            )
            self.last_chunk_length = (
                self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
                    self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
            )
            self.reset_inference_cache()
    def reset_inference_cache(self) -> None:
        """Reset Speech2Text parameters."""
        self.frontend_cache = None
        self.asr_model.encoder.reset_streaming_cache(
            self.left_context, device=self.device
        )
        self.beam_search.reset_inference_cache()
        self.num_processed_frames = torch.tensor([[0]], device=self.device)
    @torch.no_grad()
    def streaming_decode(
        self,
        speech: Union[torch.Tensor, np.ndarray],
        is_final: bool = True,
            self,
            speech: Union[torch.Tensor, np.ndarray],
            is_final: bool = True,
    ) -> List[HypothesisTransducer]:
        """Speech2Text streaming call.
        Args:
@@ -1473,13 +1460,13 @@
                )
                speech = torch.cat([speech, pad],
                                   dim=0)  # feats, feats_length = self.apply_frontend(speech, is_final=is_final)
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out = self.asr_model.encoder.chunk_forward(
@@ -1491,14 +1478,14 @@
            right_context=self.right_context,
        )
        nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
        self.num_processed_frames += self.chunk_size
        if is_final:
            self.reset_inference_cache()
        return nbest_hyps
    @torch.no_grad()
    def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
        """Speech2Text call.
@@ -1508,29 +1495,29 @@
            nbest_hypothesis: N-best hypothesis.
        """
        assert check_argument_types()
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            speech = torch.unsqueeze(speech, axis=0)
            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
        else:
            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context,
                                                            self.right_context)
        nbest_hyps = self.beam_search(enc_out[0])
        return nbest_hyps
    @torch.no_grad()
    def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
        """Speech2Text call.
@@ -1540,7 +1527,7 @@
            nbest_hypothesis: N-best hypothesis.
        """
        assert check_argument_types()
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
@@ -1548,19 +1535,19 @@
            speech = torch.unsqueeze(speech, axis=0)
            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
        else:
            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
        nbest_hyps = self.beam_search(enc_out[0])
        return nbest_hyps
    def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]:
        """Build partial or final results from the hypotheses.
        Args:
@@ -1569,26 +1556,26 @@
            results: Results containing different representation for the hypothesis.
        """
        results = []
        for hyp in nbest_hyps:
            token_int = list(filter(lambda x: x != 0, hyp.yseq))
            token = self.converter.ids2tokens(token_int)
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, hyp))
            assert check_return_type(results)
        return results
    @staticmethod
    def from_pretrained(
        model_tag: Optional[str] = None,
        **kwargs: Optional[Any],
            model_tag: Optional[str] = None,
            **kwargs: Optional[Any],
    ) -> Speech2Text:
        """Build Speech2Text instance from the pretrained model.
        Args:
@@ -1599,7 +1586,7 @@
        if model_tag is not None:
            try:
                from espnet_model_zoo.downloader import ModelDownloader
            except ImportError:
                logging.error(
                    "`espnet_model_zoo` is not installed. "
@@ -1608,7 +1595,7 @@
                raise
            d = ModelDownloader()
            kwargs.update(**d.download_and_unpack(model_tag))
        return Speech2TextTransducer(**kwargs)
@@ -1623,33 +1610,33 @@
        [(text, token, token_int, hypothesis object), ...]
    """
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
        batch_size: int = 1,
        dtype: str = "float32",
        beam_size: int = 20,
        ctc_weight: float = 0.5,
        lm_weight: float = 1.0,
        ngram_weight: float = 0.9,
        penalty: float = 0.0,
        nbest: int = 1,
        streaming: bool = False,
        frontend_conf: dict = None,
        **kwargs,
            self,
            asr_train_config: Union[Path, str] = None,
            asr_model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            lm_train_config: Union[Path, str] = None,
            lm_file: Union[Path, str] = None,
            token_type: str = None,
            bpemodel: str = None,
            device: str = "cpu",
            maxlenratio: float = 0.0,
            minlenratio: float = 0.0,
            batch_size: int = 1,
            dtype: str = "float32",
            beam_size: int = 20,
            ctc_weight: float = 0.5,
            lm_weight: float = 1.0,
            ngram_weight: float = 0.9,
            penalty: float = 0.0,
            nbest: int = 1,
            streaming: bool = False,
            frontend_conf: dict = None,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        from funasr.tasks.sa_asr import ASRTask
        scorers = {}
@@ -1663,13 +1650,13 @@
            else:
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        decoder = asr_model.decoder
        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        token_list = asr_model.token_list
        scorers.update(
@@ -1677,24 +1664,24 @@
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, None, device
            )
            scorers["lm"] = lm.lm
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        # 4. Build BeamSearch object
        # transducer is not supported now
        beam_search_transducer = None
        from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
        weights = dict(
            decoder=1.0 - ctc_weight,
            ctc=ctc_weight,
@@ -1712,13 +1699,13 @@
            token_list=token_list,
            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
        )
        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
@@ -1730,7 +1717,7 @@
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
@@ -1743,11 +1730,11 @@
        self.dtype = dtype
        self.nbest = nbest
        self.frontend = frontend
    @torch.no_grad()
    def __call__(
        self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
        profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
            profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
    ) -> List[
        Tuple[
            Optional[str],
@@ -1766,14 +1753,14 @@
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if isinstance(profile, np.ndarray):
            profile = torch.tensor(profile)
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
@@ -1784,10 +1771,10 @@
            feats_len = speech_lengths
        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
        batch = {"speech": feats, "speech_lengths": feats_len}
        # a. To device
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        asr_enc, _, spk_enc = self.asr_model.encode(**batch)
        if isinstance(asr_enc, tuple):
@@ -1796,30 +1783,30 @@
            spk_enc = spk_enc[0]
        assert len(asr_enc) == 1, len(asr_enc)
        assert len(spk_enc) == 1, len(spk_enc)
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
        )
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        for hyp in nbest_hyps:
            assert isinstance(hyp, (HypothesisSAASR)), type(hyp)
            # remove sos/eos and get results
            last_pos = -1
            if isinstance(hyp.yseq, list):
                token_int = hyp.yseq[1: last_pos]
            else:
                token_int = hyp.yseq[1: last_pos].tolist()
            spk_weigths = torch.stack(hyp.spk_weigths, dim=0)
            token_ori = self.converter.ids2tokens(token_int)
            text_ori = self.tokenizer.tokens2text(token_ori)
            text_ori_spklist = text_ori.split('$')
            cur_index = 0
            spk_choose = []
@@ -1831,32 +1818,32 @@
                spk_weights_local = spk_weights_local.mean(dim=0)
                spk_choose_local = spk_weights_local.argmax(-1)
                spk_choose.append(spk_choose_local.item() + 1)
            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != 0, token_int))
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            text_spklist = text.split('$')
            assert len(spk_choose) == len(text_spklist)
            spk_list = []
            for i in range(len(text_spklist)):
                text_split = text_spklist[i]
                n = len(text_split)
                spk_list.append(str(spk_choose[i]) * n)
            text_id = '$'.join(spk_list)
            assert len(text) == len(text_id)
            results.append((text, text_id, token, token_int, hyp))
        assert check_return_type(results)
        return results