jmwang66
2023-05-09 8dab6d184a034ca86eafa644ea0d2100aadfe27d
funasr/bin/asr_inference.py
@@ -41,6 +41,7 @@
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.asr import frontend_choices
header_colors = '\033[95m'
@@ -92,7 +93,11 @@
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
            if asr_train_args.frontend=='wav_frontend':
                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
            else:
                frontend_class=frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
@@ -111,7 +116,7 @@
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
                lm_train_config, lm_file, None, device
            )
            scorers["lm"] = lm.lm
@@ -193,7 +198,7 @@
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
@@ -280,6 +285,7 @@
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        mc: bool = False,
        **kwargs,
):
    inference_pipeline = inference_modelscope(
@@ -310,6 +316,7 @@
        ngram_weight=ngram_weight,
        nbest=nbest,
        num_workers=num_workers,
        mc=mc,
        **kwargs,
    )
    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -342,6 +349,7 @@
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    mc: bool = False,
    param_dict: dict = None,
    **kwargs,
):
@@ -355,6 +363,9 @@
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
@@ -408,6 +419,7 @@
            data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            mc=mc,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
@@ -416,7 +428,7 @@
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
@@ -452,7 +464,7 @@
                    
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                
                if text is not None:
@@ -463,6 +475,9 @@
                    asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text
                logging.info("uttid: {}".format(key))
                logging.info("text predictions: {}\n".format(text))
        return asr_result_list
    
    return _forward
@@ -637,4 +652,4 @@
if __name__ == "__main__":
    main()
    main()