凌匀
2023-02-16 8689fb676d6cd28894f55c8cac43409e7bf4cd38
support asr_inference_paraformer_vad_punc
2个文件已修改
1个文件已删除
412 ■■■■■ 已修改文件
funasr/bin/asr_inference_paraformer_vad_punc.py 44 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/vad_inference.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
vad_inference.py 364 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -185,11 +185,10 @@
        if asr_train_args.encoder_conf["input_layer"] == "conv2d":
            self.encoder_downsampling_factor = 4
        
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, begin_time: int = 0, end_time: int = None,
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
            begin_time: int = 0, end_time: int = None,
    ):
        """Inference
@@ -229,7 +228,8 @@
        enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
        predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], predictor_outs[2], predictor_outs[3]
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
                                                                        predictor_outs[2], predictor_outs[3]
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
@@ -286,11 +286,13 @@
                    timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
                    results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
                else:
                    time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time, end_time)
                    time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time,
                                                 end_time)
                    results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
        # assert check_return_type(results)
        return results
class Speech2VadSegment:
    """Speech2VadSegment class
@@ -333,6 +335,7 @@
        self.device = device
        self.dtype = dtype
        self.frontend = frontend
        self.batch_size = batch_size
    @torch.no_grad()
    def __call__(
@@ -361,16 +364,30 @@
            feats_len = feats_len.int()
        else:
            raise Exception("Need to extract feats first, please configure frontend configuration")
        batch = {"feats": feats, "feats_lengths": feats_len, "waveform": speech}
        # b. Forward Encoder streaming
        t_offset = 0
        step = min(feats_len, 6000)
        segments = [[]] * self.batch_size
        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
            if t_offset + step >= feats_len - 1:
                step = feats_len - t_offset
                is_final_send = True
            else:
                is_final_send = False
            batch = {
                "feats": feats[:, t_offset:t_offset + step, :],
                "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
                "is_final_send": is_final_send
            }
        # a. To device
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        segments = self.vad_model(**batch)
            segments_part = self.vad_model(**batch)
            if segments_part:
                for batch_num in range(0, self.batch_size):
                    segments[batch_num] += segments_part[batch_num]
        return fbanks, segments
def inference(
@@ -410,7 +427,6 @@
    punc_model_file: Optional[str] = None,
    **kwargs,
):
    inference_pipeline = inference_modelscope(
        maxlenratio=maxlenratio,
        minlenratio=minlenratio,
@@ -448,6 +464,7 @@
        **kwargs,
    )
    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
    maxlenratio: float,
@@ -611,7 +628,8 @@
                    if j == 0:
                        result_segments = result_cur
                    else:
                        result_segments = [[result_segments[0][i] + result_cur[0][i] for i in range(len(result_cur[0]))]]
                        result_segments = [
                            [result_segments[0][i] + result_cur[0][i] for i in range(len(result_cur[0]))]]
    
                key = keys[0]
                result = result_segments[0]
@@ -657,8 +675,10 @@
    
                logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
        return asr_result_list
    return _forward
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="ASR Decoding",
funasr/bin/vad_inference.py
@@ -107,10 +107,8 @@
            feats_len = feats_len.int()
        else:
            raise Exception("Need to extract feats first, please configure frontend configuration")
        # batch = {"feats": feats, "waveform": speech, "is_final_send": True}
        # segments = self.vad_model(**batch)
        # b. Forward Encoder sreaming
        # b. Forward Encoder streaming
        t_offset = 0
        step = min(feats_len, 6000)
        segments = [[]] * self.batch_size
vad_inference.py
File was deleted