kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
runtime/triton_gpu/client/decode_manifest_triton.py
@@ -78,9 +78,7 @@
def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--server-addr",
@@ -225,9 +223,7 @@
        lengths = np.array([[len(waveform)]], dtype=np.int32)
        inputs = [
            protocol_client.InferInput(
                "WAV", samples.shape, np_to_triton_dtype(samples.dtype)
            ),
            protocol_client.InferInput("WAV", samples.shape, np_to_triton_dtype(samples.dtype)),
            protocol_client.InferInput(
                "WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
            ),
@@ -360,9 +356,7 @@
                decoding_results = b" ".join(decoding_results).decode("utf-8")
            else:
                # For wenet
                decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode(
                    "utf-8"
                )
                decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
            chunk_end = time.time() - chunk_start
            latency_data.append((chunk_end, chunk_len / sample_rate))
@@ -406,15 +400,11 @@
    if args.streaming or args.simulate_streaming:
        frame_shift_ms = 10
        frame_length_ms = 25
        add_frames = math.ceil(
            (frame_length_ms - frame_shift_ms) / frame_shift_ms
        )
        add_frames = math.ceil((frame_length_ms - frame_shift_ms) / frame_shift_ms)
        # decode_window_length: input sequence length of streaming encoder
        if args.context > 0:
            # decode window length calculation for wenet
            decode_window_length = (
                args.chunk_size - 1
            ) * args.subsampling + args.context
            decode_window_length = (args.chunk_size - 1) * args.subsampling + args.context
        else:
            # decode window length calculation for icefall
            decode_window_length = (
@@ -437,10 +427,7 @@
                    compute_cer=compute_cer,
                    model_name=args.model_name,
                    first_chunk_in_secs=first_chunk_ms / 1000,
                    other_chunk_in_secs=args.chunk_size
                    * args.subsampling
                    * frame_shift_ms
                    / 1000,
                    other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000,
                    task_index=i,
                )
            )
@@ -455,10 +442,7 @@
                    compute_cer=compute_cer,
                    model_name=args.model_name,
                    first_chunk_in_secs=first_chunk_ms / 1000,
                    other_chunk_in_secs=args.chunk_size
                    * args.subsampling
                    * frame_shift_ms
                    / 1000,
                    other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000,
                    task_index=i,
                    simulate_mode=True,
                )
@@ -496,15 +480,10 @@
    s = f"RTF: {rtf:.4f}\n"
    s += f"total_duration: {total_duration:.3f} seconds\n"
    s += f"({total_duration/3600:.2f} hours)\n"
    s += (
        f"processing time: {elapsed:.3f} seconds "
        f"({elapsed/3600:.2f} hours)\n"
    )
    s += f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)\n"
    if args.streaming or args.simulate_streaming:
        latency_list = [
            chunk_end for (chunk_end, chunk_duration) in latency_data
        ]
        latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
        latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
        latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
        s += f"latency_variance: {latency_variance:.2f}\n"
@@ -530,9 +509,7 @@
        print(f.readline())  # Detailed errors
    if args.stats_file:
        stats = await triton_client.get_inference_statistics(
            model_name="", as_json=True
        )
        stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
        with open(args.stats_file, "w") as f:
            json.dump(stats, f)