From 4ace5a95b052d338947fc88809a440ccd55cf6b4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 十一月 2023 16:39:52 +0800
Subject: [PATCH] funasr pages

---
 funasr/bin/asr_infer.py |  226 +++++++++++++++++++++++++++++++++++++++++--------------
 1 files changed, 167 insertions(+), 59 deletions(-)

diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 0ce8dd8..7015eb8 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -22,9 +22,7 @@
 import requests
 import torch
 from packaging.version import parse as V
-from typeguard import check_argument_types
-from typeguard import check_return_type
-from  funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.build_utils.build_model_from_file import build_model_from_file
 from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
 from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
 from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
@@ -78,7 +76,6 @@
             frontend_conf: dict = None,
             **kwargs,
     ):
-        assert check_argument_types()
 
         # 1. Build ASR model
         scorers = {}
@@ -192,7 +189,6 @@
             text, token, token_int, hyp
 
         """
-        assert check_argument_types()
 
         # Input as audio signal
         if isinstance(speech, np.ndarray):
@@ -248,7 +244,6 @@
                 text = None
             results.append((text, token, token_int, hyp))
 
-        assert check_return_type(results)
         return results
 
 
@@ -289,7 +284,6 @@
             decoding_ind: int = 0,
             **kwargs,
     ):
-        assert check_argument_types()
 
         # 1. Build ASR model
         scorers = {}
@@ -405,7 +399,7 @@
     @torch.no_grad()
     def __call__(
             self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
-            begin_time: int = 0, end_time: int = None,
+            decoding_ind: int = None, begin_time: int = 0, end_time: int = None,
     ):
         """Inference
 
@@ -415,7 +409,6 @@
                 text, token, token_int, hyp
 
         """
-        assert check_argument_types()
 
         # Input as audio signal
         if isinstance(speech, np.ndarray):
@@ -436,7 +429,9 @@
         batch = to_device(batch, device=self.device)
 
         # b. Forward Encoder
-        enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
+        if decoding_ind is None:
+            decoding_ind = 0 if self.decoding_ind is None else self.decoding_ind
+        enc, enc_len = self.asr_model.encode(**batch, ind=decoding_ind)
         if isinstance(enc, tuple):
             enc = enc[0]
         # assert len(enc) == 1, len(enc)
@@ -522,7 +517,6 @@
                                                                vad_offset=begin_time)
                 results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
 
-        # assert check_return_type(results)
         return results
 
     def generate_hotwords_list(self, hotword_list_or_file):
@@ -662,7 +656,6 @@
             hotword_list_or_file: str = None,
             **kwargs,
     ):
-        assert check_argument_types()
 
         # 1. Build ASR model
         scorers = {}
@@ -782,7 +775,6 @@
                 text, token, token_int, hyp
 
         """
-        assert check_argument_types()
         results = []
         cache_en = cache["encoder"]
         if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
@@ -877,7 +869,6 @@
 
                 results.append(postprocessed_result)
 
-        # assert check_return_type(results)
         return results
 
 
@@ -918,7 +909,6 @@
             frontend_conf: dict = None,
             **kwargs,
     ):
-        assert check_argument_types()
 
         # 1. Build ASR model
         scorers = {}
@@ -1042,7 +1032,6 @@
             text, token, token_int, hyp
 
         """
-        assert check_argument_types()
 
         # Input as audio signal
         if isinstance(speech, np.ndarray):
@@ -1110,7 +1099,6 @@
                 text = None
             results.append((text, token, token_int, hyp))
 
-        assert check_return_type(results)
         return results
 
 
@@ -1149,7 +1137,6 @@
             streaming: bool = False,
             **kwargs,
     ):
-        assert check_argument_types()
 
         # 1. Build ASR model
         scorers = {}
@@ -1254,7 +1241,6 @@
             text, token, token_int, hyp
 
         """
-        assert check_argument_types()
         # Input as audio signal
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
@@ -1304,7 +1290,6 @@
                 text = None
             results.append((text, token, token_int, hyp))
 
-        assert check_return_type(results)
         return results
 
 
@@ -1352,7 +1337,8 @@
             quantize_dtype: str = "qint8",
             nbest: int = 1,
             streaming: bool = False,
