游雁
2023-10-10 580b11b57ac4b62f7e2acda73813a4e10e8e4cd3
funasr/bin/asr_infer.py
@@ -22,9 +22,7 @@
import requests
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from typeguard import check_return_type
from  funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
@@ -78,7 +76,6 @@
            frontend_conf: dict = None,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
@@ -192,7 +189,6 @@
            text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
@@ -248,7 +244,6 @@
                text = None
            results.append((text, token, token_int, hyp))
        assert check_return_type(results)
        return results
@@ -285,10 +280,10 @@
            nbest: int = 1,
            frontend_conf: dict = None,
            hotword_list_or_file: str = None,
            clas_scale: float = 1.0,
            decoding_ind: int = 0,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
@@ -377,10 +372,12 @@
        self.asr_train_args = asr_train_args
        self.converter = converter
        self.tokenizer = tokenizer
        self.cmvn_file = cmvn_file
        # 6. [Optional] Build hotword list from str, local file or url
        self.hotword_list = None
        self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
        self.clas_scale = clas_scale
        is_use_lm = lm_weight != 0.0 and lm_file is not None
        if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
@@ -402,7 +399,7 @@
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
            begin_time: int = 0, end_time: int = None,
            decoding_ind: int = None, begin_time: int = 0, end_time: int = None,
    ):
        """Inference
@@ -412,7 +409,6 @@
                text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
@@ -433,7 +429,9 @@
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
        if decoding_ind is None:
            decoding_ind = 0 if self.decoding_ind is None else self.decoding_ind
        enc, enc_len = self.asr_model.encode(**batch, ind=decoding_ind)
        if isinstance(enc, tuple):
            enc = enc[0]
        # assert len(enc) == 1, len(enc)
@@ -445,16 +443,20 @@
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
                                                                                   NeatContextualParaformer):
        if not isinstance(self.asr_model, ContextualParaformer) and \
            not isinstance(self.asr_model, NeatContextualParaformer):
            if self.hotword_list:
                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                     pre_token_length)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        else:
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                     pre_token_length, hw_list=self.hotword_list)
            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc,
                                                                     enc_len,
                                                                     pre_acoustic_embeds,
                                                                     pre_token_length,
                                                                     hw_list=self.hotword_list,
                                                                     clas_scale=self.clas_scale)
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        if isinstance(self.asr_model, BiCifParaformer):
@@ -515,10 +517,47 @@
                                                               vad_offset=begin_time)
                results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
        # assert check_return_type(results)
        return results
    def generate_hotwords_list(self, hotword_list_or_file):
        def load_seg_dict(seg_dict_file):
            seg_dict = {}
            assert isinstance(seg_dict_file, str)
            with open(seg_dict_file, "r", encoding="utf8") as f:
                lines = f.readlines()
                for line in lines:
                    s = line.strip().split()
                    key = s[0]
                    value = s[1:]
                    seg_dict[key] = " ".join(value)
            return seg_dict
        def seg_tokenize(txt, seg_dict):
            pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
            out_txt = ""
            for word in txt:
                word = word.lower()
                if word in seg_dict:
                    out_txt += seg_dict[word] + " "
                else:
                    if pattern.match(word):
                        for char in word:
                            if char in seg_dict:
                                out_txt += seg_dict[char] + " "
                            else:
                                out_txt += "<unk>" + " "
                    else:
                        out_txt += "<unk>" + " "
            return out_txt.strip().split()
        seg_dict = None
        if self.cmvn_file is not None:
            model_dir = os.path.dirname(self.cmvn_file)
            seg_dict_file = os.path.join(model_dir, 'seg_dict')
            if os.path.exists(seg_dict_file):
                seg_dict = load_seg_dict(seg_dict_file)
            else:
                seg_dict = None
        # for None
        if hotword_list_or_file is None:
            hotword_list = None
@@ -530,8 +569,11 @@
            with codecs.open(hotword_list_or_file, 'r') as fin:
                for line in fin.readlines():
                    hw = line.strip()
                    hw_list = hw.split()
                    if seg_dict is not None:
                        hw_list = seg_tokenize(hw_list, seg_dict)
                    hotword_str_list.append(hw)
                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                    hotword_list.append(self.converter.tokens2ids(hw_list))
                hotword_list.append([self.asr_model.sos])
                hotword_str_list.append('<s>')
            logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -551,8 +593,11 @@
            with codecs.open(hotword_list_or_file, 'r') as fin:
                for line in fin.readlines():
                    hw = line.strip()
                    hw_list = hw.split()
                    if seg_dict is not None:
                        hw_list = seg_tokenize(hw_list, seg_dict)
                    hotword_str_list.append(hw)
                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                    hotword_list.append(self.converter.tokens2ids(hw_list))
                hotword_list.append([self.asr_model.sos])
                hotword_str_list.append('<s>')
            logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -564,7 +609,10 @@
            hotword_str_list = []
            for hw in hotword_list_or_file.strip().split():
                hotword_str_list.append(hw)
                hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                hw_list = hw.strip().split()
                if seg_dict is not None:
                    hw_list = seg_tokenize(hw_list, seg_dict)
                hotword_list.append(self.converter.tokens2ids(hw_list))
            hotword_list.append([self.asr_model.sos])
            hotword_str_list.append('<s>')
            logging.info("Hotword list: {}.".format(hotword_str_list))
@@ -608,7 +656,6 @@
            hotword_list_or_file: str = None,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
@@ -728,7 +775,6 @@
                text, token, token_int, hyp
        """
        assert check_argument_types()
        results = []
        cache_en = cache["encoder"]
        if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
