游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/bin/asr_infer.py
@@ -38,9 +38,7 @@
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.whisper_utils.decoding import DecodingOptions, detect_language, decode
from funasr.utils.whisper_utils.transcribe import transcribe
from funasr.utils.whisper_utils.audio import pad_or_trim, log_mel_spectrogram
class Speech2Text:
    """Speech2Text class
@@ -1607,7 +1605,6 @@
        feats_lengths = to_device(feats_lengths, device=self.device)
        enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
        nbest_hyps = self.beam_search(enc_out[0])
        return nbest_hyps
@@ -1918,12 +1915,15 @@
            nbest: int = 1,
            streaming: bool = False,
            frontend_conf: dict = None,
            language: str = None,
            task: str = "transcribe",
            **kwargs,
    ):
        from funasr.tasks.whisper import ASRTask
        # 1. Build ASR model
        scorers = {}
        from funasr.tasks.whisper import ASRTask
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
@@ -1960,6 +1960,8 @@
        self.device = device
        self.dtype = dtype
        self.frontend = frontend
        self.language = language
        self.task = task
    @torch.no_grad()
    def __call__(
@@ -1981,15 +1983,19 @@
        """
        from funasr.utils.whisper_utils.transcribe import transcribe
        from funasr.utils.whisper_utils.audio import pad_or_trim, log_mel_spectrogram
        from funasr.utils.whisper_utils.decoding import DecodingOptions, detect_language, decode
        speech = speech[0]
        speech = pad_or_trim(speech)
        mel = log_mel_spectrogram(speech).to(self.device)
        if self.asr_model.is_multilingual:
            options = DecodingOptions(fp16=False)
            options = DecodingOptions(fp16=False, language=self.language, task=self.task)
            asr_res = decode(self.asr_model, mel, options)
            text = asr_res.text
            language = asr_res.language
            language = self.language if self.language else asr_res.language
        else:
            asr_res = transcribe(self.asr_model, speech, fp16=False)
            text = asr_res["text"]