From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 runtime/triton_gpu/client/decode_manifest_triton_wo_cuts.py |   85 +++++++++++++++---------------------------
 1 files changed, 30 insertions(+), 55 deletions(-)

diff --git a/runtime/triton_gpu/client/decode_manifest_triton_wo_cuts.py b/runtime/triton_gpu/client/decode_manifest_triton_wo_cuts.py
index ad121c6..379b28d 100644
--- a/runtime/triton_gpu/client/decode_manifest_triton_wo_cuts.py
+++ b/runtime/triton_gpu/client/decode_manifest_triton_wo_cuts.py
@@ -60,14 +60,12 @@
 from icefall.utils import store_transcripts, write_error_stats
 
 DEFAULT_MANIFEST_FILENAME = "./aishell_test.txt"  # noqa
-DEFAULT_ROOT = './'
-DEFAULT_ROOT = '/mfs/songtao/researchcode/FunASR/data/'
+DEFAULT_ROOT = "./"
+DEFAULT_ROOT = "/mfs/songtao/researchcode/FunASR/data/"
 
 
 def get_args():
-    parser = argparse.ArgumentParser(
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter
-    )
+    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 
     parser.add_argument(
         "--server-addr",
@@ -185,7 +183,7 @@
     with open(fp) as f:
         for i, dp in enumerate(f.readlines()):
             dp = eval(dp)
-            dp['id'] = i
+            dp["id"] = i
             data.append(dp)
     return data
 
@@ -195,19 +193,19 @@
     # import pdb;pdb.set_trace()
     assert len(dps) > num_tasks
 
-    one_task_num = len(dps)//num_tasks
+    one_task_num = len(dps) // num_tasks
     for i in range(0, len(dps), one_task_num):
-        if i+one_task_num >= len(dps):
+        if i + one_task_num >= len(dps):
             for k, j in enumerate(range(i, len(dps))):
                 dps_splited[k].append(dps[j])
         else:
-            dps_splited.append(dps[i:i+one_task_num])
+            dps_splited.append(dps[i : i + one_task_num])
     return dps_splited
 
 
 def load_audio(path):
     audio = AudioSegment.from_wav(path).set_frame_rate(16000).set_channels(1)
-    audiop_np = np.array(audio.get_array_of_samples())/32768.0
+    audiop_np = np.array(audio.get_array_of_samples()) / 32768.0
     return audiop_np.astype(np.float32), audio.duration_seconds
 
 
@@ -227,16 +225,14 @@
         if i % log_interval == 0:
             print(f"{name}: {i}/{len(dps)}")
 
-        waveform, duration = load_audio(
-            os.path.join(DEFAULT_ROOT, dp['audio_filepath']))
+        waveform, duration = load_audio(os.path.join(DEFAULT_ROOT, dp["audio_filepath"]))
         sample_rate = 16000
 
         # padding to nearset 10 seconds
         samples = np.zeros(
             (
                 1,
-                10 * sample_rate *
-                (int(len(waveform) / sample_rate // 10) + 1),
+                10 * sample_rate * (int(len(waveform) / sample_rate // 10) + 1),
             ),
             dtype=np.float32,
         )
@@ -245,9 +241,7 @@
         lengths = np.array([[len(waveform)]], dtype=np.int32)
 
         inputs = [
-            protocol_client.InferInput(
-                "WAV", samples.shape, np_to_triton_dtype(samples.dtype)
-            ),
+            protocol_client.InferInput("WAV", samples.shape, np_to_triton_dtype(samples.dtype)),
             protocol_client.InferInput(
                 "WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
             ),
@@ -271,16 +265,16 @@
         total_duration += duration
 
         if compute_cer:
-            ref = dp['text'].split()
+            ref = dp["text"].split()
             hyp = decoding_results.split()
             ref = list("".join(ref))
             hyp = list("".join(hyp))
-            results.append((dp['id'], ref, hyp))
+            results.append((dp["id"], ref, hyp))
         else:
             results.append(
                 (
-                    dp['id'],
-                    dp['text'].split(),
+                    dp["id"],
+                    dp["text"].split(),
                     decoding_results.split(),
                 )
             )  # noqa
@@ -309,7 +303,7 @@
         if i % log_interval == 0:
             print(f"{name}: {i}/{len(dps)}")
 
-        waveform, duration = load_audio(dp['audio_filepath'])
+        waveform, duration = load_audio(dp["audio_filepath"])
         sample_rate = 16000
 
         wav_segs = []
@@ -318,10 +312,10 @@
         while j < len(waveform):
             if j == 0:
                 stride = int(first_chunk_in_secs * sample_rate)
-                wav_segs.append(waveform[j: j + stride])
+                wav_segs.append(waveform[j : j + stride])
             else:
                 stride = int(other_chunk_in_secs * sample_rate)
-                wav_segs.append(waveform[j: j + stride])
+                wav_segs.append(waveform[j : j + stride])
             j += len(wav_segs[-1])
 
         sequence_id = task_index + 10086
@@ -380,25 +374,23 @@
                 decoding_results = b" ".join(decoding_results).decode("utf-8")
             else:
                 # For wenet
-                decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode(
-                    "utf-8"
-                )
+                decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
             chunk_end = time.time() - chunk_start
             latency_data.append((chunk_end, chunk_len / sample_rate))
 
         total_duration += duration
 
         if compute_cer:
-            ref = dp['text'].split()
+            ref = dp["text"].split()
             hyp = decoding_results.split()
             ref = list("".join(ref))
             hyp = list("".join(hyp))
-            results.append((dp['id'], ref, hyp))
+            results.append((dp["id"], ref, hyp))
         else:
             results.append(
                 (
-                    dp['id'],
-                    dp['text'].split(),
+                    dp["id"],
+                    dp["text"].split(),
                     decoding_results.split(),
                 )
             )  # noqa
@@ -426,15 +418,11 @@
     if args.streaming or args.simulate_streaming:
         frame_shift_ms = 10
         frame_length_ms = 25
-        add_frames = math.ceil(
-            (frame_length_ms - frame_shift_ms) / frame_shift_ms
-        )
+        add_frames = math.ceil((frame_length_ms - frame_shift_ms) / frame_shift_ms)
         # decode_window_length: input sequence length of streaming encoder
         if args.context > 0:
             # decode window length calculation for wenet
-            decode_window_length = (
-                args.chunk_size - 1
-            ) * args.subsampling + args.context
+            decode_window_length = (args.chunk_size - 1) * args.subsampling + args.context
         else:
             # decode window length calculation for icefall
             decode_window_length = (
@@ -457,10 +445,7 @@
                     compute_cer=compute_cer,
                     model_name=args.model_name,
                     first_chunk_in_secs=first_chunk_ms / 1000,
-                    other_chunk_in_secs=args.chunk_size
-                    * args.subsampling
-                    * frame_shift_ms
-                    / 1000,
+                    other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000,
                     task_index=i,
                 )
             )
@@ -475,10 +460,7 @@
                     compute_cer=compute_cer,
                     model_name=args.model_name,
                     first_chunk_in_secs=first_chunk_ms / 1000,
-                    other_chunk_in_secs=args.chunk_size
-                    * args.subsampling
-                    * frame_shift_ms
-                    / 1000,
+                    other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000,
                     task_index=i,
                     simulate_mode=True,
                 )
@@ -516,15 +498,10 @@
     s = f"RTF: {rtf:.4f}\n"
     s += f"total_duration: {total_duration:.3f} seconds\n"
     s += f"({total_duration/3600:.2f} hours)\n"
-    s += (
-        f"processing time: {elapsed:.3f} seconds "
-        f"({elapsed/3600:.2f} hours)\n"
-    )
+    s += f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)\n"
 
     if args.streaming or args.simulate_streaming:
-        latency_list = [
-            chunk_end for (chunk_end, chunk_duration) in latency_data
-        ]
+        latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
         latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
         latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
         s += f"latency_variance: {latency_variance:.2f}\n"
@@ -550,9 +527,7 @@
         print(f.readline())  # Detailed errors
 
     if args.stats_file:
-        stats = await triton_client.get_inference_statistics(
-            model_name="", as_json=True
-        )
+        stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
         with open(args.stats_file, "w") as f:
             json.dump(stats, f)
 

--
Gitblit v1.9.1