From 4d0bbae6830019dc3a856754dada8ddc1416e83e Mon Sep 17 00:00:00 2001
From: Lizerui9926 <110582652+Lizerui9926@users.noreply.github.com>
Date: 星期四, 12 十月 2023 16:19:13 +0800
Subject: [PATCH] Merge pull request #1003 from alibaba-damo-academy/dev_lzr_en

---
 funasr/bin/asr_infer.py |  118 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 117 insertions(+), 1 deletions(-)

diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 3117e5d..43da8bf 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -38,7 +38,9 @@
 from funasr.text.token_id_converter import TokenIDConverter
 from funasr.torch_utils.device_funcs import to_device
 from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-
+from funasr.utils.whisper_utils.decoding import DecodingOptions, detect_language, decode
+from funasr.utils.whisper_utils.transcribe import transcribe
+from funasr.utils.whisper_utils.audio import pad_or_trim, log_mel_spectrogram
 
 class Speech2Text:
     """Speech2Text class
@@ -1880,3 +1882,117 @@
             results.append((text, text_id, token, token_int, hyp))
 
         return results
+
+
+class Speech2TextWhisper:
+    """Speech2Text class
+
+    Examples:
+        >>> import soundfile
+        >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
+        >>> audio, rate = soundfile.read("speech.wav")
+        >>> speech2text(audio)
+        [(text, token, token_int, hypothesis object), ...]
+
+    """
+
+    def __init__(
+            self,
+            asr_train_config: Union[Path, str] = None,
+            asr_model_file: Union[Path, str] = None,
+            cmvn_file: Union[Path, str] = None,
+            lm_train_config: Union[Path, str] = None,
+            lm_file: Union[Path, str] = None,
+            token_type: str = None,
+            bpemodel: str = None,
+            device: str = "cpu",
+            maxlenratio: float = 0.0,
+            minlenratio: float = 0.0,
+            batch_size: int = 1,
+            dtype: str = "float32",
+            beam_size: int = 20,
+            ctc_weight: float = 0.5,
+            lm_weight: float = 1.0,
+            ngram_weight: float = 0.9,
+            penalty: float = 0.0,
+            nbest: int = 1,
+            streaming: bool = False,
+            frontend_conf: dict = None,
+            **kwargs,
+    ):
+
+        # 1. Build ASR model
+        scorers = {}
+        from funasr.tasks.whisper import ASRTask
+        asr_model, asr_train_args = ASRTask.build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device
+        )
+        frontend = None
+
+        logging.info("asr_model: {}".format(asr_model))
+        logging.info("asr_train_args: {}".format(asr_train_args))
+        asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+        decoder = asr_model.decoder
+
+        token_list = []
+
+        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+        if token_type is None:
+            token_type = asr_train_args.token_type
+        if bpemodel is None:
+            bpemodel = asr_train_args.bpemodel
+
+        if token_type is None:
+            tokenizer = None
+        elif token_type == "bpe":
+            if bpemodel is not None:
+                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+            else:
+                tokenizer = None
+        else:
+            tokenizer = build_tokenizer(token_type=token_type)
+        logging.info(f"Text tokenizer: {tokenizer}")
+
+        self.asr_model = asr_model
+        self.asr_train_args = asr_train_args
+        self.tokenizer = tokenizer
+        self.device = device
+        self.dtype = dtype
+        self.frontend = frontend
+
+    @torch.no_grad()
+    def __call__(
+            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+    ) -> List[
+        Tuple[
+            Optional[str],
+            List[str],
+            List[int],
+            Union[Hypothesis],
+        ]
+    ]:
+        """Inference
+
+        Args:
+            speech: Input speech data
+        Returns:
+            text, token, token_int, hyp
+
+        """
+
+        speech = speech[0]
+        speech = pad_or_trim(speech)
+        mel = log_mel_spectrogram(speech).to(self.device)
+
+        if self.asr_model.is_multilingual:
+            options = DecodingOptions(fp16=False)
+            asr_res = decode(self.asr_model, mel, options)
+            text = asr_res.text
+            language = asr_res.language
+        else:
+            asr_res = transcribe(self.asr_model, speech, fp16=False)
+            text = asr_res["text"]
+            language = asr_res["language"]
+        results = [(text, language)]
+        return results

--
Gitblit v1.9.1