-            simu_streaming: bool = False,
+            fake_streaming: bool = False,
+            full_utt: bool = False,
             chunk_size: int = 16,
             left_context: int = 32,
             right_context: int = 0,
@@ -1361,7 +1347,6 @@
         """Construct a Speech2Text object."""
         super().__init__()
 
-        assert check_argument_types()
         asr_model, asr_train_args = build_model_from_file(
             asr_train_config, asr_model_file, cmvn_file, device
         )
@@ -1447,7 +1432,8 @@
 
         self.beam_search = beam_search
         self.streaming = streaming
-        self.simu_streaming = simu_streaming
+        self.fake_streaming = fake_streaming
+        self.full_utt = full_utt
         self.chunk_size = max(chunk_size, 0)
         self.left_context = left_context
         self.right_context = max(right_context, 0)
@@ -1456,8 +1442,8 @@
             self.streaming = False
             self.asr_model.encoder.dynamic_chunk_training = False
 
-        if not simu_streaming or chunk_size == 0:
-            self.simu_streaming = False
+        if not fake_streaming or chunk_size == 0:
+            self.fake_streaming = False
             self.asr_model.encoder.dynamic_chunk_training = False
 
         self.frontend = frontend
@@ -1467,6 +1453,7 @@
             self._ctx = self.asr_model.encoder.get_encoder_input_size(
                 self.window_size
             )
