From dc682db808eb5f425f0dbed4c5e7feb0a334955f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 23 十一月 2023 11:43:05 +0800
Subject: [PATCH] update funasr.text -> funasr.tokenizer fix bug export

---
 funasr/bin/asr_infer.py |  216 +++++++++++++++++++++++++++++++++++++++++++++++------
 1 files changed, 190 insertions(+), 26 deletions(-)

diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 259a286..a1cede1 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -34,8 +34,8 @@
 from funasr.modules.scorers.ctc import CTCPrefixScorer
 from funasr.modules.scorers.length_bonus import LengthBonus
 from funasr.build_utils.build_asr_model import frontend_choices
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.token_id_converter import TokenIDConverter
+from funasr.tokenizer.build_tokenizer import build_tokenizer
+from funasr.tokenizer.token_id_converter import TokenIDConverter
 from funasr.torch_utils.device_funcs import to_device
 from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
 
@@ -44,9 +44,9 @@
     """Speech2Text class
 
     Examples:
-        >>> import soundfile
+        >>> import librosa
         >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
-        >>> audio, rate = soundfile.read("speech.wav")
+        >>> audio, rate = librosa.load("speech.wav")
         >>> speech2text(audio)
         [(text, token, token_int, hypothesis object), ...]
 
@@ -251,9 +251,9 @@
     """Speech2Text class
 
     Examples:
-            >>> import soundfile
+            >>> import librosa
             >>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb")
-            >>> audio, rate = soundfile.read("speech.wav")
+            >>> audio, rate = librosa.load("speech.wav")
             >>> speech2text(audio)
             [(text, token, token_int, hypothesis object), ...]
 
@@ -280,6 +280,7 @@
             nbest: int = 1,
             frontend_conf: dict = None,
             hotword_list_or_file: str = None,
+            clas_scale: float = 1.0,
             decoding_ind: int = 0,
             **kwargs,
     ):
@@ -376,6 +377,7 @@
         # 6. [Optional] Build hotword list from str, local file or url
         self.hotword_list = None
         self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
+        self.clas_scale = clas_scale
 
         is_use_lm = lm_weight != 0.0 and lm_file is not None
         if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
@@ -397,7 +399,7 @@
     @torch.no_grad()
     def __call__(
             self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
-            begin_time: int = 0, end_time: int = None,
+            decoding_ind: int = None, begin_time: int = 0, end_time: int = None,
     ):
         """Inference
 
@@ -427,7 +429,9 @@
         batch = to_device(batch, device=self.device)
 
         # b. Forward Encoder
-        enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
+        if decoding_ind is None:
+            decoding_ind = 0 if self.decoding_ind is None else self.decoding_ind
+        enc, enc_len = self.asr_model.encode(**batch, ind=decoding_ind)
         if isinstance(enc, tuple):
             enc = enc[0]
         # assert len(enc) == 1, len(enc)
@@ -439,16 +443,20 @@
         pre_token_length = pre_token_length.round().long()
         if torch.max(pre_token_length) < 1:
             return []
-        if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
-                                                                                   NeatContextualParaformer):
+        if not isinstance(self.asr_model, ContextualParaformer) and \
+            not isinstance(self.asr_model, NeatContextualParaformer):
             if self.hotword_list:
                 logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
             decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
                                                                      pre_token_length)
             decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
         else:
-            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
-                                                                     pre_token_length, hw_list=self.hotword_list)
+            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, 
+                                                                     enc_len, 
+                                                                     pre_acoustic_embeds,
+                                                                     pre_token_length, 
+                                                                     hw_list=self.hotword_list,
+                                                                     clas_scale=self.clas_scale)
             decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
 
         if isinstance(self.asr_model, BiCifParaformer):
@@ -617,9 +625,9 @@
     """Speech2Text class
 
     Examples:
-            >>> import soundfile
+            >>> import librosa
             >>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth")
