| | |
| | | #!/usr/bin/env python3 |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import argparse |
| | | import logging |
| | |
| | | help="Pretrained model tag. If specify this option, *_train_config and " |
| | | "*_file will be overwritten", |
| | | ) |
| | | group.add_argument( |
| | | "--beam_search_config", |
| | | default={}, |
| | | help="The keyword arguments for transducer beam search.", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Beam-search related") |
| | | group.add_argument( |
| | |
| | | group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") |
| | | group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") |
| | | group.add_argument("--streaming", type=str2bool, default=False) |
| | | group.add_argument("--simu_streaming", type=str2bool, default=False) |
| | | group.add_argument("--chunk_size", type=int, default=16) |
| | | group.add_argument("--left_context", type=int, default=16) |
| | | group.add_argument("--right_context", type=int, default=0) |
| | | group.add_argument( |
| | | "--display_partial_hypotheses", |
| | | type=bool, |
| | | default=False, |
| | | help="Whether to display partial hypotheses during chunk-by-chunk inference.", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Dynamic quantization related") |
| | | group.add_argument( |
| | | "--quantize_asr_model", |
| | | type=bool, |
| | | default=False, |
| | | help="Apply dynamic quantization to ASR model.", |
| | | ) |
| | | group.add_argument( |
| | | "--quantize_modules", |
| | | nargs="*", |
| | | default=None, |
| | | help="""Module names to apply dynamic quantization on. |
| | | The module names are provided as a list, where each name is separated |
| | | by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]). |
| | | Each specified name should be an attribute of 'torch.nn', e.g.: |
| | | torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""", |
| | | ) |
| | | group.add_argument( |
| | | "--quantize_dtype", |
| | | type=str, |
| | | default="qint8", |
| | | choices=["float16", "qint8"], |
| | | help="Dtype for dynamic quantization.", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Text converter related") |
| | | group.add_argument( |
| | |
| | | elif mode == "uniasr": |
| | | from funasr.bin.asr_inference_uniasr import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | elif mode == "uniasr_vad": |
| | | from funasr.bin.asr_inference_uniasr_vad import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | elif mode == "paraformer": |
| | | from funasr.bin.asr_inference_paraformer import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | elif mode == "paraformer_streaming": |
| | | from funasr.bin.asr_inference_paraformer_streaming import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | elif mode == "paraformer_vad": |
| | | from funasr.bin.asr_inference_paraformer_vad import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | elif mode == "paraformer_punc": |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | elif mode == "paraformer_vad_punc": |
| | | from funasr.bin.asr_inference_paraformer_vad_punc import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | elif mode == "vad": |
| | | from funasr.bin.vad_inference import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | elif mode == "mfcca": |
| | | from funasr.bin.asr_inference_mfcca import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | elif mode == "rnnt": |
| | | from funasr.bin.asr_inference_rnnt import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | def inference_launch_funasr(**kwargs): |
| | | if 'mode' in kwargs: |
| | | mode = kwargs['mode'] |
| | | else: |
| | | logging.info("Unknown decoding mode.") |
| | | return None |
| | | if mode == "asr": |
| | | from funasr.bin.asr_inference import inference |
| | | return inference(**kwargs) |
| | | elif mode == "sa_asr": |
| | | from funasr.bin.sa_asr_inference import inference |
| | | return inference(**kwargs) |
| | | elif mode == "uniasr": |
| | | from funasr.bin.asr_inference_uniasr import inference |
| | | return inference(**kwargs) |
| | | elif mode == "paraformer": |
| | | from funasr.bin.asr_inference_paraformer import inference |
| | | return inference(**kwargs) |
| | | elif mode == "paraformer_vad_punc": |
| | | from funasr.bin.asr_inference_paraformer_vad_punc import inference |
| | | return inference(**kwargs) |
| | | elif mode == "vad": |
| | | from funasr.bin.vad_inference import inference |
| | | return inference(**kwargs) |
| | | elif mode == "mfcca": |
| | | from funasr.bin.asr_inference_mfcca import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | elif mode == "rnnt": |
| | | from funasr.bin.asr_inference_rnnt import inference |
| | | return inference(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | |
| | | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | | os.environ["CUDA_VISIBLE_DEVICES"] = gpuid |
| | | |
| | | inference_launch(**kwargs) |
| | | inference_launch_funasr(**kwargs) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | | main() |