游雁
2023-05-19 219c2482ab755fbd4e49dfbdee91bf1a8a4ec49a
funasr/bin/asr_infer.py
@@ -1,4 +1,8 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import sys
@@ -19,13 +23,16 @@
import numpy as np
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import BeamSearch
# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
@@ -47,13 +54,12 @@
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.bin.tp_infer import Speech2Timestamp
from funasr.bin.vad_inference import Speech2VadSegment
from funasr.bin.vad_infer import Speech2VadSegment
from funasr.bin.punc_infer import Text2Punc
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.tasks.asr import frontend_choices
class Speech2Text:
    """Speech2Text class
@@ -264,7 +270,6 @@
        
        assert check_return_type(results)
        return results
class Speech2TextParaformer:
    """Speech2Text class
@@ -483,15 +488,20 @@
                nbest_hyps = nbest_hyps[: self.nbest]
            else:
                yseq = am_scores.argmax(dim=-1)
                score = am_scores.max(dim=-1)[0]
                score = torch.sum(score, dim=-1)
                # pad with mask tokens to ensure compatibility with sos/eos tokens
                yseq = torch.tensor(
                    [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
                )
                if pre_token_length[i] == 0:
                    yseq = torch.tensor(
                        [self.asr_model.sos] + [self.asr_model.eos], device=yseq.device
                    )
                    score = torch.tensor(0.0, device=yseq.device)
                else:
                    yseq = am_scores.argmax(dim=-1)
                    score = am_scores.max(dim=-1)[0]
                    score = torch.sum(score, dim=-1)
                    # pad with mask tokens to ensure compatibility with sos/eos tokens
                    yseq = torch.tensor(
                        [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
                    )
                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
            for hyp in nbest_hyps:
                assert isinstance(hyp, (Hypothesis)), type(hyp)
@@ -744,10 +754,13 @@
            feats = cache_en["feats"]
            feats_len = torch.tensor([feats.shape[1]])
            self.asr_model.frontend = None
            self.frontend.cache_reset()
            results = self.infer(feats, feats_len, cache)
            return results
        else:
            if self.frontend is not None:
                if cache_en["start_idx"] == 0:
                    self.frontend.cache_reset()
                feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"])
                feats = to_device(feats, device=self.device)
                feats_len = feats_len.int()
@@ -757,23 +770,6 @@
                feats_len = speech_lengths
            if feats.shape[1] != 0:
                if cache_en["is_final"]:
                    if feats.shape[1] + cache_en["chunk_size"][2] < cache_en["chunk_size"][1]:
                        cache_en["last_chunk"] = True
                    else:
                        # first chunk
                        feats_chunk1 = feats[:, :cache_en["chunk_size"][1], :]
                        feats_len = torch.tensor([feats_chunk1.shape[1]])
                        results_chunk1 = self.infer(feats_chunk1, feats_len, cache)
                        # last chunk
                        cache_en["last_chunk"] = True
                        feats_chunk2 = feats[:, -(feats.shape[1] + cache_en["chunk_size"][2] - cache_en["chunk_size"][1]):, :]
                        feats_len = torch.tensor([feats_chunk2.shape[1]])
                        results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
                        return [" ".join(results_chunk1 + results_chunk2)]
                results = self.infer(feats, feats_len, cache)
        return results
@@ -838,7 +834,6 @@
        # assert check_return_type(results)
        return results
class Speech2TextUniASR:
    """Speech2Text class
@@ -1072,9 +1067,7 @@
        assert check_return_type(results)
        return results
class Speech2TextMFCCA:
    """Speech2Text class
@@ -1114,6 +1107,7 @@
        assert check_argument_types()
        
        # 1. Build ASR model
        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
@@ -1270,3 +1264,579 @@
        return results
class Speech2TextTransducer:
    """Speech2Text class for Transducer models.
    Args:
        asr_train_config: ASR model training config path.
        asr_model_file: ASR model path.
        beam_search_config: Beam search config path.
        lm_train_config: Language Model training config path.
        lm_file: Language Model config path.
        token_type: Type of token units.
        bpemodel: BPE model path.
        device: Device to use for inference.
        beam_size: Size of beam during search.
        dtype: Data type.
        lm_weight: Language model weight.
        quantize_asr_model: Whether to apply dynamic quantization to ASR model.
        quantize_modules: List of module names to apply dynamic quantization on.
        quantize_dtype: Dynamic quantization data type.
        nbest: Number of final hypothesis.
        streaming: Whether to perform chunk-by-chunk inference.
        chunk_size: Number of frames in chunk AFTER subsampling.
        left_context: Number of frames in left context AFTER subsampling.
        right_context: Number of frames in right context AFTER subsampling.
        display_partial_hypotheses: Whether to display partial hypotheses.
    """
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        beam_search_config: Dict[str, Any] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        beam_size: int = 5,
        dtype: str = "float32",
        lm_weight: float = 1.0,
        quantize_asr_model: bool = False,
        quantize_modules: List[str] = None,
        quantize_dtype: str = "qint8",
        nbest: int = 1,
        streaming: bool = False,
        simu_streaming: bool = False,
        chunk_size: int = 16,
        left_context: int = 32,
        right_context: int = 0,
        display_partial_hypotheses: bool = False,
    ) -> None:
        """Construct a Speech2Text object."""
        super().__init__()
        assert check_argument_types()
        from funasr.tasks.asr import ASRTransducerTask
        asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
        if quantize_asr_model:
            if quantize_modules is not None:
                if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
                    raise ValueError(
                        "Only 'Linear' and 'LSTM' modules are currently supported"
                        " by PyTorch and in --quantize_modules"
                    )
                q_config = set([getattr(torch.nn, q) for q in quantize_modules])
            else:
                q_config = {torch.nn.Linear}
            if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
                raise ValueError(
                    "float16 dtype for dynamic quantization is not supported with torch"
                    " version < 1.5.0. Switching to qint8 dtype instead."
                )
            q_dtype = getattr(torch, quantize_dtype)
            asr_model = torch.quantization.quantize_dynamic(
                asr_model, q_config, dtype=q_dtype
            ).eval()
        else:
            asr_model.to(dtype=getattr(torch, dtype)).eval()
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            )
            lm_scorer = lm.lm
        else:
            lm_scorer = None
        # 4. Build BeamSearch object
        if beam_search_config is None:
            beam_search_config = {}
        beam_search = BeamSearchTransducer(
            asr_model.decoder,
            asr_model.joint_network,
            beam_size,
            lm=lm_scorer,
            lm_weight=lm_weight,
            nbest=nbest,
            **beam_search_config,
        )
        token_list = asr_model.token_list
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
            if bpemodel is not None:
                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
            else:
                tokenizer = None
        else:
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.device = device
        self.dtype = dtype
        self.nbest = nbest
        self.converter = converter
        self.tokenizer = tokenizer
        self.beam_search = beam_search
        self.streaming = streaming
        self.simu_streaming = simu_streaming
        self.chunk_size = max(chunk_size, 0)
        self.left_context = left_context
        self.right_context = max(right_context, 0)
        if not streaming or chunk_size == 0:
            self.streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False
        if not simu_streaming or chunk_size == 0:
            self.simu_streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False
        self.frontend = frontend
        self.window_size = self.chunk_size + self.right_context
        if self.streaming:
            self._ctx = self.asr_model.encoder.get_encoder_input_size(
                self.window_size
            )
            self.last_chunk_length = (
                self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
            )
            self.reset_inference_cache()
    def reset_inference_cache(self) -> None:
        """Reset Speech2Text parameters."""
        self.frontend_cache = None
        self.asr_model.encoder.reset_streaming_cache(
            self.left_context, device=self.device
        )
        self.beam_search.reset_inference_cache()
        self.num_processed_frames = torch.tensor([[0]], device=self.device)
    @torch.no_grad()
    def streaming_decode(
        self,
        speech: Union[torch.Tensor, np.ndarray],
        is_final: bool = True,
    ) -> List[HypothesisTransducer]:
        """Speech2Text streaming call.
        Args:
            speech: Chunk of speech data. (S)
            is_final: Whether speech corresponds to the final chunk of data.
        Returns:
            nbest_hypothesis: N-best hypothesis.
        """
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if is_final:
            if self.streaming and speech.size(0) < self.last_chunk_length:
                pad = torch.zeros(
                    self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype
                )
                speech = torch.cat([speech, pad],
                                   dim=0)  # feats, feats_length = self.apply_frontend(speech, is_final=is_final)
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out = self.asr_model.encoder.chunk_forward(
            feats,
            feats_lengths,
            self.num_processed_frames,
            chunk_size=self.chunk_size,
            left_context=self.left_context,
            right_context=self.right_context,
        )
        nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
        self.num_processed_frames += self.chunk_size
        if is_final:
            self.reset_inference_cache()
        return nbest_hyps
    @torch.no_grad()
    def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
        """Speech2Text call.
        Args:
            speech: Speech data. (S)
        Returns:
            nbest_hypothesis: N-best hypothesis.
        """
        assert check_argument_types()
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context,
                                                            self.right_context)
        nbest_hyps = self.beam_search(enc_out[0])
        return nbest_hyps
    @torch.no_grad()
    def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
        """Speech2Text call.
        Args:
            speech: Speech data. (S)
        Returns:
            nbest_hypothesis: N-best hypothesis.
        """
        assert check_argument_types()
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
        nbest_hyps = self.beam_search(enc_out[0])
        return nbest_hyps
    def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]:
        """Build partial or final results from the hypotheses.
        Args:
            nbest_hyps: N-best hypothesis.
        Returns:
            results: Results containing different representation for the hypothesis.
        """
        results = []
        for hyp in nbest_hyps:
            token_int = list(filter(lambda x: x != 0, hyp.yseq))
            token = self.converter.ids2tokens(token_int)
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, hyp))
            assert check_return_type(results)
        return results
    @staticmethod
    def from_pretrained(
        model_tag: Optional[str] = None,
        **kwargs: Optional[Any],
    ) -> Speech2Text:
        """Build Speech2Text instance from the pretrained model.
        Args:
            model_tag: Model tag of the pretrained models.
        Return:
            : Speech2Text instance.
        """
        if model_tag is not None:
            try:
                from espnet_model_zoo.downloader import ModelDownloader
            except ImportError:
                logging.error(
                    "`espnet_model_zoo` is not installed. "
                    "Please install via `pip install -U espnet_model_zoo`."
                )
                raise
            d = ModelDownloader()
            kwargs.update(**d.download_and_unpack(model_tag))
        return Speech2TextTransducer(**kwargs)