@@ -823,7 +869,6 @@
                results.append(postprocessed_result)
        # assert check_return_type(results)
        return results
@@ -864,7 +909,6 @@
            frontend_conf: dict = None,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
@@ -988,7 +1032,6 @@
            text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
@@ -1056,7 +1099,6 @@
                text = None
            results.append((text, token, token_int, hyp))
        assert check_return_type(results)
        return results
@@ -1095,7 +1137,6 @@
            streaming: bool = False,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
@@ -1200,7 +1241,6 @@
            text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
@@ -1250,7 +1290,6 @@
                text = None
            results.append((text, token, token_int, hyp))
        assert check_return_type(results)
        return results
@@ -1298,7 +1337,8 @@
            quantize_dtype: str = "qint8",
            nbest: int = 1,
            streaming: bool = False,
            simu_streaming: bool = False,
            fake_streaming: bool = False,
            full_utt: bool = False,
            chunk_size: int = 16,
            left_context: int = 32,
            right_context: int = 0,
@@ -1307,7 +1347,6 @@
        """Construct a Speech2Text object."""
        super().__init__()
        assert check_argument_types()
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
@@ -1393,7 +1432,8 @@
        self.beam_search = beam_search
        self.streaming = streaming
        self.simu_streaming = simu_streaming
        self.fake_streaming = fake_streaming
        self.full_utt = full_utt
        self.chunk_size = max(chunk_size, 0)
        self.left_context = left_context
        self.right_context = max(right_context, 0)
@@ -1402,8 +1442,8 @@
            self.streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False
        if not simu_streaming or chunk_size == 0:
            self.simu_streaming = False
        if not fake_streaming or chunk_size == 0:
            self.fake_streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False
        self.frontend = frontend
@@ -1413,6 +1453,7 @@
            self._ctx = self.asr_model.encoder.get_encoder_input_size(
                self.window_size
            )
            self._right_ctx = right_context
            self.last_chunk_length = (
                    self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
@@ -1479,14 +1520,13 @@
        return nbest_hyps
    @torch.no_grad()
    def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
    def fake_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
        """Speech2Text call.
        Args:
            speech: Speech data. (S)
        Returns:
            nbest_hypothesis: N-best hypothesis.
        """
        assert check_argument_types()
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
@@ -1511,7 +1551,7 @@
        return nbest_hyps
    @torch.no_grad()
    def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
    def full_utt_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
        """Speech2Text call.
        Args:
            speech: Speech data. (S)
@@ -1519,6 +1559,36 @@
            nbest_hypothesis: N-best hypothesis.
        """
        assert check_argument_types()
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            speech = torch.unsqueeze(speech, axis=0)
            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out = self.asr_model.encoder.full_utt_forward(feats, feats_lengths)
        nbest_hyps = self.beam_search(enc_out[0])
        return nbest_hyps
    @torch.no_grad()
    def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
        """Speech2Text call.
        Args:
            speech: Speech data. (S)
        Returns:
            nbest_hypothesis: N-best hypothesis.
        """
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
@@ -1560,35 +1630,8 @@
                text = None
            results.append((text, token, token_int, hyp))
            assert check_return_type(results)
        return results
    @staticmethod
    def from_pretrained(
            model_tag: Optional[str] = None,
            **kwargs: Optional[Any],
    ) -> Speech2Text:
        """Build Speech2Text instance from the pretrained model.
        Args:
            model_tag: Model tag of the pretrained models.
        Return:
            : Speech2Text instance.
        """
        if model_tag is not None:
            try:
                from espnet_model_zoo.downloader import ModelDownloader
            except ImportError:
                logging.error(
                    "`espnet_model_zoo` is not installed. "
                    "Please install via `pip install -U espnet_model_zoo`."
                )
                raise
            d = ModelDownloader()
            kwargs.update(**d.download_and_unpack(model_tag))
        return Speech2TextTransducer(**kwargs)
class Speech2TextSAASR:
@@ -1627,7 +1670,6 @@
            frontend_conf: dict = None,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
@@ -1745,7 +1787,6 @@
            text, text_id, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
@@ -1838,5 +1879,4 @@
            results.append((text, text_id, token, token_int, hyp))
        assert check_return_type(results)
        return results