游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/bin/asr_infer.py
@@ -399,7 +399,7 @@
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
            begin_time: int = 0, end_time: int = None,
            decoding_ind: int = None, begin_time: int = 0, end_time: int = None,
    ):
        """Inference
@@ -429,7 +429,9 @@
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
        if decoding_ind is None:
            decoding_ind = 0 if self.decoding_ind is None else self.decoding_ind
        enc, enc_len = self.asr_model.encode(**batch, ind=decoding_ind)
        if isinstance(enc, tuple):
            enc = enc[0]
        # assert len(enc) == 1, len(enc)
@@ -1335,7 +1337,7 @@
            quantize_dtype: str = "qint8",
            nbest: int = 1,
            streaming: bool = False,
            simu_streaming: bool = False,
            fake_streaming: bool = False,
            full_utt: bool = False,
            chunk_size: int = 16,
            left_context: int = 32,
@@ -1430,7 +1432,7 @@
        self.beam_search = beam_search
        self.streaming = streaming
        self.simu_streaming = simu_streaming
        self.fake_streaming = fake_streaming
        self.full_utt = full_utt
        self.chunk_size = max(chunk_size, 0)
        self.left_context = left_context
@@ -1440,8 +1442,8 @@
            self.streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False
        if not simu_streaming or chunk_size == 0:
            self.simu_streaming = False
        if not fake_streaming or chunk_size == 0:
            self.fake_streaming = False
            self.asr_model.encoder.dynamic_chunk_training = False
        self.frontend = frontend
@@ -1518,7 +1520,7 @@
        return nbest_hyps
    @torch.no_grad()
    def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
    def fake_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
        """Speech2Text call.
        Args:
            speech: Speech data. (S)
@@ -1603,7 +1605,6 @@
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
        nbest_hyps = self.beam_search(enc_out[0])
        return nbest_hyps
@@ -1878,3 +1879,126 @@
            results.append((text, text_id, token, token_int, hyp))
        return results
class Speech2TextWhisper:
    """Speech2Text class
    Examples:
        >>> import soundfile
        >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2text(audio)
        [(text, token, token_int, hypothesis object), ...]
    """
    def __init__(
            self,
            asr_train_config: Union[Path, str] = None,
            asr_model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            lm_train_config: Union[Path, str] = None,
            lm_file: Union[Path, str] = None,
            token_type: str = None,
            bpemodel: str = None,
            device: str = "cpu",
            maxlenratio: float = 0.0,
            minlenratio: float = 0.0,
            batch_size: int = 1,
            dtype: str = "float32",
            beam_size: int = 20,
            ctc_weight: float = 0.5,
            lm_weight: float = 1.0,
            ngram_weight: float = 0.9,
            penalty: float = 0.0,
            nbest: int = 1,
            streaming: bool = False,
            frontend_conf: dict = None,
            language: str = None,
            task: str = "transcribe",
            **kwargs,
    ):
        from funasr.tasks.whisper import ASRTask
        # 1. Build ASR model
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        decoder = asr_model.decoder
        token_list = []
        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
            if bpemodel is not None:
                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
            else:
                tokenizer = None
        else:
            tokenizer = build_tokenizer(token_type=token_type)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.tokenizer = tokenizer
        self.device = device
        self.dtype = dtype
        self.frontend = frontend
        self.language = language
        self.task = task
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
    ) -> List[
        Tuple[
            Optional[str],
            List[str],
            List[int],
            Union[Hypothesis],
        ]
    ]:
        """Inference
        Args:
            speech: Input speech data
        Returns:
            text, token, token_int, hyp
        """
        from funasr.utils.whisper_utils.transcribe import transcribe
        from funasr.utils.whisper_utils.audio import pad_or_trim, log_mel_spectrogram
        from funasr.utils.whisper_utils.decoding import DecodingOptions, detect_language, decode
        speech = speech[0]
        speech = pad_or_trim(speech)
        mel = log_mel_spectrogram(speech).to(self.device)
        if self.asr_model.is_multilingual:
            options = DecodingOptions(fp16=False, language=self.language, task=self.task)
            asr_res = decode(self.asr_model, mel, options)
            text = asr_res.text
            language = self.language if self.language else asr_res.language
        else:
            asr_res = transcribe(self.asr_model, speech, fp16=False)
            text = asr_res["text"]
            language = asr_res["language"]
        results = [(text, language)]
        return results