zhifu gao
2023-03-16 d783b24ba7d8a03dabfa2139fcbf40c216e0ea3d
funasr/bin/asr_inference_uniasr_vad.py
@@ -46,7 +46,7 @@
    Examples:
        >>> import soundfile
        >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
        >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2text(audio)
        [(text, token, token_int, hypothesis object), ...]
@@ -398,6 +398,19 @@
    else:
        device = "cpu"
    if param_dict is not None and "decoding_model" in param_dict:
        if param_dict["decoding_model"] == "fast":
            decoding_ind = 0
            decoding_mode = "model1"
        elif param_dict["decoding_model"] == "normal":
            decoding_ind = 0
            decoding_mode = "model2"
        elif param_dict["decoding_model"] == "offline":
            decoding_ind = 1
            decoding_mode = "model2"
        else:
            raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
    # 1. Set random-seed
    set_all_random_seed(seed)
@@ -440,18 +453,6 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        if param_dict is not None and "decoding_model" in param_dict:
            if param_dict["decoding_model"] == "fast":
                speech2text.decoding_ind = 0
                speech2text.decoding_mode = "model1"
            elif param_dict["decoding_model"] == "normal":
                speech2text.decoding_ind = 0
                speech2text.decoding_mode = "model2"
            elif param_dict["decoding_model"] == "offline":
                speech2text.decoding_ind = 1
                speech2text.decoding_mode = "model2"
            else:
                raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,