-            >>> audio, rate = soundfile.read("speech.wav")
+            >>> audio, rate = librosa.load("speech.wav")
             >>> speech2text(audio)
             [(text, token, token_int, hypothesis object), ...]
 
@@ -868,9 +876,9 @@
     """Speech2Text class
 
     Examples:
-        >>> import soundfile
+        >>> import librosa
         >>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb")
-        >>> audio, rate = soundfile.read("speech.wav")
+        >>> audio, rate = librosa.load("speech.wav")
         >>> speech2text(audio)
         [(text, token, token_int, hypothesis object), ...]
 
@@ -1098,9 +1106,9 @@
     """Speech2Text class
 
     Examples:
-        >>> import soundfile
+        >>> import librosa
         >>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb")
-        >>> audio, rate = soundfile.read("speech.wav")
+        >>> audio, rate = librosa.load("speech.wav")
         >>> speech2text(audio)
         [(text, token, token_int, hypothesis object), ...]
 
@@ -1329,7 +1337,8 @@
             quantize_dtype: str = "qint8",
             nbest: int = 1,
             streaming: bool = False,
-            simu_streaming: bool = False,
+            fake_streaming: bool = False,
+            full_utt: bool = False,
             chunk_size: int = 16,
             left_context: int = 32,
             right_context: int = 0,
@@ -1423,7 +1432,8 @@
 
         self.beam_search = beam_search
         self.streaming = streaming
-        self.simu_streaming = simu_streaming
+        self.fake_streaming = fake_streaming
+        self.full_utt = full_utt
         self.chunk_size = max(chunk_size, 0)
         self.left_context = left_context
         self.right_context = max(right_context, 0)
@@ -1432,8 +1442,8 @@
             self.streaming = False
             self.asr_model.encoder.dynamic_chunk_training = False
 
-        if not simu_streaming or chunk_size == 0:
-            self.simu_streaming = False
+        if not fake_streaming or chunk_size == 0:
+            self.fake_streaming = False
             self.asr_model.encoder.dynamic_chunk_training = False
 
         self.frontend = frontend
@@ -1443,6 +1453,7 @@
             self._ctx = self.asr_model.encoder.get_encoder_input_size(
                 self.window_size
             )
