shixian.shi
2023-08-14 c73d1a8e81582b91a9bdd6e82fce2e84f8d9d94b
funasr/bin/asr_inference_launch.py
@@ -260,8 +260,6 @@
        hotword_list_or_file = None
        clas_scale = 1.0
    if kwargs.get("device", None) == "cpu":
        ngpu = 0
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
@@ -565,6 +563,8 @@
        if 'hotword' in kwargs:
            hotword_list_or_file = kwargs['hotword']
        speech2vadsegment.vad_model.vad_opts.max_single_segment_time = kwargs.get("max_single_segment_time", 60000)
        batch_size_token_threshold_s = kwargs.get("batch_size_token_threshold_s", int(speech2vadsegment.vad_model.vad_opts.max_single_segment_time*0.67/1000)) * 1000
        batch_size_token = kwargs.get("batch_size_token", 6000)
        print("batch_size_token: ", batch_size_token)
@@ -647,8 +647,7 @@
            beg_idx = 0
            for j, _ in enumerate(range(0, n)):
                batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
                if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][
                    0]) < batch_size_token_ms:
                if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_ms and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_threshold_s:
                    continue
                batch_size_token_ms_cum = 0
                end_idx = j + 1
@@ -1291,6 +1290,7 @@
        quantize_dtype: Optional[str] = "float16",
        streaming: Optional[bool] = False,
        simu_streaming: Optional[bool] = False,
        full_utt: Optional[bool] = False,
        chunk_size: Optional[int] = 16,
        left_context: Optional[int] = 16,
        right_context: Optional[int] = 0,
@@ -1340,7 +1340,7 @@
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1:
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
@@ -1367,14 +1367,12 @@
        quantize_dtype=quantize_dtype,
        streaming=streaming,
        simu_streaming=simu_streaming,
        full_utt=full_utt,
        chunk_size=chunk_size,
        left_context=left_context,
        right_context=right_context,
    )
    speech2text = Speech2TextTransducer.from_pretrained(
        model_tag=model_tag,
        **speech2text_kwargs,
    )
    speech2text = Speech2TextTransducer(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
@@ -1420,7 +1418,7 @@
                        _end = (i + 1) * speech2text._ctx
                        speech2text.streaming_decode(
                            speech[i * speech2text._ctx: _end], is_final=False
                            speech[i * speech2text._ctx: _end + speech2text._right_ctx], is_final=False
                        )
                    final_hyps = speech2text.streaming_decode(
@@ -1428,6 +1426,8 @@
                    )
                elif speech2text.simu_streaming:
                    final_hyps = speech2text.simu_streaming_decode(**batch)
                elif speech2text.full_utt:
                    final_hyps = speech2text.full_utt_decode(**batch)
                else:
                    final_hyps = speech2text(**batch)
@@ -1816,6 +1816,7 @@
    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
    group.add_argument("--streaming", type=str2bool, default=False)
    group.add_argument("--simu_streaming", type=str2bool, default=False)
    group.add_argument("--full_utt", type=str2bool, default=False)
    group.add_argument("--chunk_size", type=int, default=16)
    group.add_argument("--left_context", type=int, default=16)
    group.add_argument("--right_context", type=int, default=0)