smohan-speech
2023-05-06 a73123bcfc14370b74b17084bc124f00c48613e4
funasr/bin/asr_inference.py
@@ -40,7 +40,6 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
header_colors = '\033[95m'
@@ -91,8 +90,6 @@
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
@@ -111,7 +108,7 @@
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
                lm_train_config, lm_file, None, device
            )
            scorers["lm"] = lm.lm
@@ -141,6 +138,13 @@
            token_list=token_list,
            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
        )
        beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
        for scorer in scorers.values():
            if isinstance(scorer, torch.nn.Module):
                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
        logging.info(f"Beam_search: {beam_search}")
        logging.info(f"Decoding device={device}, dtype={dtype}")
        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
        if token_type is None:
@@ -198,16 +202,7 @@
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
            self.asr_model.frontend = None
        else:
            feats = speech
            feats_len = speech_lengths
        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
        batch = {"speech": feats, "speech_lengths": feats_len}
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        # a. To device
        batch = to_device(batch, device=self.device)
@@ -355,6 +350,9 @@
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
@@ -408,6 +406,7 @@
            data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            mc=True,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
@@ -452,7 +451,7 @@
                    
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                
                if text is not None:
@@ -463,6 +462,9 @@
                    asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text
                logging.info("uttid: {}".format(key))
                logging.info("text predictions: {}\n".format(text))
        return asr_result_list
    
    return _forward
@@ -637,4 +639,4 @@
if __name__ == "__main__":
    main()
    main()