zhifu gao
2023-02-15 a27a59a852141f042e95db1c6a4125483852c61f
Merge pull request #118 from alibaba-damo-academy/dev_lhn

add decoding model parameters
2个文件已修改
26 ■■■■■ 已修改文件
funasr/bin/asr_inference_uniasr.py 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_uniasr_vad.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_uniasr.py
@@ -397,7 +397,7 @@
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
@@ -439,6 +439,18 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        if param_dict is not None and "decoding_model" in param_dict:
            if param_dict["decoding_model"] == "fast":
                speech2text.decoding_ind = 0
                speech2text.decoding_mode = "model1"
            elif param_dict["decoding_model"] == "normal":
                speech2text.decoding_ind = 0
                speech2text.decoding_mode = "model2"
            elif param_dict["decoding_model"] == "offline":
                speech2text.decoding_ind = 1
                speech2text.decoding_mode = "model2"
            else:
                raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
funasr/bin/asr_inference_uniasr_vad.py
@@ -439,6 +439,18 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        if param_dict is not None and "decoding_model" in param_dict:
            if param_dict["decoding_model"] == "fast":
                speech2text.decoding_ind = 0
                speech2text.decoding_mode = "model1"
            elif param_dict["decoding_model"] == "normal":
                speech2text.decoding_ind = 0
                speech2text.decoding_mode = "model2"
            elif param_dict["decoding_model"] == "offline":
                speech2text.decoding_ind = 1
                speech2text.decoding_mode = "model2"
            else:
                raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,