class Speech2TextSAASR:
    """Speech2Text class
    Examples:
        >>> import soundfile
        >>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2text(audio)
        [(text, token, token_int, hypothesis object), ...]
    """
    def __init__(
        self,
        asr_train_config: Union[Path, str] = None,
        asr_model_file: Union[Path, str] = None,
        cmvn_file: Union[Path, str] = None,
        lm_train_config: Union[Path, str] = None,
        lm_file: Union[Path, str] = None,
        token_type: str = None,
        bpemodel: str = None,
        device: str = "cpu",
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
        batch_size: int = 1,
        dtype: str = "float32",
        beam_size: int = 20,
        ctc_weight: float = 0.5,
        lm_weight: float = 1.0,
        ngram_weight: float = 0.9,
        penalty: float = 0.0,
        nbest: int = 1,
        streaming: bool = False,
        frontend_conf: dict = None,
        **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        from funasr.tasks.sa_asr import ASRTask
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            if asr_train_args.frontend == 'wav_frontend':
                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
            else:
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        decoder = asr_model.decoder
        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        token_list = asr_model.token_list
        scorers.update(
            decoder=decoder,
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, None, device
            )
            scorers["lm"] = lm.lm
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        # 4. Build BeamSearch object
        # transducer is not supported now
        beam_search_transducer = None
        from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
        weights = dict(
            decoder=1.0 - ctc_weight,
            ctc=ctc_weight,
            lm=lm_weight,
            ngram=ngram_weight,
            length_bonus=penalty,
        )
        beam_search = BeamSearch(
            beam_size=beam_size,
            weights=weights,
            scorers=scorers,
            sos=asr_model.sos,
            eos=asr_model.eos,
            vocab_size=len(token_list),
            token_list=token_list,
            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
        )
        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
            if bpemodel is not None:
                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
            else:
                tokenizer = None
        else:
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
        self.tokenizer = tokenizer
        self.beam_search = beam_search
        self.beam_search_transducer = beam_search_transducer
        self.maxlenratio = maxlenratio
        self.minlenratio = minlenratio
        self.device = device
        self.dtype = dtype
        self.nbest = nbest
        self.frontend = frontend
    @torch.no_grad()
    def __call__(
        self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
        profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
    ) -> List[
        Tuple[
            Optional[str],
            Optional[str],
            List[str],
            List[int],
            Union[HypothesisSAASR],
        ]
    ]:
        """Inference
        Args:
            speech: Input speech data
        Returns:
            text, text_id, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if isinstance(profile, np.ndarray):
            profile = torch.tensor(profile)
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
            self.asr_model.frontend = None
        else:
            feats = speech
            feats_len = speech_lengths
        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
        batch = {"speech": feats, "speech_lengths": feats_len}
        # a. To device
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        asr_enc, _, spk_enc = self.asr_model.encode(**batch)
        if isinstance(asr_enc, tuple):
            asr_enc = asr_enc[0]
        if isinstance(spk_enc, tuple):
            spk_enc = spk_enc[0]
        assert len(asr_enc) == 1, len(asr_enc)
        assert len(spk_enc) == 1, len(spk_enc)
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
        )
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        for hyp in nbest_hyps:
            assert isinstance(hyp, (HypothesisSAASR)), type(hyp)
            # remove sos/eos and get results
            last_pos = -1
            if isinstance(hyp.yseq, list):
                token_int = hyp.yseq[1: last_pos]
            else:
                token_int = hyp.yseq[1: last_pos].tolist()
            spk_weigths = torch.stack(hyp.spk_weigths, dim=0)
            token_ori = self.converter.ids2tokens(token_int)
            text_ori = self.tokenizer.tokens2text(token_ori)
            text_ori_spklist = text_ori.split('$')
            cur_index = 0
            spk_choose = []
            for i in range(len(text_ori_spklist)):
                text_ori_split = text_ori_spklist[i]
                n = len(text_ori_split)
                spk_weights_local = spk_weigths[cur_index: cur_index + n]
                cur_index = cur_index + n + 1
                spk_weights_local = spk_weights_local.mean(dim=0)
                spk_choose_local = spk_weights_local.argmax(-1)
                spk_choose.append(spk_choose_local.item() + 1)
            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != 0, token_int))
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            text_spklist = text.split('$')
            assert len(spk_choose) == len(text_spklist)
            spk_list = []
            for i in range(len(text_spklist)):
                text_split = text_spklist[i]
                n = len(text_split)
                spk_list.append(str(spk_choose[i]) * n)
            text_id = '$'.join(spk_list)
            assert len(text) == len(text_id)
            results.append((text, text_id, token, token_int, hyp))
        assert check_return_type(results)
        return results