+            self._right_ctx = right_context
 
             self.last_chunk_length = (
                     self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
@@ -1509,7 +1520,7 @@
         return nbest_hyps
 
     @torch.no_grad()
-    def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
+    def fake_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
         """Speech2Text call.
         Args:
             speech: Speech data. (S)
@@ -1540,6 +1551,37 @@
         return nbest_hyps
 
     @torch.no_grad()
+    def full_utt_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
+        """Speech2Text call.
+        Args:
+            speech: Speech data. (S)
+        Returns:
+            nbest_hypothesis: N-best hypothesis.
+        """
+        assert check_argument_types()
+
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+
+        if self.frontend is not None:
+            speech = torch.unsqueeze(speech, axis=0)
+            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+            feats, feats_lengths = self.frontend(speech, speech_lengths)
+        else:
+            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+        if self.asr_model.normalize is not None:
+            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+        feats = to_device(feats, device=self.device)
+        feats_lengths = to_device(feats_lengths, device=self.device)
+        enc_out = self.asr_model.encoder.full_utt_forward(feats, feats_lengths)
+        nbest_hyps = self.beam_search(enc_out[0])
+
+        return nbest_hyps
+
+    @torch.no_grad()
     def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
         """Speech2Text call.
         Args:
@@ -1563,7 +1605,6 @@
         feats_lengths = to_device(feats_lengths, device=self.device)
 
         enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
-
         nbest_hyps = self.beam_search(enc_out[0])
 
         return nbest_hyps
@@ -1596,9 +1637,9 @@
     """Speech2Text class
 
     Examples:
-        >>> import soundfile
+        >>> import librosa
         >>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
-        >>> audio, rate = soundfile.read("speech.wav")
+        >>> audio, rate = librosa.load("speech.wav")
         >>> speech2text(audio)
         [(text, token, token_int, hypothesis object), ...]
 
@@ -1838,3 +1879,126 @@
             results.append((text, text_id, token, token_int, hyp))
 
         return results
+
+
+class Speech2TextWhisper:
+    """Speech2Text class
+
+    Examples:
+        >>> import librosa
+        >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
+        >>> audio, rate = librosa.load("speech.wav")
+        >>> speech2text(audio)
+        [(text, token, token_int, hypothesis object), ...]
+
+    """
+
+    def __init__(
+            self,
+            asr_train_config: Union[Path, str] = None,
+            asr_model_file: Union[Path, str] = None,
+            cmvn_file: Union[Path, str] = None,
+            lm_train_config: Union[Path, str] = None,
+            lm_file: Union[Path, str] = None,
+            token_type: str = None,
+            bpemodel: str = None,
+            device: str = "cpu",
+            maxlenratio: float = 0.0,
+            minlenratio: float = 0.0,
+            batch_size: int = 1,
+            dtype: str = "float32",
+            beam_size: int = 20,
+            ctc_weight: float = 0.5,
+            lm_weight: float = 1.0,
+            ngram_weight: float = 0.9,
+            penalty: float = 0.0,
+            nbest: int = 1,
+            streaming: bool = False,
+            frontend_conf: dict = None,
+            language: str = None,
+            task: str = "transcribe",
+            **kwargs,
+    ):
+
+        from funasr.tasks.whisper import ASRTask
+
+        # 1. Build ASR model
+        scorers = {}
+        asr_model, asr_train_args = ASRTask.build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device
+        )
+        frontend = None
+
+        logging.info("asr_model: {}".format(asr_model))
+        logging.info("asr_train_args: {}".format(asr_train_args))
+        asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+        decoder = asr_model.decoder
+
+        token_list = []
+
+        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+        if token_type is None:
+            token_type = asr_train_args.token_type
+        if bpemodel is None:
+            bpemodel = asr_train_args.bpemodel
+
+        if token_type is None:
+            tokenizer = None
+        elif token_type == "bpe":
+            if bpemodel is not None:
+                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+            else:
+                tokenizer = None
+        else:
+            tokenizer = build_tokenizer(token_type=token_type)
+        logging.info(f"Text tokenizer: {tokenizer}")
+
+        self.asr_model = asr_model
+        self.asr_train_args = asr_train_args
+        self.tokenizer = tokenizer
+        self.device = device
+        self.dtype = dtype
+        self.frontend = frontend
+        self.language = language
+        self.task = task
+
+    @torch.no_grad()
+    def __call__(
+            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+    ) -> List[
+        Tuple[
+            Optional[str],
+            List[str],
+            List[int],
+            Union[Hypothesis],
+        ]
+    ]:
+        """Inference
+
+        Args:
+            speech: Input speech data
+        Returns:
+            text, token, token_int, hyp
+
+        """
+
+        from funasr.utils.whisper_utils.transcribe import transcribe
+        from funasr.utils.whisper_utils.audio import pad_or_trim, log_mel_spectrogram
+        from funasr.utils.whisper_utils.decoding import DecodingOptions, detect_language, decode
+
+        speech = speech[0]
+        speech = pad_or_trim(speech)
+        mel = log_mel_spectrogram(speech).to(self.device)
+
+        if self.asr_model.is_multilingual:
+            options = DecodingOptions(fp16=False, language=self.language, task=self.task)
+            asr_res = decode(self.asr_model, mel, options)
+            text = asr_res.text
+            language = self.language if self.language else asr_res.language
+        else:
+            asr_res = transcribe(self.asr_model, speech, fp16=False)
+            text = asr_res["text"]
+            language = asr_res["language"]
+        results = [(text, language)]
+        return results

--
Gitblit v1.9.1