liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
runtime/triton_gpu/client/decode_manifest_triton_wo_cuts.py
@@ -60,14 +60,12 @@
from icefall.utils import store_transcripts, write_error_stats
DEFAULT_MANIFEST_FILENAME = "./aishell_test.txt"  # noqa
DEFAULT_ROOT = './'
DEFAULT_ROOT = '/mfs/songtao/researchcode/FunASR/data/'
DEFAULT_ROOT = "./"
DEFAULT_ROOT = "/mfs/songtao/researchcode/FunASR/data/"
def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--server-addr",
@@ -185,7 +183,7 @@
    with open(fp) as f:
        for i, dp in enumerate(f.readlines()):
            dp = eval(dp)
            dp['id'] = i
            dp["id"] = i
            data.append(dp)
    return data
@@ -195,19 +193,19 @@
    # import pdb;pdb.set_trace()
    assert len(dps) > num_tasks
    one_task_num = len(dps)//num_tasks
    one_task_num = len(dps) // num_tasks
    for i in range(0, len(dps), one_task_num):
        if i+one_task_num >= len(dps):
        if i + one_task_num >= len(dps):
            for k, j in enumerate(range(i, len(dps))):
                dps_splited[k].append(dps[j])
        else:
            dps_splited.append(dps[i:i+one_task_num])
            dps_splited.append(dps[i : i + one_task_num])
    return dps_splited
def load_audio(path):
    audio = AudioSegment.from_wav(path).set_frame_rate(16000).set_channels(1)
    audiop_np = np.array(audio.get_array_of_samples())/32768.0
    audiop_np = np.array(audio.get_array_of_samples()) / 32768.0
    return audiop_np.astype(np.float32), audio.duration_seconds
@@ -227,16 +225,14 @@
        if i % log_interval == 0:
            print(f"{name}: {i}/{len(dps)}")
        waveform, duration = load_audio(
            os.path.join(DEFAULT_ROOT, dp['audio_filepath']))
        waveform, duration = load_audio(os.path.join(DEFAULT_ROOT, dp["audio_filepath"]))
        sample_rate = 16000
        # padding to nearset 10 seconds
        samples = np.zeros(
            (
                1,
                10 * sample_rate *
                (int(len(waveform) / sample_rate // 10) + 1),
                10 * sample_rate * (int(len(waveform) / sample_rate // 10) + 1),
            ),
            dtype=np.float32,
        )
@@ -245,9 +241,7 @@
        lengths = np.array([[len(waveform)]], dtype=np.int32)
        inputs = [
            protocol_client.InferInput(
                "WAV", samples.shape, np_to_triton_dtype(samples.dtype)
            ),
            protocol_client.InferInput("WAV", samples.shape, np_to_triton_dtype(samples.dtype)),
            protocol_client.InferInput(
                "WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
            ),
@@ -271,16 +265,16 @@
        total_duration += duration
        if compute_cer:
            ref = dp['text'].split()
            ref = dp["text"].split()
            hyp = decoding_results.split()
            ref = list("".join(ref))
            hyp = list("".join(hyp))
            results.append((dp['id'], ref, hyp))
            results.append((dp["id"], ref, hyp))
        else:
            results.append(
                (
                    dp['id'],
                    dp['text'].split(),
                    dp["id"],
                    dp["text"].split(),
                    decoding_results.split(),
                )
            )  # noqa
@@ -309,7 +303,7 @@
        if i % log_interval == 0:
            print(f"{name}: {i}/{len(dps)}")
        waveform, duration = load_audio(dp['audio_filepath'])
        waveform, duration = load_audio(dp["audio_filepath"])
        sample_rate = 16000
        wav_segs = []
@@ -318,10 +312,10 @@
        while j < len(waveform):
            if j == 0:
                stride = int(first_chunk_in_secs * sample_rate)
                wav_segs.append(waveform[j: j + stride])
                wav_segs.append(waveform[j : j + stride])
            else:
                stride = int(other_chunk_in_secs * sample_rate)
                wav_segs.append(waveform[j: j + stride])
                wav_segs.append(waveform[j : j + stride])
            j += len(wav_segs[-1])
        sequence_id = task_index + 10086
@@ -380,25 +374,23 @@
                decoding_results = b" ".join(decoding_results).decode("utf-8")
            else:
                # For wenet
                decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode(
                    "utf-8"
                )
                decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
            chunk_end = time.time() - chunk_start
            latency_data.append((chunk_end, chunk_len / sample_rate))
        total_duration += duration
        if compute_cer:
            ref = dp['text'].split()
            ref = dp["text"].split()
            hyp = decoding_results.split()
            ref = list("".join(ref))
            hyp = list("".join(hyp))
            results.append((dp['id'], ref, hyp))
            results.append((dp["id"], ref, hyp))
        else:
            results.append(
                (
                    dp['id'],
                    dp['text'].split(),
                    dp["id"],
                    dp["text"].split(),
                    decoding_results.split(),
                )
            )  # noqa
@@ -426,15 +418,11 @@
    if args.streaming or args.simulate_streaming:
        frame_shift_ms = 10
        frame_length_ms = 25
        add_frames = math.ceil(
            (frame_length_ms - frame_shift_ms) / frame_shift_ms
        )
        add_frames = math.ceil((frame_length_ms - frame_shift_ms) / frame_shift_ms)
        # decode_window_length: input sequence length of streaming encoder
        if args.context > 0:
            # decode window length calculation for wenet
            decode_window_length = (
                args.chunk_size - 1
            ) * args.subsampling + args.context
            decode_window_length = (args.chunk_size - 1) * args.subsampling + args.context
        else:
            # decode window length calculation for icefall
            decode_window_length = (
@@ -457,10 +445,7 @@
                    compute_cer=compute_cer,
                    model_name=args.model_name,
                    first_chunk_in_secs=first_chunk_ms / 1000,
                    other_chunk_in_secs=args.chunk_size
                    * args.subsampling
                    * frame_shift_ms
                    / 1000,
                    other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000,
                    task_index=i,
                )
            )
@@ -475,10 +460,7 @@
                    compute_cer=compute_cer,
                    model_name=args.model_name,
                    first_chunk_in_secs=first_chunk_ms / 1000,
                    other_chunk_in_secs=args.chunk_size
                    * args.subsampling
                    * frame_shift_ms
                    / 1000,
                    other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000,
                    task_index=i,
                    simulate_mode=True,
                )
@@ -516,15 +498,10 @@
    s = f"RTF: {rtf:.4f}\n"
    s += f"total_duration: {total_duration:.3f} seconds\n"
    s += f"({total_duration/3600:.2f} hours)\n"
    s += (
        f"processing time: {elapsed:.3f} seconds "
        f"({elapsed/3600:.2f} hours)\n"
    )
    s += f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)\n"
    if args.streaming or args.simulate_streaming:
        latency_list = [
            chunk_end for (chunk_end, chunk_duration) in latency_data
        ]
        latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
        latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
        latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
        s += f"latency_variance: {latency_variance:.2f}\n"
@@ -550,9 +527,7 @@
        print(f.readline())  # Detailed errors
    if args.stats_file:
        stats = await triton_client.get_inference_statistics(
            model_name="", as_json=True
        )
        stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
        with open(args.stats_file, "w") as f:
            json.dump(stats, f)