+            self._right_ctx = right_context
 
             self.last_chunk_length = (
                     self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
@@ -1533,14 +1520,13 @@
         return nbest_hyps
 
     @torch.no_grad()
-    def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
+    def fake_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
         """Speech2Text call.
         Args:
             speech: Speech data. (S)
         Returns:
             nbest_hypothesis: N-best hypothesis.
         """
-        assert check_argument_types()
 
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
@@ -1565,7 +1551,7 @@
         return nbest_hyps
 
     @torch.no_grad()
-    def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
+    def full_utt_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
         """Speech2Text call.
         Args:
             speech: Speech data. (S)
@@ -1585,11 +1571,40 @@
             feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
             feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
 
+        if self.asr_model.normalize is not None:
+            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+        feats = to_device(feats, device=self.device)
+        feats_lengths = to_device(feats_lengths, device=self.device)
+        enc_out = self.asr_model.encoder.full_utt_forward(feats, feats_lengths)
+        nbest_hyps = self.beam_search(enc_out[0])
+
+        return nbest_hyps
+
+    @torch.no_grad()
+    def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
+        """Speech2Text call.
+        Args:
+            speech: Speech data. (S)
+        Returns:
+            nbest_hypothesis: N-best hypothesis.
+        """
+
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+
+        if self.frontend is not None:
+            speech = torch.unsqueeze(speech, axis=0)
+            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+            feats, feats_lengths = self.frontend(speech, speech_lengths)
+        else:
+            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
         feats = to_device(feats, device=self.device)
         feats_lengths = to_device(feats_lengths, device=self.device)
 
         enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
-
         nbest_hyps = self.beam_search(enc_out[0])
 
         return nbest_hyps
@@ -1614,35 +1629,8 @@
                 text = None
             results.append((text, token, token_int, hyp))
 
-            assert check_return_type(results)
 
         return results
-
-    @staticmethod
-    def from_pretrained(
-            model_tag: Optional[str] = None,
-            **kwargs: Optional[Any],
-    ) -> Speech2Text:
-        """Build Speech2Text instance from the pretrained model.
-        Args:
-            model_tag: Model tag of the pretrained models.
-        Return:
-            : Speech2Text instance.
-        """
-        if model_tag is not None:
-            try:
-                from espnet_model_zoo.downloader import ModelDownloader
-
-            except ImportError:
-                logging.error(
-                    "`espnet_model_zoo` is not installed. "
-                    "Please install via `pip install -U espnet_model_zoo`."
-                )
-                raise
-            d = ModelDownloader()
-            kwargs.update(**d.download_and_unpack(model_tag))
-
-        return Speech2TextTransducer(**kwargs)
 
 
 class Speech2TextSAASR:
@@ -1681,7 +1669,6 @@
             frontend_conf: dict = None,
             **kwargs,
     ):
-        assert check_argument_types()
 
         # 1. Build ASR model
         scorers = {}
@@ -1799,7 +1786,6 @@
             text, text_id, token, token_int, hyp
 
         """
-        assert check_argument_types()
 
         # Input as audio signal
         if isinstance(speech, np.ndarray):
@@ -1892,5 +1878,127 @@
 
             results.append((text, text_id, token, token_int, hyp))
 
-        assert check_return_type(results)
+        return results
+
+
+class Speech2TextWhisper:
+    """Speech2Text class
+
+    Examples:
+        >>> import soundfile
+        >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
+        >>> audio, rate = soundfile.read("speech.wav")
+        >>> speech2text(audio)
+        [(text, token, token_int, hypothesis object), ...]
+
+    """
+
+    def __init__(
+            self,
+            asr_train_config: Union[Path, str] = None,
+            asr_model_file: Union[Path, str] = None,
+            cmvn_file: Union[Path, str] = None,
+            lm_train_config: Union[Path, str] = None,
+            lm_file: Union[Path, str] = None,
+            token_type: str = None,
+            bpemodel: str = None,
+            device: str = "cpu",
+            maxlenratio: float = 0.0,
+            minlenratio: float = 0.0,
+            batch_size: int = 1,
+            dtype: str = "float32",
+            beam_size: int = 20,
+            ctc_weight: float = 0.5,
+            lm_weight: float = 1.0,
+            ngram_weight: float = 0.9,
+            penalty: float = 0.0,
+            nbest: int = 1,
+            streaming: bool = False,
+            frontend_conf: dict = None,
+            language: str = None,
+            task: str = "transcribe",
+            **kwargs,
+    ):
+
+        from funasr.tasks.whisper import ASRTask
+
+        # 1. Build ASR model
+        scorers = {}
+        asr_model, asr_train_args = ASRTask.build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device
+        )
+        frontend = None
+
+        logging.info("asr_model: {}".format(asr_model))
+        logging.info("asr_train_args: {}".format(asr_train_args))
+        asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+        decoder = asr_model.decoder
+
+        token_list = []
+
+        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+        if token_type is None:
+            token_type = asr_train_args.token_type
+        if bpemodel is None:
+            bpemodel = asr_train_args.bpemodel
+
+        if token_type is None:
+            tokenizer = None
+        elif token_type == "bpe":
+            if bpemodel is not None:
+                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+            else:
+                tokenizer = None
+        else:
+            tokenizer = build_tokenizer(token_type=token_type)
+        logging.info(f"Text tokenizer: {tokenizer}")
+
+        self.asr_model = asr_model
+        self.asr_train_args = asr_train_args
+        self.tokenizer = tokenizer
+        self.device = device
+        self.dtype = dtype
+        self.frontend = frontend
+        self.language = language
+        self.task = task
+
+    @torch.no_grad()
+    def __call__(
+            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+    ) -> List[
+        Tuple[
+            Optional[str],
+            List[str],
+            List[int],
+            Union[Hypothesis],
+        ]
+    ]:
+        """Inference
+
+        Args:
+            speech: Input speech data
+        Returns:
+            text, token, token_int, hyp
+
+        """
+
+        from funasr.utils.whisper_utils.transcribe import transcribe
+        from funasr.utils.whisper_utils.audio import pad_or_trim, log_mel_spectrogram
+        from funasr.utils.whisper_utils.decoding import DecodingOptions, detect_language, decode
+
+        speech = speech[0]
+        speech = pad_or_trim(speech)
+        mel = log_mel_spectrogram(speech).to(self.device)
+
+        if self.asr_model.is_multilingual:
+            options = DecodingOptions(fp16=False, language=self.language, task=self.task)
+            asr_res = decode(self.asr_model, mel, options)
+            text = asr_res.text
+            language = self.language if self.language else asr_res.language
+        else:
+            asr_res = transcribe(self.asr_model, speech, fp16=False)
+            text = asr_res["text"]
+            language = asr_res["language"]
+        results = [(text, language)]
         return results

--
Gitblit v1.9.1