haoneng.lhn
2023-09-11 e60ac4bc991183a780fdd03d22db7d3b42df9b58
support chunk size select for chunk-hopping encoder
2个文件已修改
18 ■■■■ 已修改文件
funasr/bin/asr_infer.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_infer.py
@@ -399,7 +399,7 @@
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
            begin_time: int = 0, end_time: int = None,
            decoding_ind: int = None, begin_time: int = 0, end_time: int = None,
    ):
        """Inference
@@ -429,7 +429,9 @@
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
        if decoding_ind is None:
            decoding_ind = self.decoding_ind
        enc, enc_len = self.asr_model.encode(**batch, ind=decoding_ind)
        if isinstance(enc, tuple):
            enc = enc[0]
        # assert len(enc) == 1, len(enc)
funasr/bin/asr_inference_launch.py
@@ -236,6 +236,7 @@
        timestamp_infer_config: Union[Path, str] = None,
        timestamp_model_file: Union[Path, str] = None,
        param_dict: dict = None,
        decoding_ind: int = 0,
        **kwargs,
):
    ncpu = kwargs.get("ncpu", 1)
@@ -290,6 +291,7 @@
        nbest=nbest,
        hotword_list_or_file=hotword_list_or_file,
        clas_scale=clas_scale,
        decoding_ind=decoding_ind,
    )
    speech2text = Speech2TextParaformer(**speech2text_kwargs)
@@ -312,6 +314,7 @@
            **kwargs,
    ):
        decoding_ind = None
        hotword_list_or_file = None
        if param_dict is not None:
            hotword_list_or_file = param_dict.get('hotword')
@@ -319,6 +322,8 @@
            hotword_list_or_file = kwargs['hotword']
        if hotword_list_or_file is not None or 'hotword' in kwargs:
            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
        if param_dict is not None and "decoding_ind" in param_dict:
            decoding_ind = param_dict["decoding_ind"]
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
@@ -365,6 +370,7 @@
            # N-best list of (text, token, token_int, hyp_object)
            time_beg = time.time()
            batch["decoding_ind"] = decoding_ind
            results = speech2text(**batch)
            if len(results) < 1:
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
@@ -1786,6 +1792,12 @@
        default=1,
        help="The batch size for inference",
    )
    group.add_argument(
        "--decoding_ind",
        type=int,
        default=0,
        help="chunk select for chunk encoder",
    )
    group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")