游雁
2023-06-29 bc723ea200144bd6fa8a5dff4b9a780feda144fc
funasr/bin/asr_infer.py
@@ -22,9 +22,7 @@
import requests
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from typeguard import check_return_type
from  funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
@@ -78,7 +76,6 @@
            frontend_conf: dict = None,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
@@ -192,7 +189,6 @@
            text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
@@ -248,7 +244,6 @@
                text = None
            results.append((text, token, token_int, hyp))
        assert check_return_type(results)
        return results
@@ -288,7 +283,6 @@
            decoding_ind: int = 0,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
@@ -377,6 +371,7 @@
        self.asr_train_args = asr_train_args
        self.converter = converter
        self.tokenizer = tokenizer
        self.cmvn_file = cmvn_file
        # 6. [Optional] Build hotword list from str, local file or url
        self.hotword_list = None
@@ -412,7 +407,6 @@
                text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
@@ -515,10 +509,47 @@
                                                               vad_offset=begin_time)
                results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
        # assert check_return_type(results)
        return results
    def generate_hotwords_list(self, hotword_list_or_file):
        def load_seg_dict(seg_dict_file):
            seg_dict = {}
            assert isinstance(seg_dict_file, str)
            with open(seg_dict_file, "r", encoding="utf8") as f:
                lines = f.readlines()
                for line in lines:
                    s = line.strip().split()
                    key = s[0]
                    value = s[1:]
                    seg_dict[key] = " ".join(value)
            return seg_dict
        def seg_tokenize(txt, seg_dict):
            pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
            out_txt = ""
            for word in txt:
                word = word.lower()
                if word in seg_dict:
                    out_txt += seg_dict[word] + " "
                else:
                    if pattern.match(word):
                        for char in word:
                            if char in seg_dict:
                                out_txt += seg_dict[char] + " "
                            else:
                                out_txt += "<unk>" + " "
                    else:
                        out_txt += "<unk>" + " "
            return out_txt.strip().split()
        seg_dict = None
        if self.cmvn_file is not None:
            model_dir = os.path.dirname(self.cmvn_file)
            seg_dict_file = os.path.join(model_dir, 'seg_dict')
            if os.path.exists(seg_dict_file):
                seg_dict = load_seg_dict(seg_dict_file)
            else:
                seg_dict = None
        # for None
        if hotword_list_or_file is None:
            hotword_list = None
@@ -530,8 +561,11 @@
            with codecs.open(hotword_list_or_file, 'r') as fin:
                for line in fin.readlines():
                    hw = line.strip()
                    hw_list = hw.split()
                    if seg_dict is not None:
                        hw_list = seg_tokenize(hw_list, seg_dict)
                    hotword_str_list.append(hw)
                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                    hotword_list.append(self.converter.tokens2ids(hw_list))
                hotword_list.append([self.asr_model.sos])
                hotword_str_list.append('<s>')
            logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -551,8 +585,11 @@
            with codecs.open(hotword_list_or_file, 'r') as fin:
                for line in fin.readlines():
                    hw = line.strip()
                    hw_list = hw.split()
                    if seg_dict is not None:
                        hw_list = seg_tokenize(hw_list, seg_dict)
                    hotword_str_list.append(hw)
                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                    hotword_list.append(self.converter.tokens2ids(hw_list))
                hotword_list.append([self.asr_model.sos])
                hotword_str_list.append('<s>')
            logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -564,7 +601,10 @@
            hotword_str_list = []
            for hw in hotword_list_or_file.strip().split():
                hotword_str_list.append(hw)
                hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                hw_list = hw.strip().split()
                if seg_dict is not None:
                    hw_list = seg_tokenize(hw_list, seg_dict)
                hotword_list.append(self.converter.tokens2ids(hw_list))
            hotword_list.append([self.asr_model.sos])
            hotword_str_list.append('<s>')
            logging.info("Hotword list: {}.".format(hotword_str_list))
@@ -608,7 +648,6 @@
            hotword_list_or_file: str = None,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
@@ -728,7 +767,6 @@
                text, token, token_int, hyp
        """
        assert check_argument_types()
        results = []
        cache_en = cache["encoder"]
        if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
@@ -823,7 +861,6 @@
                results.append(postprocessed_result)
        # assert check_return_type(results)
        return results
@@ -864,7 +901,6 @@
            frontend_conf: dict = None,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
@@ -988,7 +1024,6 @@
            text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
@@ -1056,7 +1091,6 @@
                text = None
            results.append((text, token, token_int, hyp))
        assert check_return_type(results)
        return results
@@ -1095,7 +1129,6 @@
            streaming: bool = False,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
@@ -1200,7 +1233,6 @@
            text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
@@ -1250,7 +1282,6 @@
                text = None
            results.append((text, token, token_int, hyp))
        assert check_return_type(results)
        return results
@@ -1307,7 +1338,6 @@
        """Construct a Speech2Text object."""
        super().__init__()
        assert check_argument_types()
        asr_model, asr_train_args = build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
@@ -1486,7 +1516,6 @@
        Returns:
            nbest_hypothesis: N-best hypothesis.
        """
        assert check_argument_types()
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
@@ -1518,7 +1547,6 @@
        Returns:
            nbest_hypothesis: N-best hypothesis.
        """
        assert check_argument_types()
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
@@ -1560,35 +1588,8 @@
                text = None
            results.append((text, token, token_int, hyp))
            assert check_return_type(results)
        return results
    @staticmethod
    def from_pretrained(
            model_tag: Optional[str] = None,
            **kwargs: Optional[Any],
    ) -> Speech2Text:
        """Build Speech2Text instance from the pretrained model.
        Args:
            model_tag: Model tag of the pretrained models.
        Return:
            : Speech2Text instance.
        """
        if model_tag is not None:
            try:
                from espnet_model_zoo.downloader import ModelDownloader
            except ImportError:
                logging.error(
                    "`espnet_model_zoo` is not installed. "
                    "Please install via `pip install -U espnet_model_zoo`."
                )
                raise
            d = ModelDownloader()
            kwargs.update(**d.download_and_unpack(model_tag))
        return Speech2TextTransducer(**kwargs)
class Speech2TextSAASR:
@@ -1627,7 +1628,6 @@
            frontend_conf: dict = None,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
@@ -1636,8 +1636,10 @@
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            if asr_train_args.frontend == 'wav_frontend':
                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
            from funasr.tasks.sa_asr import frontend_choices
            if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
            else:
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
@@ -1743,7 +1745,6 @@
            text, text_id, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
@@ -1836,5 +1837,4 @@
            results.append((text, text_id, token, token_int, hyp))
        assert check_return_type(results)
        return results