shixian.shi
2023-08-14 c73d1a8e81582b91a9bdd6e82fce2e84f8d9d94b
funasr/bin/asr_inference_launch.py
@@ -260,8 +260,6 @@
        hotword_list_or_file = None
        clas_scale = 1.0
    if kwargs.get("device", None) == "cpu":
        ngpu = 0
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
@@ -566,6 +564,7 @@
            hotword_list_or_file = kwargs['hotword']
        speech2vadsegment.vad_model.vad_opts.max_single_segment_time = kwargs.get("max_single_segment_time", 60000)
        batch_size_token_threshold_s = kwargs.get("batch_size_token_threshold_s", int(speech2vadsegment.vad_model.vad_opts.max_single_segment_time*0.67/1000)) * 1000
        batch_size_token = kwargs.get("batch_size_token", 6000)
        print("batch_size_token: ", batch_size_token)
@@ -648,7 +647,7 @@
            beg_idx = 0
            for j, _ in enumerate(range(0, n)):
                batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
                if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_ms and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < speech2vadsegment.vad_model.vad_opts.max_single_segment_time:
                if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_ms and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_threshold_s:
                    continue
                batch_size_token_ms_cum = 0
                end_idx = j + 1
@@ -1291,6 +1290,7 @@
        quantize_dtype: Optional[str] = "float16",
        streaming: Optional[bool] = False,
        simu_streaming: Optional[bool] = False,
        full_utt: Optional[bool] = False,
        chunk_size: Optional[int] = 16,
        left_context: Optional[int] = 16,
        right_context: Optional[int] = 0,
@@ -1367,6 +1367,7 @@
        quantize_dtype=quantize_dtype,
        streaming=streaming,
        simu_streaming=simu_streaming,
        full_utt=full_utt,
        chunk_size=chunk_size,
        left_context=left_context,
        right_context=right_context,
@@ -1417,7 +1418,7 @@
                        _end = (i + 1) * speech2text._ctx
                        speech2text.streaming_decode(
                            speech[i * speech2text._ctx: _end], is_final=False
                            speech[i * speech2text._ctx: _end + speech2text._right_ctx], is_final=False
                        )
                    final_hyps = speech2text.streaming_decode(
@@ -1425,6 +1426,8 @@
                    )
                elif speech2text.simu_streaming:
                    final_hyps = speech2text.simu_streaming_decode(**batch)
                elif speech2text.full_utt:
                    final_hyps = speech2text.full_utt_decode(**batch)
                else:
                    final_hyps = speech2text(**batch)
@@ -1813,6 +1816,7 @@
    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
    group.add_argument("--streaming", type=str2bool, default=False)
    group.add_argument("--simu_streaming", type=str2bool, default=False)
    group.add_argument("--full_utt", type=str2bool, default=False)
    group.add_argument("--chunk_size", type=int, default=16)
    group.add_argument("--left_context", type=int, default=16)
    group.add_argument("--right_context", type=int, default=0)