From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 runtime/triton_gpu/client/decode_manifest_triton.py |   43 ++++++++++---------------------------------
 1 files changed, 10 insertions(+), 33 deletions(-)

diff --git a/runtime/triton_gpu/client/decode_manifest_triton.py b/runtime/triton_gpu/client/decode_manifest_triton.py
index 3a8d57f..482c715 100644
--- a/runtime/triton_gpu/client/decode_manifest_triton.py
+++ b/runtime/triton_gpu/client/decode_manifest_triton.py
@@ -78,9 +78,7 @@
 
 
 def get_args():
-    parser = argparse.ArgumentParser(
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter
-    )
+    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 
     parser.add_argument(
         "--server-addr",
@@ -225,9 +223,7 @@
         lengths = np.array([[len(waveform)]], dtype=np.int32)
 
         inputs = [
-            protocol_client.InferInput(
-                "WAV", samples.shape, np_to_triton_dtype(samples.dtype)
-            ),
+            protocol_client.InferInput("WAV", samples.shape, np_to_triton_dtype(samples.dtype)),
             protocol_client.InferInput(
                 "WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
             ),
@@ -360,9 +356,7 @@
                 decoding_results = b" ".join(decoding_results).decode("utf-8")
             else:
                 # For wenet
-                decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode(
-                    "utf-8"
-                )
+                decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
             chunk_end = time.time() - chunk_start
             latency_data.append((chunk_end, chunk_len / sample_rate))
 
@@ -406,15 +400,11 @@
     if args.streaming or args.simulate_streaming:
         frame_shift_ms = 10
         frame_length_ms = 25
-        add_frames = math.ceil(
-            (frame_length_ms - frame_shift_ms) / frame_shift_ms
-        )
+        add_frames = math.ceil((frame_length_ms - frame_shift_ms) / frame_shift_ms)
         # decode_window_length: input sequence length of streaming encoder
         if args.context > 0:
             # decode window length calculation for wenet
-            decode_window_length = (
-                args.chunk_size - 1
-            ) * args.subsampling + args.context
+            decode_window_length = (args.chunk_size - 1) * args.subsampling + args.context
         else:
             # decode window length calculation for icefall
             decode_window_length = (
@@ -437,10 +427,7 @@
                     compute_cer=compute_cer,
                     model_name=args.model_name,
                     first_chunk_in_secs=first_chunk_ms / 1000,
-                    other_chunk_in_secs=args.chunk_size
-                    * args.subsampling
-                    * frame_shift_ms
-                    / 1000,
+                    other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000,
                     task_index=i,
                 )
             )
@@ -455,10 +442,7 @@
                     compute_cer=compute_cer,
                     model_name=args.model_name,
                     first_chunk_in_secs=first_chunk_ms / 1000,
-                    other_chunk_in_secs=args.chunk_size
-                    * args.subsampling
-                    * frame_shift_ms
-                    / 1000,
+                    other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000,
                     task_index=i,
                     simulate_mode=True,
                 )
@@ -496,15 +480,10 @@
     s = f"RTF: {rtf:.4f}\n"
     s += f"total_duration: {total_duration:.3f} seconds\n"
     s += f"({total_duration/3600:.2f} hours)\n"
-    s += (
-        f"processing time: {elapsed:.3f} seconds "
-        f"({elapsed/3600:.2f} hours)\n"
-    )
+    s += f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)\n"
 
     if args.streaming or args.simulate_streaming:
-        latency_list = [
-            chunk_end for (chunk_end, chunk_duration) in latency_data
-        ]
+        latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
         latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
         latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
         s += f"latency_variance: {latency_variance:.2f}\n"
@@ -530,9 +509,7 @@
         print(f.readline())  # Detailed errors
 
     if args.stats_file:
-        stats = await triton_client.get_inference_statistics(
-            model_name="", as_json=True
-        )
+        stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
         with open(args.stats_file, "w") as f:
             json.dump(stats, f)
 

--
Gitblit v1.9.1