smohan-speech
2023-05-07 d76aea23d9f5daac4df7ee1985d07f7428abc719
funasr/bin/asr_inference.py
@@ -40,6 +40,8 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.asr import frontend_choices
header_colors = '\033[95m'
@@ -90,6 +92,12 @@
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            if asr_train_args.frontend=='wav_frontend':
                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
            else:
                frontend_class=frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
@@ -197,12 +205,21 @@
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
            self.asr_model.frontend = None
        else:
            feats = speech
            feats_len = speech_lengths
        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
        batch = {"speech": feats, "speech_lengths": feats_len}
        # a. To device
        batch = to_device(batch, device=self.device)
@@ -275,6 +292,7 @@
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        mc: bool = False,
        **kwargs,
):
    inference_pipeline = inference_modelscope(
@@ -305,6 +323,7 @@
        ngram_weight=ngram_weight,
        nbest=nbest,
        num_workers=num_workers,
        mc=mc,
        **kwargs,
    )
    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -337,6 +356,7 @@
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    mc: bool = False,
    param_dict: dict = None,
    **kwargs,
):
@@ -406,7 +426,7 @@
            data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            mc=True,
            mc=mc,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
@@ -415,7 +435,7 @@
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop